Validate parameters before
Browse files- gradio_demo.py +89 -28
gradio_demo.py
CHANGED
|
@@ -30,7 +30,7 @@ parser.add_argument("--no_llava", action='store_true', default=True)#False
|
|
| 30 |
parser.add_argument("--use_image_slider", action='store_true', default=False)
|
| 31 |
parser.add_argument("--log_history", action='store_true', default=False)
|
| 32 |
parser.add_argument("--loading_half_params", action='store_true', default=True)#False
|
| 33 |
-
parser.add_argument("--use_tile_vae", action='store_true', default=False
|
| 34 |
parser.add_argument("--encoder_tile_size", type=int, default=512)
|
| 35 |
parser.add_argument("--decoder_tile_size", type=int, default=64)
|
| 36 |
parser.add_argument("--load_8bit_llava", action='store_true', default=False)
|
|
@@ -67,15 +67,16 @@ if torch.cuda.device_count() > 0:
|
|
| 67 |
else:
|
| 68 |
llava_agent = None
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def stage1_process(input_image, gamma_correction):
|
| 72 |
print('Start stage1_process')
|
| 73 |
if torch.cuda.device_count() == 0:
|
| 74 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 75 |
return None
|
| 76 |
-
if input_image is None:
|
| 77 |
-
gr.Warning('Please provide an image to restore.')
|
| 78 |
-
return None
|
| 79 |
torch.cuda.set_device(SUPIR_device)
|
| 80 |
LQ = HWC3(input_image)
|
| 81 |
LQ = fix_resize(LQ, 512)
|
|
@@ -92,15 +93,12 @@ def stage1_process(input_image, gamma_correction):
|
|
| 92 |
print('End stage1_process')
|
| 93 |
return LQ
|
| 94 |
|
| 95 |
-
@spaces.GPU(duration=
|
| 96 |
def llave_process(input_image, temperature, top_p, qs=None):
|
| 97 |
print('Start llave_process')
|
| 98 |
if torch.cuda.device_count() == 0:
|
| 99 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 100 |
return 'Set this space to GPU config to make it work.'
|
| 101 |
-
if input_image is None:
|
| 102 |
-
gr.Warning('Please provide an image to restore.')
|
| 103 |
-
return 'Please provide an image to restore.'
|
| 104 |
torch.cuda.set_device(LLaVA_device)
|
| 105 |
if use_llava:
|
| 106 |
LQ = HWC3(input_image)
|
|
@@ -111,7 +109,7 @@ def llave_process(input_image, temperature, top_p, qs=None):
|
|
| 111 |
print('End llave_process')
|
| 112 |
return captions[0]
|
| 113 |
|
| 114 |
-
@spaces.GPU(duration=
|
| 115 |
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
| 116 |
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
| 117 |
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
|
|
@@ -119,9 +117,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
|
|
| 119 |
if torch.cuda.device_count() == 0:
|
| 120 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 121 |
return None, None, None, None
|
| 122 |
-
if input_image is None:
|
| 123 |
-
gr.Warning('Please provide an image to restore.')
|
| 124 |
-
return None, None, None, None
|
| 125 |
torch.cuda.set_device(SUPIR_device)
|
| 126 |
event_id = str(time.time_ns())
|
| 127 |
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
|
|
@@ -279,7 +274,7 @@ with gr.Blocks(title='SUPIR') as interface:
|
|
| 279 |
qs = gr.Textbox(label="Question", info="Describe the image and its style in a very detailed manner", placeholder="The image is a realistic photography, not an art painting.")
|
| 280 |
|
| 281 |
with gr.Accordion("Restoring options", open=False):
|
| 282 |
-
num_samples = gr.Slider(label="Num Samples", info="Number of generated results; I discourage to increase because the process is limited to
|
| 283 |
, value=1, step=1)
|
| 284 |
upscale = gr.Slider(label="Upscale", info="The resolution increase factor", minimum=1, maximum=8, value=1, step=1)
|
| 285 |
edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
|
|
@@ -319,10 +314,10 @@ with gr.Blocks(title='SUPIR') as interface:
|
|
| 319 |
ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
|
| 320 |
interactive=True)
|
| 321 |
with gr.Column():
|
| 322 |
-
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
|
| 323 |
interactive=True)
|
| 324 |
with gr.Column():
|
| 325 |
-
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
|
| 326 |
interactive=True)
|
| 327 |
|
| 328 |
with gr.Column():
|
|
@@ -352,17 +347,83 @@ with gr.Blocks(title='SUPIR') as interface:
|
|
| 352 |
with gr.Row():
|
| 353 |
gr.Markdown(claim_md)
|
| 354 |
event_id = gr.Textbox(label="Event ID", value="", visible=False)
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
interface.queue(10).launch()
|
|
|
|
| 30 |
parser.add_argument("--use_image_slider", action='store_true', default=False)
|
| 31 |
parser.add_argument("--log_history", action='store_true', default=False)
|
| 32 |
parser.add_argument("--loading_half_params", action='store_true', default=True)#False
|
| 33 |
+
parser.add_argument("--use_tile_vae", action='store_true', default=True)#False
|
| 34 |
parser.add_argument("--encoder_tile_size", type=int, default=512)
|
| 35 |
parser.add_argument("--decoder_tile_size", type=int, default=64)
|
| 36 |
parser.add_argument("--load_8bit_llava", action='store_true', default=False)
|
|
|
|
| 67 |
else:
|
| 68 |
llava_agent = None
|
| 69 |
|
| 70 |
+
def check(input_image):
|
| 71 |
+
if input_image is None:
|
| 72 |
+
raise gr.Error("Please provide an image to restore.")
|
| 73 |
+
|
| 74 |
+
@spaces.GPU(duration=180)
|
| 75 |
def stage1_process(input_image, gamma_correction):
|
| 76 |
print('Start stage1_process')
|
| 77 |
if torch.cuda.device_count() == 0:
|
| 78 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 79 |
return None
|
|
|
|
|
|
|
|
|
|
| 80 |
torch.cuda.set_device(SUPIR_device)
|
| 81 |
LQ = HWC3(input_image)
|
| 82 |
LQ = fix_resize(LQ, 512)
|
|
|
|
| 93 |
print('End stage1_process')
|
| 94 |
return LQ
|
| 95 |
|
| 96 |
+
@spaces.GPU(duration=180)
|
| 97 |
def llave_process(input_image, temperature, top_p, qs=None):
|
| 98 |
print('Start llave_process')
|
| 99 |
if torch.cuda.device_count() == 0:
|
| 100 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 101 |
return 'Set this space to GPU config to make it work.'
|
|
|
|
|
|
|
|
|
|
| 102 |
torch.cuda.set_device(LLaVA_device)
|
| 103 |
if use_llava:
|
| 104 |
LQ = HWC3(input_image)
|
|
|
|
| 109 |
print('End llave_process')
|
| 110 |
return captions[0]
|
| 111 |
|
| 112 |
+
@spaces.GPU(duration=180)
|
| 113 |
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
|
| 114 |
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
|
| 115 |
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
|
|
|
|
| 117 |
if torch.cuda.device_count() == 0:
|
| 118 |
gr.Warning('Set this space to GPU config to make it work.')
|
| 119 |
return None, None, None, None
|
|
|
|
|
|
|
|
|
|
| 120 |
torch.cuda.set_device(SUPIR_device)
|
| 121 |
event_id = str(time.time_ns())
|
| 122 |
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
|
|
|
|
| 274 |
qs = gr.Textbox(label="Question", info="Describe the image and its style in a very detailed manner", placeholder="The image is a realistic photography, not an art painting.")
|
| 275 |
|
| 276 |
with gr.Accordion("Restoring options", open=False):
|
| 277 |
+
num_samples = gr.Slider(label="Num Samples", info="Number of generated results; I discourage to increase because the process is limited to 3 min", minimum=1, maximum=4 if not args.use_image_slider else 1
|
| 278 |
, value=1, step=1)
|
| 279 |
upscale = gr.Slider(label="Upscale", info="The resolution increase factor", minimum=1, maximum=8, value=1, step=1)
|
| 280 |
edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
|
|
|
|
| 314 |
ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
|
| 315 |
interactive=True)
|
| 316 |
with gr.Column():
|
| 317 |
+
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", info="Wavelet=For JPEG artifacts", value="Wavelet",
|
| 318 |
interactive=True)
|
| 319 |
with gr.Column():
|
| 320 |
+
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", info="Q=Quality, F=Fidelity", value="v0-Q",
|
| 321 |
interactive=True)
|
| 322 |
|
| 323 |
with gr.Column():
|
|
|
|
| 347 |
with gr.Row():
|
| 348 |
gr.Markdown(claim_md)
|
| 349 |
event_id = gr.Textbox(label="Event ID", value="", visible=False)
|
| 350 |
+
|
| 351 |
+
denoise_button.click(fn = check, inputs = [
|
| 352 |
+
input_image
|
| 353 |
+
], outputs = [], queue = False, show_progress = False).success(fn = stage1_process, inputs = [
|
| 354 |
+
input_image,
|
| 355 |
+
gamma_correction
|
| 356 |
+
], outputs=[
|
| 357 |
+
denoise_image
|
| 358 |
+
])
|
| 359 |
+
|
| 360 |
+
llave_button.click(fn = check, inputs = [
|
| 361 |
+
denoise_image
|
| 362 |
+
], outputs = [], queue = False, show_progress = False).success(fn = llave_process, inputs = [
|
| 363 |
+
denoise_image,
|
| 364 |
+
temperature,
|
| 365 |
+
top_p,
|
| 366 |
+
qs
|
| 367 |
+
], outputs = [
|
| 368 |
+
prompt
|
| 369 |
+
])
|
| 370 |
+
|
| 371 |
+
diffusion_button.click(fn = check, inputs = [
|
| 372 |
+
input_image
|
| 373 |
+
], outputs = [], queue = False, show_progress = False).success(fn=stage2_process, inputs = [
|
| 374 |
+
input_image,
|
| 375 |
+
prompt,
|
| 376 |
+
a_prompt,
|
| 377 |
+
n_prompt,
|
| 378 |
+
num_samples,
|
| 379 |
+
upscale,
|
| 380 |
+
edm_steps,
|
| 381 |
+
s_stage1,
|
| 382 |
+
s_stage2,
|
| 383 |
+
s_cfg,
|
| 384 |
+
seed,
|
| 385 |
+
s_churn,
|
| 386 |
+
s_noise,
|
| 387 |
+
color_fix_type,
|
| 388 |
+
diff_dtype,
|
| 389 |
+
ae_dtype,
|
| 390 |
+
gamma_correction,
|
| 391 |
+
linear_CFG,
|
| 392 |
+
linear_s_stage2,
|
| 393 |
+
spt_linear_CFG,
|
| 394 |
+
spt_linear_s_stage2,
|
| 395 |
+
model_select
|
| 396 |
+
], outputs = [
|
| 397 |
+
result_gallery,
|
| 398 |
+
event_id,
|
| 399 |
+
fb_score,
|
| 400 |
+
fb_text
|
| 401 |
+
])
|
| 402 |
+
|
| 403 |
+
restart_button.click(fn = load_and_reset, inputs = [
|
| 404 |
+
param_setting
|
| 405 |
+
], outputs = [
|
| 406 |
+
edm_steps,
|
| 407 |
+
s_cfg,
|
| 408 |
+
s_stage2,
|
| 409 |
+
s_stage1,
|
| 410 |
+
s_churn,
|
| 411 |
+
s_noise,
|
| 412 |
+
a_prompt,
|
| 413 |
+
n_prompt,
|
| 414 |
+
color_fix_type,
|
| 415 |
+
linear_CFG,
|
| 416 |
+
linear_s_stage2,
|
| 417 |
+
spt_linear_CFG,
|
| 418 |
+
spt_linear_s_stage2
|
| 419 |
+
])
|
| 420 |
+
|
| 421 |
+
submit_button.click(fn = submit_feedback, inputs = [
|
| 422 |
+
event_id,
|
| 423 |
+
fb_score,
|
| 424 |
+
fb_text
|
| 425 |
+
], outputs = [
|
| 426 |
+
fb_text
|
| 427 |
+
])
|
| 428 |
|
| 429 |
interface.queue(10).launch()
|