Have upscale button use the same seed as hires fix.

This commit is contained in:
AUTOMATIC1111 2024-01-04 19:47:00 +03:00
parent f903b4dda3
commit 15ec54dd96
5 changed files with 53 additions and 20 deletions

View File

@ -91,6 +91,9 @@ class Script:
setup_for_ui_only = False setup_for_ui_only = False
"""If true, the script setup will only be run in Gradio UI, not in API""" """If true, the script setup will only be run in Gradio UI, not in API"""
controls = None
"""A list of controls retured by the ui()."""
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu.""" """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@ -624,6 +627,7 @@ class ScriptRunner:
import modules.api.models as api_models import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
script.controls = controls
if controls is None: if controls is None:
return return
@ -918,6 +922,23 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True) errors.report(f"Error running setup: {script.filename}", exc_info=True)
def set_named_arg(self, args, script_type, arg_elem_id, value):
script = next((x for x in self.scripts if type(x).__name__ == script_type), None)
if script is None:
return
for i, control in enumerate(script.controls):
if arg_elem_id in control.elem_id:
index = script.args_from + i
if isinstance(args, list):
args[index] = value
return args
elif isinstance(args, tuple):
return args[:index] + (value,) + args[index+1:]
else:
return None
scripts_txt2img: ScriptRunner = None scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None scripts_img2img: ScriptRunner = None

View File

@ -1,3 +1,4 @@
import json
from contextlib import closing from contextlib import closing
import modules.scripts import modules.scripts
@ -9,12 +10,19 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
assert len(gallery) > 0, 'No image to upscale' assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
geninfo = json.loads(generation_info)
all_seeds = geninfo["all_seeds"]
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
image = infotext_utils.image_from_url_text(image_info) image = infotext_utils.image_from_url_text(image_info)
gallery_index_from_end = len(gallery) - gallery_index
image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
return txt2img(id_task, request, *args, firstpass_image=image) return txt2img(id_task, request, *args, firstpass_image=image)
@ -22,6 +30,10 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
if firstpass_image is not None: if firstpass_image is not None:
seed = getattr(firstpass_image, 'seed', None)
if seed:
args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed)
enable_hr = True enable_hr = True
batch_size = 1 batch_size = 1
n_iter = 1 n_iter = 1

View File

@ -405,8 +405,8 @@ def create_ui():
txt2img_outputs = [ txt2img_outputs = [
output_panel.gallery, output_panel.gallery,
output_panel.generation_info,
output_panel.infotext, output_panel.infotext,
output_panel.html_info,
output_panel.html_log, output_panel.html_log,
] ]
@ -424,7 +424,7 @@ def create_ui():
output_panel.button_upscale.click( output_panel.button_upscale.click(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
_js="submit_txt2img_upscale", _js="submit_txt2img_upscale",
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
outputs=txt2img_outputs, outputs=txt2img_outputs,
show_progress=False, show_progress=False,
) )
@ -437,8 +437,8 @@ def create_ui():
inputs=[dummy_component], inputs=[dummy_component],
outputs=[ outputs=[
output_panel.gallery, output_panel.gallery,
output_panel.generation_info,
output_panel.infotext, output_panel.infotext,
output_panel.html_info,
output_panel.html_log, output_panel.html_log,
], ],
show_progress=False, show_progress=False,
@ -766,8 +766,8 @@ def create_ui():
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
output_panel.gallery, output_panel.gallery,
output_panel.generation_info,
output_panel.infotext, output_panel.infotext,
output_panel.html_info,
output_panel.html_log, output_panel.html_log,
], ],
show_progress=False, show_progress=False,
@ -807,8 +807,8 @@ def create_ui():
inputs=[dummy_component], inputs=[dummy_component],
outputs=[ outputs=[
output_panel.gallery, output_panel.gallery,
output_panel.generation_info,
output_panel.infotext, output_panel.infotext,
output_panel.html_info,
output_panel.html_log, output_panel.html_log,
], ],
show_progress=False, show_progress=False,

View File

@ -108,8 +108,8 @@ def save_files(js_data, images, do_make_zip, index):
@dataclasses.dataclass @dataclasses.dataclass
class OutputPanel: class OutputPanel:
gallery = None gallery = None
generation_info = None
infotext = None infotext = None
html_info = None
html_log = None html_log = None
button_upscale = None button_upscale = None
@ -175,17 +175,17 @@ Requested path was: {f}
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group(): with gr.Group():
res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img': if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click( generation_info_button.click(
fn=update_generation_info, fn=update_generation_info,
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }", _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[res.infotext, res.html_info, res.html_info], inputs=[res.generation_info, res.infotext, res.infotext],
outputs=[res.html_info, res.html_info], outputs=[res.infotext, res.infotext],
show_progress=False, show_progress=False,
) )
@ -193,10 +193,10 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files), fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[ inputs=[
res.infotext, res.generation_info,
res.gallery, res.gallery,
res.html_info, res.infotext,
res.html_info, res.infotext,
], ],
outputs=[ outputs=[
download_files, download_files,
@ -209,10 +209,10 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files), fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[ inputs=[
res.infotext, res.generation_info,
res.gallery, res.gallery,
res.html_info, res.infotext,
res.html_info, res.infotext,
], ],
outputs=[ outputs=[
download_files, download_files,
@ -221,8 +221,8 @@ Requested path was: {f}
) )
else: else:
res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = [] paste_field_names = []

View File

@ -49,7 +49,7 @@ def create_ui():
], ],
outputs=[ outputs=[
output_panel.gallery, output_panel.gallery,
output_panel.infotext, output_panel.generation_info,
output_panel.html_log, output_panel.html_log,
], ],
show_progress=False, show_progress=False,