import datetime
import mimetypes
import os
import sys
from functools import reduce
import warnings
from contextlib import ExitStack
import gradio as gr
import gradio.utils
from gradio.components.image_editor import Brush
import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensions
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript
from modules.shared import opts, cmd_opts
import modules.generation_parameters_copypaste as parameters_copypaste
import modules.hypernetworks.ui as hypernetworks_ui
import modules.textual_inversion.ui as textual_inversion_ui
import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.generation_parameters_copypaste import image_from_url_text
create_setting_component = ui_settings.create_setting_component
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gradio_extensions.GradioDeprecationWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# Likewise, add explicit content-type header for certain missing image types
mimetypes.add_type('image/webp', '.webp')
if not cmd_opts.share and not cmd_opts.listen:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
if cmd_opts.ngrok is not None:
import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...')
ngrok.connect(
cmd_opts.ngrok,
cmd_opts.port if cmd_opts.port is not None else 7860,
cmd_opts.ngrok_options
)
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️
paste_symbol = '\u2199\ufe0f' # ↙
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀
detect_image_size_symbol = '\U0001F4D0' # 📐
plaintext_to_html = ui_common.plaintext_to_html
def send_gradio_gallery_to_image(x):
if len(x) == 0:
return None
return image_from_url_text(x[0])
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
if not enable:
return ""
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
p.calculate_target_resolution()
return f"from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}"
def resize_from_to_html(width, height, scale_by):
target_width = int(width * scale_by)
target_height = int(height * scale_by)
if not target_width or not target_height:
return "no image selected"
return f"resize: from {width}x{height} to {target_width}x{target_height}"
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
if mode in {0, 1, 3, 4}:
return [interrogation_function(ii_singles[mode]), None]
elif mode == 2:
return [interrogation_function(ii_singles[mode]["image"]), None]
elif mode == 5:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
images = shared.listfiles(ii_input_dir)
print(f"Will process {len(images)} images.")
if ii_output_dir != "":
os.makedirs(ii_output_dir, exist_ok=True)
else:
ii_output_dir = ii_input_dir
for image in images:
img = Image.open(image)
filename = os.path.basename(image)
left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None]
def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image)
return gr.update() if prompt is None else prompt
def connect_clear_prompt(button):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click(
_js="clear_prompt",
fn=None,
inputs=[],
outputs=[],
)
def update_token_counter(text, steps, *, is_positive=True):
try:
text, _ = extra_networks.parse_prompt(text)
if is_positive:
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
else:
prompt_flat_list = [text]
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]]
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])
return f"{token_count}/{max_length}"
def update_negative_prompt_token_counter(text, steps):
return update_token_counter(text, steps, is_positive=False)
def setup_progressbar(*args, **kwargs):
pass
def apply_setting(key, value):
if value is None:
return gr.update()
if shared.cmd_opts.freeze_settings:
return gr.update()
# dont allow model to be swapped when model hash exists in prompt
if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap:
return gr.update()
if key == "sd_model_checkpoint":
ckpt_info = sd_models.get_closet_checkpoint_match(value)
if ckpt_info is not None:
value = ckpt_info.title
else:
return gr.update()
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
return
valtype = type(opts.data_labels[key].default)
oldval = opts.data.get(key, None)
opts.data[key] = valtype(value) if valtype != type(None) else value
if oldval != value and opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange()
opts.save(shared.config_filename)
return getattr(opts, key)
def create_output_panel(tabname, outdir, toprow=None):
return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
return steps, sampler_name
def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
yield category
def create_override_settings_dropdown(tabname, row):
dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
dropdown.change(
fn=lambda x: gr.Dropdown.update(visible=bool(x)),
inputs=[dropdown],
outputs=[dropdown],
)
return dropdown
def create_ui():
import modules.img2img
import modules.txt2img
reload_javascript()
parameters_copypaste.reset()
scripts.scripts_current = scripts.scripts_txt2img
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
dummy_component = gr.Textbox(visible=False)
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
with ExitStack() as stack:
if shared.opts.txt2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
scripts.scripts_txt2img.prepare_ui()
for category in ordered_ui_categories():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "cfg":
with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass
elif category == "accordions":
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
with enable_hr.extra():
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution")
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80):
with gr.Row():
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80):
with gr.Row():
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
scripts.scripts_txt2img.setup_ui_for_section(category)
elif category == "batch":
if not opts.dimensions_and_batch_together:
with FormRow(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "override_settings":
with FormRow(elem_id="txt2img_override_settings_row") as row:
override_settings = create_override_settings_dropdown('txt2img', row)
elif category == "scripts":
with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = scripts.scripts_txt2img.setup_ui()
if category not in {"accordions"}:
scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
for component in hr_resolution_preview_inputs:
event = component.release if isinstance(component, gr.Slider) else component.change
event(
fn=calc_resolution_hires,
inputs=hr_resolution_preview_inputs,
outputs=[hr_final_resolution],
show_progress=False,
)
event(
None,
_js="onCalcResolutionHires",
inputs=hr_resolution_preview_inputs,
outputs=[],
show_progress=False,
)
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit",
inputs=[
dummy_component,
toprow.prompt,
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
batch_count,
batch_size,
cfg_scale,
height,
width,
enable_hr,
denoising_strength,
hr_scale,
hr_upscaler,
hr_second_pass_steps,
hr_resize_x,
hr_resize_y,
hr_checkpoint_name,
hr_sampler_name,
hr_prompt,
hr_negative_prompt,
override_settings,
] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
],
show_progress=False,
)
toprow.prompt.submit(**txt2img_args)
toprow.submit.click(**txt2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
outputs=[
txt2img_gallery,
generation_info,
html_info,
html_log,
],
show_progress=False,
)
txt2img_paste_fields = [
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(cfg_scale, "CFG scale"),
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_name, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
txt2img_preview_params = [
toprow.prompt,
toprow.negative_prompt,
steps,
sampler_name,
cfg_scale,
scripts.scripts_txt2img.script('Seed').seed,
width,
height,
]
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
extra_tabs.__exit__()
scripts.scripts_current = scripts.scripts_img2img
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
extra_tabs.__enter__()
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
with ExitStack() as stack:
if shared.opts.img2img_settings_accordion:
stack.enter_context(gr.Accordion("Open for Settings", open=False))
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
copy_image_buttons = []
copy_image_destinations = {}
def add_copy_image_controls(tab_name, elem):
with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"):
gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}")
for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):
if name == tab_name:
gr.Button(title, interactive=False)
copy_image_destinations[name] = elem
continue
button = gr.Button(title)
copy_image_buttons.append((button, name, elem))
scripts.scripts_img2img.prepare_ui()
for category in ordered_ui_categories():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "image":
with gr.Tabs(elem_id="mode_img2img"):
img2img_selected_tab = gr.State(0)
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
init_img = gr.ImageEditor(label="Image for img2img", elem_id="img2img_image", show_label=False, interactive=True, type="pil", image_mode="RGBA", height=opts.img2img_editor_height)
add_copy_image_controls('img2img', init_img)
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
sketch = gr.ImageEditor(label="Image for img2img", elem_id="img2img_sketch", show_label=False, interactive=True, type="pil", image_mode="RGB", height=opts.img2img_editor_height, brush=Brush(default_color=opts.img2img_sketch_default_brush_color))
add_copy_image_controls('sketch', sketch)
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
init_img_with_mask = gr.ImageEditor(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", brush=Brush(colors=["opts.img2img_inpaint_mask_brush_color"], color_mode="fixed"), interactive=True, type="pil", image_mode="RGBA", height=opts.img2img_editor_height)
add_copy_image_controls('inpaint', init_img_with_mask)
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
inpaint_color_sketch = gr.ImageEditor(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", brush=Brush(default_color=opts.img2img_inpaint_sketch_default_brush_color), interactive=True, type="pil", image_mode="RGB", height=opts.img2img_editor_height)
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
gr.HTML(
"
Process images in a directory on the same machine where the server is running." +
"
Use an empty output directory to save pictures normally instead of writing to the output directory." +
f"
Add inpaint batch mask directory to enable inpaint batch processing."
f"{hidden}
See wiki for detailed explanation.
") with gr.Row(variant="compact", equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding", id="create_embedding"): new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") def get_textual_inversion_template_names(): return sorted(textual_inversion.textual_inversion_templates) with gr.Tab(label="Train", id="train"): gr.HTML(value="Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]
") with FormRow(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks)) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name") with FormRow(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") with FormRow(): clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) with FormRow(): batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") with FormRow(): template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") use_weight = gr.Checkbox(label="Use PNG alpha channel as loss weight", value=False, elem_id="use_weight") save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") with gr.Row(): train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") params = script_callbacks.UiTrainTabParams(txt2img_preview_params) script_callbacks.ui_train_tabs_callback(params) with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4) gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") create_embedding.click( fn=textual_inversion_ui.create_embedding, inputs=[ new_embedding_name, initialization_text, nvpt, overwrite_old_embedding, ], outputs=[ train_embedding_name, ti_output, ti_outcome, ] ) create_hypernetwork.click( fn=hypernetworks_ui.create_hypernetwork, inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, overwrite_old_hypernetwork, new_hypernetwork_layer_structure, new_hypernetwork_activation_func, new_hypernetwork_initialization_option, new_hypernetwork_add_layer_norm, new_hypernetwork_use_dropout, new_hypernetwork_dropout_structure ], outputs=[ train_hypernetwork_name, ti_output, ti_outcome, ] ) train_embedding.click( fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, train_embedding_name, embedding_learn_rate, batch_size, gradient_step, dataset_directory, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, *txt2img_preview_params, ], outputs=[ ti_output, ti_outcome, ] ) train_hypernetwork.click( fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, train_hypernetwork_name, hypernetwork_learn_rate, batch_size, gradient_step, dataset_directory, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_file, preview_from_txt2img, *txt2img_preview_params, ], outputs=[ ti_output, ti_outcome, ] ) interrupt_training.click( fn=lambda: shared.state.interrupt(), inputs=[], outputs=[], ) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) interfaces = [ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "train"), ] interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings.interface, "Settings", "settings")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] shared.tab_names = [] for _interface, label, _ifid in interfaces: shared.tab_names.append(label) with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: settings.add_quicksettings() parameters_copypaste.connect_paste_params_buttons() with gr.Tabs(elem_id="tabs") as tabs: tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)} sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999)) for interface, label, ifid in sorted_interfaces: if label in shared.opts.hidden_tabs: continue with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"): interface.render() if ifid not in ["extensions", "settings"]: loadsave.add_block(interface, ifid) loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs) loadsave.setup_ui() if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio: gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) footer = shared.html("footer.html") footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API") gr.HTML(footer, elem_id="footer") settings.add_functionality(demo) update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) loadsave.dump_defaults() demo.ui_loadsave = loadsave return demo def versions_html(): import torch import launch python_version = ".".join([str(x) for x in sys.version_info[0:3]]) commit = launch.commit_hash() tag = launch.git_tag() if shared.xformers_available: import xformers xformers_version = xformers.__version__ else: xformers_version = "N/A" return f""" version: {tag} • python: {python_version} • torch: {getattr(torch, '__long_version__',torch.__version__)} • xformers: {xformers_version} • gradio: {gr.__version__} • checkpoint: N/A """ def setup_ui_api(app): from pydantic import BaseModel, Field class QuicksettingsHint(BaseModel): name: str = Field(title="Name of the quicksettings field") label: str = Field(title="Label of the quicksettings field") def quicksettings_hint(): return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()] app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"]) def download_sysinfo(attachment=False): from fastapi.responses import PlainTextResponse text = sysinfo.get() filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json" return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])