2022-09-03 17:08:45 +08:00
import html
import json
2022-09-10 16:10:00 +08:00
import math
2022-09-03 17:08:45 +08:00
import mimetypes
import os
2022-10-22 19:07:00 +08:00
import platform
2022-09-06 04:08:06 +08:00
import random
2022-09-03 17:08:45 +08:00
import sys
2022-10-15 18:11:28 +08:00
import tempfile
2022-09-03 17:08:45 +08:00
import time
import traceback
2022-10-15 01:04:47 +08:00
from functools import partial , reduce
2023-01-19 04:04:24 +08:00
import warnings
2022-09-03 17:08:45 +08:00
2022-10-22 19:07:00 +08:00
import gradio as gr
import gradio . routes
import gradio . utils
2022-09-07 00:33:51 +08:00
import numpy as np
2022-09-28 22:05:23 +08:00
from PIL import Image , PngImagePlugin
2022-11-28 14:00:10 +08:00
from modules . call_queue import wrap_gradio_gpu_call , wrap_queued_call , wrap_gradio_call
2022-09-03 17:08:45 +08:00
2023-04-30 03:16:54 +08:00
from modules import sd_hijack , sd_models , localization , script_callbacks , ui_extensions , deepbooru , sd_vae , extra_networks , postprocessing , ui_components , ui_common , ui_postprocessing , progress
2023-03-20 21:09:36 +08:00
from modules . ui_components import FormRow , FormColumn , FormGroup , ToolButton , FormHTML
2023-01-26 00:15:42 +08:00
from modules . paths import script_path , data_path
2022-10-17 03:06:21 +08:00
2022-10-17 01:08:23 +08:00
from modules . shared import opts , cmd_opts , restricted_opts
2022-10-14 16:56:41 +08:00
2022-10-22 19:07:00 +08:00
import modules . codeformer_model
2022-10-27 13:36:11 +08:00
import modules . generation_parameters_copypaste as parameters_copypaste
2022-10-22 19:07:00 +08:00
import modules . gfpgan_model
import modules . hypernetworks . ui
2022-09-03 22:21:15 +08:00
import modules . scripts
2022-10-22 19:07:00 +08:00
import modules . shared as shared
2022-09-10 04:16:02 +08:00
import modules . styles
2022-10-22 19:07:00 +08:00
import modules . textual_inversion . ui
2022-10-06 04:16:27 +08:00
from modules import prompt_parser
2022-10-05 00:19:50 +08:00
from modules . images import save_image
2022-10-22 19:07:00 +08:00
from modules . sd_hijack import model_hijack
from modules . sd_samplers import samplers , samplers_for_img2img
2023-01-10 04:35:40 +08:00
from modules . textual_inversion import textual_inversion
2022-10-11 20:51:22 +08:00
import modules . hypernetworks . ui
2022-10-27 13:36:11 +08:00
from modules . generation_parameters_copypaste import image_from_url_text
2023-01-23 19:42:49 +08:00
import modules . extras
2022-09-03 17:08:45 +08:00
2023-01-19 04:04:24 +08:00
warnings . filterwarnings ( " default " if opts . show_warnings else " ignore " , category = UserWarning )
2022-10-09 03:12:24 +08:00
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
2022-09-03 17:08:45 +08:00
mimetypes . init ( )
mimetypes . add_type ( ' application/javascript ' , ' .js ' )
2022-09-06 00:37:11 +08:00
if not cmd_opts . share and not cmd_opts . listen :
2022-09-03 17:08:45 +08:00
# fix gradio phoning home
gradio . utils . version_check = lambda : None
gradio . utils . get_local_ip_address = lambda : ' 127.0.0.1 '
2022-12-15 02:59:33 +08:00
if cmd_opts . ngrok is not None :
2022-10-11 17:40:27 +08:00
import modules . ngrok as ngrok
print ( ' ngrok authtoken detected, trying to connect... ' )
2022-12-15 02:59:33 +08:00
ngrok . connect (
cmd_opts . ngrok ,
cmd_opts . port if cmd_opts . port is not None else 7860 ,
cmd_opts . ngrok_region
)
2022-10-11 17:40:27 +08:00
2022-09-03 17:08:45 +08:00
def gr_show ( visible = True ) :
return { " visible " : visible , " __type__ " : " update " }
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
sample_img2img = sample_img2img if os . path . exists ( sample_img2img ) else None
2022-09-17 03:20:56 +08:00
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = ' \U0001f3b2 \ufe0f ' # 🎲️
reuse_symbol = ' \u267b \ufe0f ' # ♻️
2022-09-24 03:49:21 +08:00
paste_symbol = ' \u2199 \ufe0f ' # ↙
2022-10-14 00:22:41 +08:00
refresh_symbol = ' \U0001f504 ' # 🔄
2022-10-15 19:22:30 +08:00
save_style_symbol = ' \U0001f4be ' # 💾
apply_style_symbol = ' \U0001f4cb ' # 📋
2023-03-20 21:09:36 +08:00
clear_prompt_symbol = ' \U0001f5d1 \ufe0f ' # 🗑️
2023-01-21 13:36:07 +08:00
extra_networks_symbol = ' \U0001F3B4 ' # 🎴
2023-01-28 02:34:41 +08:00
switch_values_symbol = ' \U000021C5 ' # ⇅
2023-04-30 03:16:54 +08:00
restore_progress_symbol = ' \U0001F300 ' # 🌀
2023-01-28 02:34:41 +08:00
2022-09-17 03:20:56 +08:00
2022-09-03 17:08:45 +08:00
def plaintext_to_html ( text ) :
2023-01-23 14:24:43 +08:00
return ui_common . plaintext_to_html ( text )
2023-01-22 20:38:39 +08:00
2022-09-03 17:08:45 +08:00
def send_gradio_gallery_to_image ( x ) :
if len ( x ) == 0 :
return None
return image_from_url_text ( x [ 0 ] )
2022-09-04 18:52:01 +08:00
def visit ( x , func , path = " " ) :
if hasattr ( x , ' children ' ) :
2023-03-30 02:04:02 +08:00
if isinstance ( x , gr . Tabs ) and x . elem_id is not None :
# Tabs element can't have a label, have to use elem_id instead
func ( f " { path } /Tabs@ { x . elem_id } " , x )
2022-09-04 18:52:01 +08:00
for c in x . children :
visit ( c , func , path )
elif x . label is not None :
func ( path + " / " + str ( x . label ) , x )
2022-09-03 17:08:45 +08:00
2022-09-11 22:35:12 +08:00
def add_style ( name : str , prompt : str , negative_prompt : str ) :
if name is None :
2022-10-15 19:22:30 +08:00
return [ gr_show ( ) for x in range ( 4 ) ]
2022-09-10 04:16:02 +08:00
2022-09-11 22:35:12 +08:00
style = modules . styles . PromptStyle ( name , prompt , negative_prompt )
2022-09-14 22:56:21 +08:00
shared . prompt_styles . styles [ style . name ] = style
2022-09-11 22:35:12 +08:00
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
# reserialize all styles every time we save them
2022-09-14 22:56:21 +08:00
shared . prompt_styles . save_styles ( shared . styles_filename )
2022-09-10 04:16:02 +08:00
2023-01-14 19:56:39 +08:00
return [ gr . Dropdown . update ( visible = True , choices = list ( shared . prompt_styles . styles ) ) for _ in range ( 2 ) ]
2022-09-14 22:56:21 +08:00
2023-01-07 14:56:37 +08:00
def calc_resolution_hires ( enable , width , height , hr_scale , hr_resize_x , hr_resize_y ) :
from modules import processing , devices
if not enable :
return " "
p = processing . StableDiffusionProcessingTxt2Img ( width = width , height = height , enable_hr = True , hr_scale = hr_scale , hr_resize_x = hr_resize_x , hr_resize_y = hr_resize_y )
with devices . autocast ( ) :
p . init ( [ " " ] , [ 0 ] , [ 0 ] )
2023-01-09 19:57:47 +08:00
return f " resize: from <span class= ' resolution ' > { p . width } x { p . height } </span> to <span class= ' resolution ' > { p . hr_resize_x or p . hr_upscale_to_x } x { p . hr_resize_y or p . hr_upscale_to_y } </span> "
2023-01-07 14:56:37 +08:00
2022-09-14 22:56:21 +08:00
2023-03-29 03:23:40 +08:00
def resize_from_to_html ( width , height , scale_by ) :
target_width = int ( width * scale_by )
target_height = int ( height * scale_by )
if not target_width or not target_height :
return " no image selected "
return f " resize: from <span class= ' resolution ' > { width } x { height } </span> to <span class= ' resolution ' > { target_width } x { target_height } </span> "
2023-01-14 19:56:39 +08:00
def apply_styles ( prompt , prompt_neg , styles ) :
prompt = shared . prompt_styles . apply_styles_to_prompt ( prompt , styles )
prompt_neg = shared . prompt_styles . apply_negative_styles_to_prompt ( prompt_neg , styles )
2022-09-14 22:56:21 +08:00
2023-01-14 19:56:39 +08:00
return [ gr . Textbox . update ( value = prompt ) , gr . Textbox . update ( value = prompt_neg ) , gr . Dropdown . update ( value = [ ] ) ]
2022-09-10 04:16:02 +08:00
2023-01-19 01:16:52 +08:00
def process_interrogate ( interrogation_function , mode , ii_input_dir , ii_output_dir , * ii_singles ) :
if mode in { 0 , 1 , 3 , 4 } :
return [ interrogation_function ( ii_singles [ mode ] ) , None ]
elif mode == 2 :
return [ interrogation_function ( ii_singles [ mode ] [ " image " ] ) , None ]
elif mode == 5 :
assert not shared . cmd_opts . hide_ui_dir_config , " Launched with --hide-ui-dir-config, batch img2img disabled "
images = shared . listfiles ( ii_input_dir )
print ( f " Will process { len ( images ) } images. " )
if ii_output_dir != " " :
os . makedirs ( ii_output_dir , exist_ok = True )
else :
ii_output_dir = ii_input_dir
for image in images :
img = Image . open ( image )
filename = os . path . basename ( image )
left , _ = os . path . splitext ( filename )
print ( interrogation_function ( img ) , file = open ( os . path . join ( ii_output_dir , left + " .txt " ) , ' a ' ) )
2023-01-21 14:14:27 +08:00
return [ gr . update ( ) , None ]
2023-01-19 01:16:52 +08:00
2022-09-11 23:48:36 +08:00
def interrogate ( image ) :
2022-12-25 12:23:12 +08:00
prompt = shared . interrogator . interrogate ( image . convert ( " RGB " ) )
2023-01-21 14:14:27 +08:00
return gr . update ( ) if prompt is None else prompt
2022-09-11 23:48:36 +08:00
2022-09-14 22:56:21 +08:00
2022-10-06 02:50:10 +08:00
def interrogate_deepbooru ( image ) :
2022-11-20 21:39:20 +08:00
prompt = deepbooru . model . tag ( image )
2023-01-21 14:14:27 +08:00
return gr . update ( ) if prompt is None else prompt
2022-10-06 02:50:10 +08:00
2023-01-01 21:51:12 +08:00
def create_seed_inputs ( target_interface ) :
2023-03-20 21:09:36 +08:00
with FormRow ( elem_id = target_interface + ' _seed_row ' , variant = " compact " ) :
2023-01-03 14:04:29 +08:00
seed = ( gr . Textbox if cmd_opts . use_textbox_seed else gr . Number ) ( label = ' Seed ' , value = - 1 , elem_id = target_interface + ' _seed ' )
seed . style ( container = False )
2023-04-18 05:48:28 +08:00
random_seed = ToolButton ( random_symbol , elem_id = target_interface + ' _random_seed ' , label = ' Random seed ' )
reuse_seed = ToolButton ( reuse_symbol , elem_id = target_interface + ' _reuse_seed ' , label = ' Reuse seed ' )
2022-09-17 03:20:56 +08:00
2023-03-20 21:09:36 +08:00
seed_checkbox = gr . Checkbox ( label = ' Extra ' , elem_id = target_interface + ' _subseed_show ' , value = False )
2022-09-17 03:20:56 +08:00
# Components to show/hide based on the 'Extra' checkbox
seed_extras = [ ]
2023-01-03 14:04:29 +08:00
with FormRow ( visible = False , elem_id = target_interface + ' _subseed_row ' ) as seed_extra_row_1 :
2022-09-17 03:20:56 +08:00
seed_extras . append ( seed_extra_row_1 )
2023-01-03 14:04:29 +08:00
subseed = gr . Number ( label = ' Variation seed ' , value = - 1 , elem_id = target_interface + ' _subseed ' )
subseed . style ( container = False )
2023-03-20 21:09:36 +08:00
random_subseed = ToolButton ( random_symbol , elem_id = target_interface + ' _random_subseed ' )
reuse_subseed = ToolButton ( reuse_symbol , elem_id = target_interface + ' _reuse_subseed ' )
2023-01-01 21:51:12 +08:00
subseed_strength = gr . Slider ( label = ' Variation strength ' , value = 0.0 , minimum = 0 , maximum = 1 , step = 0.01 , elem_id = target_interface + ' _subseed_strength ' )
2022-09-17 03:20:56 +08:00
2023-01-03 14:04:29 +08:00
with FormRow ( visible = False ) as seed_extra_row_2 :
2022-09-17 03:20:56 +08:00
seed_extras . append ( seed_extra_row_2 )
2023-01-01 21:51:12 +08:00
seed_resize_from_w = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize seed from width " , value = 0 , elem_id = target_interface + ' _seed_resize_from_w ' )
seed_resize_from_h = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize seed from height " , value = 0 , elem_id = target_interface + ' _seed_resize_from_h ' )
2022-09-17 03:20:56 +08:00
random_seed . click ( fn = lambda : - 1 , show_progress = False , inputs = [ ] , outputs = [ seed ] )
random_subseed . click ( fn = lambda : - 1 , show_progress = False , inputs = [ ] , outputs = [ subseed ] )
def change_visibility ( show ) :
return { comp : gr_show ( show ) for comp in seed_extras }
seed_checkbox . change ( change_visibility , show_progress = False , inputs = [ seed_checkbox ] , outputs = seed_extras )
2022-09-21 18:34:10 +08:00
return seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox
2022-09-17 03:20:56 +08:00
2022-10-22 03:24:14 +08:00
2022-11-02 03:33:55 +08:00
def connect_clear_prompt ( button ) :
2022-10-22 04:32:26 +08:00
""" Given clear button, prompt, and token_counter objects, setup clear prompt button click event """
2022-10-22 03:24:14 +08:00
button . click (
2022-10-22 04:13:12 +08:00
_js = " clear_prompt " ,
2022-11-02 03:03:56 +08:00
fn = None ,
2022-11-02 03:33:55 +08:00
inputs = [ ] ,
outputs = [ ] ,
2022-10-22 03:24:14 +08:00
)
2022-10-20 14:08:24 +08:00
2022-09-19 14:02:10 +08:00
def connect_reuse_seed ( seed : gr . Number , reuse_seed : gr . Button , generation_info : gr . Textbox , dummy_component , is_subseed ) :
""" Connects a ' reuse (sub)seed ' button ' s click event so that it copies last used
( sub ) seed value from generation info the to the seed field . If copying subseed and subseed strength
2022-09-17 03:20:56 +08:00
was 0 , i . e . no variation seed was used , it copies the normal seed value instead . """
2022-09-19 14:02:10 +08:00
def copy_seed ( gen_info_string : str , index ) :
res = - 1
2022-09-17 03:20:56 +08:00
try :
gen_info = json . loads ( gen_info_string )
2022-09-19 14:02:10 +08:00
index - = gen_info . get ( ' index_of_first_image ' , 0 )
if is_subseed and gen_info . get ( ' subseed_strength ' , 0 ) > 0 :
all_subseeds = gen_info . get ( ' all_subseeds ' , [ - 1 ] )
res = all_subseeds [ index if 0 < = index < len ( all_subseeds ) else 0 ]
2022-09-17 03:20:56 +08:00
else :
2022-09-19 14:02:10 +08:00
all_seeds = gen_info . get ( ' all_seeds ' , [ - 1 ] )
res = all_seeds [ index if 0 < = index < len ( all_seeds ) else 0 ]
2022-09-17 03:20:56 +08:00
except json . decoder . JSONDecodeError as e :
if gen_info_string != ' ' :
print ( " Error parsing JSON generation info: " , file = sys . stderr )
print ( gen_info_string , file = sys . stderr )
2022-09-19 14:02:10 +08:00
return [ res , gr_show ( False ) ]
2022-09-17 03:20:56 +08:00
reuse_seed . click (
fn = copy_seed ,
2022-09-19 14:02:10 +08:00
_js = " (x, y) => [x, selected_gallery_index()] " ,
2022-09-17 03:20:56 +08:00
show_progress = False ,
2022-09-19 14:02:10 +08:00
inputs = [ generation_info , dummy_component ] ,
outputs = [ seed , dummy_component ]
2022-09-17 03:20:56 +08:00
)
2022-10-04 19:35:12 +08:00
2022-09-30 03:47:06 +08:00
def update_token_counter ( text , steps ) :
2022-10-04 19:35:12 +08:00
try :
2023-01-21 13:36:07 +08:00
text , _ = extra_networks . parse_prompt ( text )
2022-10-06 04:16:27 +08:00
_ , prompt_flat_list , _ = prompt_parser . get_multicond_prompt_list ( [ text ] )
prompt_schedules = prompt_parser . get_learned_conditioning_prompt_schedules ( prompt_flat_list , steps )
2022-10-04 19:35:12 +08:00
except Exception :
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [ [ [ steps , text ] ] ]
2022-09-30 03:47:06 +08:00
flat_prompts = reduce ( lambda list1 , list2 : list1 + list2 , prompt_schedules )
2022-10-04 19:35:12 +08:00
prompts = [ prompt_text for step , prompt_text in flat_prompts ]
2023-01-07 06:45:28 +08:00
token_count , max_length = max ( [ model_hijack . get_prompt_lengths ( prompt ) for prompt in prompts ] , key = lambda args : args [ 0 ] )
2023-01-20 15:18:41 +08:00
return f " <span class= ' gr-box gr-text-input ' > { token_count } / { max_length } </span> "
2022-09-19 21:42:56 +08:00
2022-10-04 19:35:12 +08:00
2022-09-14 22:56:21 +08:00
def create_toprow ( is_img2img ) :
2022-09-24 01:46:02 +08:00
id_part = " img2img " if is_img2img else " txt2img "
2023-01-20 15:18:41 +08:00
with gr . Row ( elem_id = f " { id_part } _toprow " , variant = " compact " ) :
with gr . Column ( elem_id = f " { id_part } _prompt_container " , scale = 6 ) :
2022-09-14 22:56:21 +08:00
with gr . Row ( ) :
2022-09-24 01:54:17 +08:00
with gr . Column ( scale = 80 ) :
2022-09-14 22:56:21 +08:00
with gr . Row ( ) :
2023-01-21 14:48:38 +08:00
prompt = gr . Textbox ( label = " Prompt " , elem_id = f " { id_part } _prompt " , show_label = False , lines = 3 , placeholder = " Prompt (press Ctrl+Enter or Alt+Enter to generate) " )
2022-10-15 11:48:13 +08:00
2022-09-14 22:56:21 +08:00
with gr . Row ( ) :
2022-10-15 19:22:30 +08:00
with gr . Column ( scale = 80 ) :
2022-10-11 15:08:45 +08:00
with gr . Row ( ) :
2023-03-20 21:09:36 +08:00
negative_prompt = gr . Textbox ( label = " Negative prompt " , elem_id = f " { id_part } _neg_prompt " , show_label = False , lines = 3 , placeholder = " Negative prompt (press Ctrl+Enter or Alt+Enter to generate) " )
2022-10-15 11:48:13 +08:00
2022-10-15 19:22:30 +08:00
button_interrogate = None
button_deepbooru = None
if is_img2img :
2023-03-20 21:09:36 +08:00
with gr . Column ( scale = 1 , elem_classes = " interrogate-col " ) :
2022-10-15 19:22:30 +08:00
button_interrogate = gr . Button ( ' Interrogate \n CLIP ' , elem_id = " interrogate " )
2022-11-26 21:10:46 +08:00
button_deepbooru = gr . Button ( ' Interrogate \n DeepBooru ' , elem_id = " deepbooru " )
2022-09-14 22:56:21 +08:00
2023-01-21 14:48:38 +08:00
with gr . Column ( scale = 1 , elem_id = f " { id_part } _actions_column " ) :
2023-03-20 21:09:36 +08:00
with gr . Row ( elem_id = f " { id_part } _generate_box " , elem_classes = " generate-box " ) :
interrupt = gr . Button ( ' Interrupt ' , elem_id = f " { id_part } _interrupt " , elem_classes = " generate-box-interrupt " )
skip = gr . Button ( ' Skip ' , elem_id = f " { id_part } _skip " , elem_classes = " generate-box-skip " )
2022-10-01 04:31:00 +08:00
submit = gr . Button ( ' Generate ' , elem_id = f " { id_part } _generate " , variant = ' primary ' )
2022-09-22 09:12:39 +08:00
2022-10-05 11:56:30 +08:00
skip . click (
fn = lambda : shared . state . skip ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-09-22 09:12:39 +08:00
interrupt . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-09-14 22:56:21 +08:00
2023-01-21 14:48:38 +08:00
with gr . Row ( elem_id = f " { id_part } _tools " ) :
paste = ToolButton ( value = paste_symbol , elem_id = " paste " )
clear_prompt_button = ToolButton ( value = clear_prompt_symbol , elem_id = f " { id_part } _clear_prompt " )
extra_networks_button = ToolButton ( value = extra_networks_symbol , elem_id = f " { id_part } _extra_networks " )
prompt_style_apply = ToolButton ( value = apply_style_symbol , elem_id = f " { id_part } _style_apply " )
save_style = ToolButton ( value = save_style_symbol , elem_id = f " { id_part } _style_create " )
2023-04-30 03:16:54 +08:00
restore_progress_button = ToolButton ( value = restore_progress_symbol , elem_id = f " { id_part } _restore_progress " , visible = False )
2023-01-21 14:48:38 +08:00
2023-03-20 21:09:36 +08:00
token_counter = gr . HTML ( value = " <span>0/75</span> " , elem_id = f " { id_part } _token_counter " , elem_classes = [ " token-counter " ] )
2023-01-21 14:48:38 +08:00
token_button = gr . Button ( visible = False , elem_id = f " { id_part } _token_button " )
2023-03-20 21:09:36 +08:00
negative_token_counter = gr . HTML ( value = " <span>0/75</span> " , elem_id = f " { id_part } _negative_token_counter " , elem_classes = [ " token-counter " ] )
2023-01-21 14:48:38 +08:00
negative_token_button = gr . Button ( visible = False , elem_id = f " { id_part } _negative_token_button " )
clear_prompt_button . click (
fn = lambda * x : x ,
_js = " confirm_clear_prompt " ,
inputs = [ prompt , negative_prompt ] ,
outputs = [ prompt , negative_prompt ] ,
)
2023-01-21 13:36:07 +08:00
with gr . Row ( elem_id = f " { id_part } _styles_row " ) :
2023-01-14 19:56:39 +08:00
prompt_styles = gr . Dropdown ( label = " Styles " , elem_id = f " { id_part } _styles " , choices = [ k for k , v in shared . prompt_styles . styles . items ( ) ] , value = [ ] , multiselect = True )
create_refresh_button ( prompt_styles , shared . prompt_styles . reload , lambda : { " choices " : [ k for k , v in shared . prompt_styles . styles . items ( ) ] } , f " refresh_ { id_part } _styles " )
2022-10-15 19:22:30 +08:00
2023-04-30 03:16:54 +08:00
return prompt , prompt_styles , negative_prompt , submit , button_interrogate , button_deepbooru , prompt_style_apply , save_style , paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button , restore_progress_button
2022-09-14 22:56:21 +08:00
2023-01-10 17:29:45 +08:00
def setup_progressbar ( * args , * * kwargs ) :
2023-01-15 23:50:56 +08:00
pass
2022-09-14 22:56:21 +08:00
2022-10-15 00:30:28 +08:00
def apply_setting ( key , value ) :
if value is None :
return gr . update ( )
2022-10-23 03:05:22 +08:00
if shared . cmd_opts . freeze_settings :
return gr . update ( )
2022-10-17 22:58:21 +08:00
# dont allow model to be swapped when model hash exists in prompt
if key == " sd_model_checkpoint " and opts . disable_weights_auto_swap :
return gr . update ( )
2022-10-15 00:30:28 +08:00
if key == " sd_model_checkpoint " :
ckpt_info = sd_models . get_closet_checkpoint_match ( value )
if ckpt_info is not None :
value = ckpt_info . title
else :
return gr . update ( )
comp_args = opts . data_labels [ key ] . component_args
if comp_args and isinstance ( comp_args , dict ) and comp_args . get ( ' visible ' ) is False :
return
valtype = type ( opts . data_labels [ key ] . default )
2023-01-03 01:46:51 +08:00
oldval = opts . data . get ( key , None )
2022-10-15 00:30:28 +08:00
opts . data [ key ] = valtype ( value ) if valtype != type ( None ) else value
if oldval != value and opts . data_labels [ key ] . onchange is not None :
opts . data_labels [ key ] . onchange ( )
opts . save ( shared . config_filename )
2023-01-19 23:58:08 +08:00
return getattr ( opts , key )
2022-10-15 00:30:28 +08:00
2023-01-30 05:25:30 +08:00
2022-10-21 21:10:51 +08:00
def create_refresh_button ( refresh_component , refresh_method , refreshed_args , elem_id ) :
def refresh ( ) :
refresh_method ( )
args = refreshed_args ( ) if callable ( refreshed_args ) else refreshed_args
2022-10-02 20:03:39 +08:00
2022-10-21 21:10:51 +08:00
for k , v in args . items ( ) :
setattr ( refresh_component , k , v )
2022-10-16 12:42:52 +08:00
2022-10-21 21:10:51 +08:00
return gr . update ( * * ( args or { } ) )
2022-10-16 12:42:52 +08:00
2023-01-03 14:04:29 +08:00
refresh_button = ToolButton ( value = refresh_symbol , elem_id = elem_id )
2022-10-21 21:10:51 +08:00
refresh_button . click (
fn = refresh ,
inputs = [ ] ,
outputs = [ refresh_component ]
)
return refresh_button
2022-10-16 12:42:52 +08:00
2022-10-29 13:28:48 +08:00
def create_output_panel ( tabname , outdir ) :
2023-01-23 14:24:43 +08:00
return ui_common . create_output_panel ( tabname , outdir )
2022-10-10 09:26:52 +08:00
2022-10-08 13:09:29 +08:00
2023-01-01 06:19:10 +08:00
def create_sampler_and_steps_selection ( choices , tabname ) :
if opts . samplers_in_dropdown :
2023-01-03 14:04:29 +08:00
with FormRow ( elem_id = f " sampler_selection_ { tabname } " ) :
2023-01-01 06:19:10 +08:00
sampler_index = gr . Dropdown ( label = ' Sampling method ' , elem_id = f " { tabname } _sampling " , choices = [ x . name for x in choices ] , value = choices [ 0 ] . name , type = " index " )
2023-01-05 03:04:40 +08:00
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , elem_id = f " { tabname } _steps " , label = " Sampling steps " , value = 20 )
2023-01-01 06:19:10 +08:00
else :
2023-01-03 14:04:29 +08:00
with FormGroup ( elem_id = f " sampler_selection_ { tabname } " ) :
2023-01-05 03:04:40 +08:00
steps = gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , elem_id = f " { tabname } _steps " , label = " Sampling steps " , value = 20 )
2023-01-01 06:19:10 +08:00
sampler_index = gr . Radio ( label = ' Sampling method ' , elem_id = f " { tabname } _sampling " , choices = [ x . name for x in choices ] , value = choices [ 0 ] . name , type = " index " )
return steps , sampler_index
2022-10-29 13:28:48 +08:00
2022-10-16 12:42:52 +08:00
2023-01-03 15:39:21 +08:00
def ordered_ui_categories ( ) :
2023-01-16 04:32:38 +08:00
user_order = { x . strip ( ) : i * 2 + 1 for i , x in enumerate ( shared . opts . ui_reorder . split ( " , " ) ) }
2023-01-03 15:39:21 +08:00
2023-01-16 04:32:38 +08:00
for i , category in sorted ( enumerate ( shared . ui_reorder_categories ) , key = lambda x : user_order . get ( x [ 1 ] , x [ 0 ] * 2 + 0 ) ) :
2023-01-03 15:39:21 +08:00
yield category
2023-01-19 23:58:08 +08:00
def get_value_for_setting ( key ) :
value = getattr ( opts , key )
info = opts . data_labels [ key ]
args = info . component_args ( ) if callable ( info . component_args ) else info . component_args or { }
args = { k : v for k , v in args . items ( ) if k not in { ' precision ' } }
return gr . update ( value = value , * * args )
2023-01-30 05:25:30 +08:00
def create_override_settings_dropdown ( tabname , row ) :
dropdown = gr . Dropdown ( [ ] , label = " Override settings " , visible = False , elem_id = f " { tabname } _override_settings " , multiselect = True )
dropdown . change (
fn = lambda x : gr . Dropdown . update ( visible = len ( x ) > 0 ) ,
inputs = [ dropdown ] ,
outputs = [ dropdown ] ,
)
return dropdown
2022-11-28 14:00:10 +08:00
def create_ui ( ) :
2022-10-21 21:10:51 +08:00
import modules . img2img
import modules . txt2img
2022-10-16 12:42:52 +08:00
2022-11-02 12:26:31 +08:00
reload_javascript ( )
2022-10-31 22:36:45 +08:00
parameters_copypaste . reset ( )
2022-10-16 12:42:52 +08:00
2022-11-20 00:10:17 +08:00
modules . scripts . scripts_current = modules . scripts . scripts_txt2img
modules . scripts . scripts_txt2img . initialize_scripts ( is_img2img = False )
2022-09-03 17:08:45 +08:00
with gr . Blocks ( analytics_enabled = False ) as txt2img_interface :
2023-04-30 03:16:54 +08:00
txt2img_prompt , txt2img_prompt_styles , txt2img_negative_prompt , submit , _ , _ , txt2img_prompt_style_apply , txt2img_save_style , txt2img_paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button , restore_progress_button = create_toprow ( is_img2img = False )
2022-10-20 10:23:57 +08:00
2022-09-19 14:02:10 +08:00
dummy_component = gr . Label ( visible = False )
2023-01-19 04:04:24 +08:00
txt_prompt_img = gr . File ( label = " " , elem_id = " txt2img_prompt_image " , file_count = " single " , type = " binary " , visible = False )
2022-09-03 17:08:45 +08:00
2023-01-21 13:36:07 +08:00
with FormRow ( variant = ' compact ' , elem_id = " txt2img_extra_networks " , visible = False ) as extra_networks :
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks . create_ui ( extra_networks , extra_networks_button , ' txt2img ' )
2022-09-03 17:08:45 +08:00
with gr . Row ( ) . style ( equal_height = False ) :
2023-01-14 18:38:10 +08:00
with gr . Column ( variant = ' compact ' , elem_id = " txt2img_settings " ) :
2023-01-03 15:39:21 +08:00
for category in ordered_ui_categories ( ) :
if category == " sampler " :
steps , sampler_index = create_sampler_and_steps_selection ( samplers , " txt2img " )
2022-09-19 21:42:56 +08:00
2023-01-03 15:39:21 +08:00
elif category == " dimensions " :
with FormRow ( ) :
with gr . Column ( elem_id = " txt2img_column_size " , scale = 4 ) :
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " txt2img_width " )
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " txt2img_height " )
2023-03-21 13:49:08 +08:00
with gr . Column ( elem_id = " txt2img_dimensions_row " , scale = 1 , elem_classes = " dimensions-tools " ) :
2023-04-18 05:48:28 +08:00
res_switch_btn = ToolButton ( value = switch_values_symbol , elem_id = " txt2img_res_switch_btn " , label = " Switch dims " )
2023-03-20 21:09:36 +08:00
2023-01-03 15:39:21 +08:00
if opts . dimensions_and_batch_together :
with gr . Column ( elem_id = " txt2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " txt2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " txt2img_batch_size " )
elif category == " cfg " :
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 30.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 , elem_id = " txt2img_cfg_scale " )
elif category == " seed " :
seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox = create_seed_inputs ( ' txt2img ' )
elif category == " checkboxes " :
2023-03-20 21:09:36 +08:00
with FormRow ( elem_classes = " checkboxes-row " , variant = " compact " ) :
2023-01-03 15:39:21 +08:00
restore_faces = gr . Checkbox ( label = ' Restore faces ' , value = False , visible = len ( shared . face_restorers ) > 1 , elem_id = " txt2img_restore_faces " )
tiling = gr . Checkbox ( label = ' Tiling ' , value = False , elem_id = " txt2img_tiling " )
enable_hr = gr . Checkbox ( label = ' Hires. fix ' , value = False , elem_id = " txt2img_enable_hr " )
2023-01-07 14:56:37 +08:00
hr_final_resolution = FormHTML ( value = " " , elem_id = " txtimg_hr_finalres " , label = " Upscaled resolution " , interactive = False )
2023-01-03 15:39:21 +08:00
elif category == " hires_fix " :
2023-01-05 03:04:40 +08:00
with FormGroup ( visible = False , elem_id = " txt2img_hires_fix " ) as hr_options :
2023-01-19 05:44:51 +08:00
with FormRow ( elem_id = " txt2img_hires_fix_row1 " , variant = " compact " ) :
2023-01-05 03:04:40 +08:00
hr_upscaler = gr . Dropdown ( label = " Upscaler " , elem_id = " txt2img_hr_upscaler " , choices = [ * shared . latent_upscale_modes , * [ x . name for x in shared . sd_upscalers ] ] , value = shared . latent_upscale_default_mode )
hr_second_pass_steps = gr . Slider ( minimum = 0 , maximum = 150 , step = 1 , label = ' Hires steps ' , value = 0 , elem_id = " txt2img_hires_steps " )
denoising_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising strength ' , value = 0.7 , elem_id = " txt2img_denoising_strength " )
2023-01-19 05:44:51 +08:00
with FormRow ( elem_id = " txt2img_hires_fix_row2 " , variant = " compact " ) :
2023-01-05 03:04:40 +08:00
hr_scale = gr . Slider ( minimum = 1.0 , maximum = 4.0 , step = 0.05 , label = " Upscale by " , value = 2.0 , elem_id = " txt2img_hr_scale " )
hr_resize_x = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize width to " , value = 0 , elem_id = " txt2img_hr_resize_x " )
hr_resize_y = gr . Slider ( minimum = 0 , maximum = 2048 , step = 8 , label = " Resize height to " , value = 0 , elem_id = " txt2img_hr_resize_y " )
2023-01-03 15:39:21 +08:00
elif category == " batch " :
if not opts . dimensions_and_batch_together :
with FormRow ( elem_id = " txt2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " txt2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " txt2img_batch_size " )
2023-01-30 05:25:30 +08:00
elif category == " override_settings " :
with FormRow ( elem_id = " txt2img_override_settings_row " ) as row :
override_settings = create_override_settings_dropdown ( ' txt2img ' , row )
2023-01-03 15:39:21 +08:00
elif category == " scripts " :
with FormGroup ( elem_id = " txt2img_script_container " ) :
custom_inputs = modules . scripts . scripts_txt2img . setup_ui ( )
2022-09-03 17:08:45 +08:00
2023-01-07 14:56:37 +08:00
hr_resolution_preview_inputs = [ enable_hr , width , height , hr_scale , hr_resize_x , hr_resize_y ]
for input in hr_resolution_preview_inputs :
2023-01-09 19:57:47 +08:00
input . change (
fn = calc_resolution_hires ,
inputs = hr_resolution_preview_inputs ,
outputs = [ hr_final_resolution ] ,
show_progress = False ,
)
input . change (
None ,
_js = " onCalcResolutionHires " ,
inputs = hr_resolution_preview_inputs ,
outputs = [ ] ,
show_progress = False ,
)
2023-01-07 13:53:53 +08:00
2023-01-01 04:40:55 +08:00
txt2img_gallery , generation_info , html_info , html_log = create_output_panel ( " txt2img " , opts . outdir_txt2img_samples )
2022-09-03 17:08:45 +08:00
2022-09-19 14:02:10 +08:00
connect_reuse_seed ( seed , reuse_seed , generation_info , dummy_component , is_subseed = False )
connect_reuse_seed ( subseed , reuse_subseed , generation_info , dummy_component , is_subseed = True )
2022-09-17 03:20:56 +08:00
2022-09-03 17:08:45 +08:00
txt2img_args = dict (
2023-01-01 04:40:55 +08:00
fn = wrap_gradio_gpu_call ( modules . txt2img . txt2img , extra_outputs = [ None , ' ' , ' ' ] ) ,
2022-09-06 07:09:01 +08:00
_js = " submit " ,
2022-09-03 17:08:45 +08:00
inputs = [
2023-01-15 23:50:56 +08:00
dummy_component ,
2022-09-10 04:16:02 +08:00
txt2img_prompt ,
2022-09-11 22:35:12 +08:00
txt2img_negative_prompt ,
2023-01-14 19:56:39 +08:00
txt2img_prompt_styles ,
2022-09-03 17:08:45 +08:00
steps ,
sampler_index ,
2022-09-07 17:32:28 +08:00
restore_faces ,
2022-09-05 08:25:37 +08:00
tiling ,
2022-09-03 17:08:45 +08:00
batch_count ,
batch_size ,
cfg_scale ,
seed ,
2022-09-21 18:34:10 +08:00
subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox ,
2022-09-03 17:08:45 +08:00
height ,
width ,
2022-09-19 21:42:56 +08:00
enable_hr ,
denoising_strength ,
2023-01-03 00:42:10 +08:00
hr_scale ,
hr_upscaler ,
2023-01-05 03:04:40 +08:00
hr_second_pass_steps ,
hr_resize_x ,
hr_resize_y ,
2023-01-30 05:25:30 +08:00
override_settings ,
2022-09-03 22:21:15 +08:00
] + custom_inputs ,
2022-10-15 22:20:17 +08:00
2022-09-03 17:08:45 +08:00
outputs = [
txt2img_gallery ,
generation_info ,
2023-01-01 04:40:55 +08:00
html_info ,
html_log ,
2022-09-18 16:14:42 +08:00
] ,
show_progress = False ,
2022-09-03 17:08:45 +08:00
)
2022-09-10 04:16:02 +08:00
txt2img_prompt . submit ( * * txt2img_args )
2022-09-03 17:08:45 +08:00
submit . click ( * * txt2img_args )
2023-01-28 13:41:15 +08:00
2023-03-21 11:49:19 +08:00
res_switch_btn . click ( lambda w , h : ( h , w ) , inputs = [ width , height ] , outputs = [ width , height ] , show_progress = False )
2022-09-03 17:08:45 +08:00
2023-04-30 03:16:54 +08:00
restore_progress_button . click (
fn = progress . restore_progress ,
_js = " restoreProgressTxt2img " ,
inputs = [ dummy_component ] ,
outputs = [
txt2img_gallery ,
generation_info ,
html_info ,
html_log ,
] ,
show_progress = False ,
)
2022-10-13 07:17:26 +08:00
txt_prompt_img . change (
fn = modules . images . image_data ,
inputs = [
txt_prompt_img
] ,
outputs = [
txt2img_prompt ,
txt_prompt_img
]
)
2022-09-19 21:42:56 +08:00
enable_hr . change (
fn = lambda x : gr_show ( x ) ,
inputs = [ enable_hr ] ,
outputs = [ hr_options ] ,
2023-01-07 14:56:37 +08:00
show_progress = False ,
2022-09-19 21:42:56 +08:00
)
2022-09-25 14:25:28 +08:00
txt2img_paste_fields = [
( txt2img_prompt , " Prompt " ) ,
( txt2img_negative_prompt , " Negative prompt " ) ,
( steps , " Steps " ) ,
( sampler_index , " Sampler " ) ,
( restore_faces , " Face restoration " ) ,
( cfg_scale , " CFG scale " ) ,
( seed , " Seed " ) ,
( width , " Size-1 " ) ,
( height , " Size-2 " ) ,
( batch_size , " Batch size " ) ,
( subseed , " Variation seed " ) ,
( subseed_strength , " Variation seed strength " ) ,
( seed_resize_from_w , " Seed resize from-1 " ) ,
( seed_resize_from_h , " Seed resize from-2 " ) ,
( denoising_strength , " Denoising strength " ) ,
( enable_hr , lambda d : " Denoising strength " in d ) ,
( hr_options , lambda d : gr . Row . update ( visible = " Denoising strength " in d ) ) ,
2023-01-03 00:42:10 +08:00
( hr_scale , " Hires upscale " ) ,
( hr_upscaler , " Hires upscaler " ) ,
2023-01-05 03:04:40 +08:00
( hr_second_pass_steps , " Hires steps " ) ,
( hr_resize_x , " Hires resize-1 " ) ,
( hr_resize_y , " Hires resize-2 " ) ,
2022-10-22 17:23:45 +08:00
* modules . scripts . scripts_txt2img . infotext_fields
2022-09-25 14:25:28 +08:00
]
2023-02-19 14:30:49 +08:00
parameters_copypaste . add_paste_fields ( " txt2img " , None , txt2img_paste_fields , override_settings )
2023-01-30 05:25:30 +08:00
parameters_copypaste . register_paste_params_button ( parameters_copypaste . ParamBinding (
2023-02-19 14:30:49 +08:00
paste_button = txt2img_paste , tabname = " txt2img " , source_text_component = txt2img_prompt , source_image_component = None ,
2023-01-30 05:25:30 +08:00
) )
2022-10-15 01:31:49 +08:00
txt2img_preview_params = [
txt2img_prompt ,
txt2img_negative_prompt ,
steps ,
sampler_index ,
cfg_scale ,
seed ,
width ,
height ,
]
2022-11-28 14:00:10 +08:00
token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ txt2img_prompt , steps ] , outputs = [ token_counter ] )
2023-01-20 15:18:41 +08:00
negative_token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ txt2img_negative_prompt , steps ] , outputs = [ negative_token_counter ] )
2022-09-24 03:49:21 +08:00
2023-01-21 13:36:07 +08:00
ui_extra_networks . setup_ui ( extra_networks_ui , txt2img_gallery )
2022-11-20 00:10:17 +08:00
modules . scripts . scripts_current = modules . scripts . scripts_img2img
modules . scripts . scripts_img2img . initialize_scripts ( is_img2img = True )
2022-09-24 03:49:21 +08:00
2022-09-03 17:08:45 +08:00
with gr . Blocks ( analytics_enabled = False ) as img2img_interface :
2023-04-30 03:16:54 +08:00
img2img_prompt , img2img_prompt_styles , img2img_negative_prompt , submit , img2img_interrogate , img2img_deepbooru , img2img_prompt_style_apply , img2img_save_style , img2img_paste , extra_networks_button , token_counter , token_button , negative_token_counter , negative_token_button , restore_progress_button = create_toprow ( is_img2img = True )
2022-09-03 17:08:45 +08:00
2023-01-19 04:04:24 +08:00
img2img_prompt_img = gr . File ( label = " " , elem_id = " img2img_prompt_image " , file_count = " single " , type = " binary " , visible = False )
2022-09-22 17:11:48 +08:00
2023-01-21 13:36:07 +08:00
with FormRow ( variant = ' compact ' , elem_id = " img2img_extra_networks " , visible = False ) as extra_networks :
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks . create_ui ( extra_networks , extra_networks_button , ' img2img ' )
2023-01-03 14:04:29 +08:00
with FormRow ( ) . style ( equal_height = False ) :
2023-01-14 18:38:10 +08:00
with gr . Column ( variant = ' compact ' , elem_id = " img2img_settings " ) :
2023-01-15 03:43:01 +08:00
copy_image_buttons = [ ]
copy_image_destinations = { }
def add_copy_image_controls ( tab_name , elem ) :
with gr . Row ( variant = " compact " , elem_id = f " img2img_copy_to_ { tab_name } " ) :
gr . HTML ( " Copy image to: " , elem_id = f " img2img_label_copy_to_ { tab_name } " )
for title , name in zip ( [ ' img2img ' , ' sketch ' , ' inpaint ' , ' inpaint sketch ' ] , [ ' img2img ' , ' sketch ' , ' inpaint ' , ' inpaint_sketch ' ] ) :
if name == tab_name :
gr . Button ( title , interactive = False )
copy_image_destinations [ name ] = elem
continue
button = gr . Button ( title )
copy_image_buttons . append ( ( button , name , elem ) )
2023-01-12 01:33:24 +08:00
with gr . Tabs ( elem_id = " mode_img2img " ) :
2023-03-29 03:23:40 +08:00
img2img_selected_tab = gr . State ( 0 )
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' img2img ' , id = ' img2img ' , elem_id = " img2img_img2img_tab " ) as tab_img2img :
init_img = gr . Image ( label = " Image for img2img " , elem_id = " img2img_image " , show_label = False , source = " upload " , interactive = True , type = " pil " , tool = " editor " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-15 03:43:01 +08:00
add_copy_image_controls ( ' img2img ' , init_img )
2022-09-10 00:43:16 +08:00
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' Sketch ' , id = ' img2img_sketch ' , elem_id = " img2img_img2img_sketch_tab " ) as tab_sketch :
sketch = gr . Image ( label = " Image for img2img " , elem_id = " img2img_sketch " , show_label = False , source = " upload " , interactive = True , type = " pil " , tool = " color-sketch " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-15 03:43:01 +08:00
add_copy_image_controls ( ' sketch ' , sketch )
2022-09-03 17:08:45 +08:00
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' Inpaint ' , id = ' inpaint ' , elem_id = " img2img_inpaint_tab " ) as tab_inpaint :
init_img_with_mask = gr . Image ( label = " Image for inpainting with mask " , show_label = False , elem_id = " img2maskimg " , source = " upload " , interactive = True , type = " pil " , tool = " sketch " , image_mode = " RGBA " ) . style ( height = 480 )
2023-01-15 03:43:01 +08:00
add_copy_image_controls ( ' inpaint ' , init_img_with_mask )
2022-09-22 17:11:48 +08:00
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' Inpaint sketch ' , id = ' inpaint_sketch ' , elem_id = " img2img_inpaint_sketch_tab " ) as tab_inpaint_color :
inpaint_color_sketch = gr . Image ( label = " Color sketch inpainting " , show_label = False , elem_id = " inpaint_sketch " , source = " upload " , interactive = True , type = " pil " , tool = " color-sketch " , image_mode = " RGBA " ) . style ( height = 480 )
inpaint_color_sketch_orig = gr . State ( None )
2023-01-15 03:43:01 +08:00
add_copy_image_controls ( ' inpaint_sketch ' , inpaint_color_sketch )
2022-09-03 17:08:45 +08:00
2023-01-12 01:33:24 +08:00
def update_orig ( image , state ) :
if image is not None :
same_size = state is not None and state . size == image . size
has_exact_match = np . any ( np . all ( np . array ( image ) == np . array ( state ) , axis = - 1 ) )
edited = same_size and has_exact_match
return image if not edited or state is None else state
2022-09-22 17:11:48 +08:00
2023-01-12 01:33:24 +08:00
inpaint_color_sketch . change ( update_orig , [ inpaint_color_sketch , inpaint_color_sketch_orig ] , inpaint_color_sketch_orig )
2022-09-22 17:11:48 +08:00
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' Inpaint upload ' , id = ' inpaint_upload ' , elem_id = " img2img_inpaint_upload_tab " ) as tab_inpaint_upload :
init_img_inpaint = gr . Image ( label = " Image for img2img " , show_label = False , source = " upload " , interactive = True , type = " pil " , elem_id = " img_inpaint_base " )
init_mask_inpaint = gr . Image ( label = " Mask " , source = " upload " , interactive = True , type = " pil " , elem_id = " img_inpaint_mask " )
2022-09-22 17:11:48 +08:00
2023-01-12 01:33:24 +08:00
with gr . TabItem ( ' Batch ' , id = ' batch ' , elem_id = " img2img_batch_tab " ) as tab_batch :
2022-09-24 21:29:20 +08:00
hidden = ' <br>Disabled when launched with --hide-ui-dir-config. ' if shared . cmd_opts . hide_ui_dir_config else ' '
2023-01-28 09:32:31 +08:00
gr . HTML (
f " <p style= ' padding-bottom: 1em; ' class= \" text-gray-500 \" >Process images in a directory on the same machine where the server is running. " +
f " <br>Use an empty output directory to save pictures normally instead of writing to the output directory. " +
f " <br>Add inpaint batch mask directory to enable inpaint batch processing. "
f " { hidden } </p> "
)
2023-01-01 21:51:12 +08:00
img2img_batch_input_dir = gr . Textbox ( label = " Input directory " , * * shared . hide_dirs , elem_id = " img2img_batch_input_dir " )
img2img_batch_output_dir = gr . Textbox ( label = " Output directory " , * * shared . hide_dirs , elem_id = " img2img_batch_output_dir " )
2023-01-28 09:32:31 +08:00
img2img_batch_inpaint_mask_dir = gr . Textbox ( label = " Inpaint batch mask directory (required for inpaint batch processing only) " , * * shared . hide_dirs , elem_id = " img2img_batch_inpaint_mask_dir " )
2022-09-03 17:08:45 +08:00
2023-03-29 03:23:40 +08:00
img2img_tabs = [ tab_img2img , tab_sketch , tab_inpaint , tab_inpaint_color , tab_inpaint_upload , tab_batch ]
img2img_image_inputs = [ init_img , sketch , init_img_with_mask , inpaint_color_sketch ]
for i , tab in enumerate ( img2img_tabs ) :
tab . select ( fn = lambda tabnum = i : tabnum , inputs = [ ] , outputs = [ img2img_selected_tab ] )
2023-01-15 03:43:01 +08:00
def copy_image ( img ) :
if isinstance ( img , dict ) and ' image ' in img :
return img [ ' image ' ]
return img
for button , name , elem in copy_image_buttons :
button . click (
fn = copy_image ,
inputs = [ elem ] ,
outputs = [ copy_image_destinations [ name ] ] ,
)
button . click (
fn = lambda : None ,
_js = " switch_to_ " + name . replace ( " " , " _ " ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2023-01-03 14:04:29 +08:00
with FormRow ( ) :
2023-03-29 01:36:57 +08:00
resize_mode = gr . Radio ( label = " Resize mode " , elem_id = " resize_mode " , choices = [ " Just resize " , " Crop and resize " , " Resize and fill " , " Just resize (latent upscale) " ] , type = " index " , value = " Just resize " )
2022-09-22 17:11:48 +08:00
2023-01-03 15:39:21 +08:00
for category in ordered_ui_categories ( ) :
if category == " sampler " :
steps , sampler_index = create_sampler_and_steps_selection ( samplers_for_img2img , " img2img " )
2022-09-04 02:02:38 +08:00
2023-01-03 15:39:21 +08:00
elif category == " dimensions " :
with FormRow ( ) :
with gr . Column ( elem_id = " img2img_column_size " , scale = 4 ) :
2023-03-29 03:23:40 +08:00
selected_scale_tab = gr . State ( value = 0 )
with gr . Tabs ( ) :
with gr . Tab ( label = " Resize to " ) as tab_scale_to :
2023-04-29 23:20:11 +08:00
with FormRow ( ) :
with gr . Column ( elem_id = " img2img_column_size " , scale = 4 ) :
width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " img2img_width " )
height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " img2img_height " )
with gr . Column ( elem_id = " img2img_dimensions_row " , scale = 1 , elem_classes = " dimensions-tools " ) :
res_switch_btn = ToolButton ( value = switch_values_symbol , elem_id = " img2img_res_switch_btn " )
2023-03-29 03:23:40 +08:00
with gr . Tab ( label = " Resize by " ) as tab_scale_by :
scale_by = gr . Slider ( minimum = 0.05 , maximum = 4.0 , step = 0.05 , label = " Scale " , value = 1.0 , elem_id = " img2img_scale " )
with FormRow ( ) :
scale_by_html = FormHTML ( resize_from_to_html ( 0 , 0 , 0.0 ) , elem_id = " img2img_scale_resolution_preview " )
gr . Slider ( label = " Unused " , elem_id = " img2img_unused_scale_by_slider " )
2023-04-30 00:39:22 +08:00
button_update_resize_to = gr . Button ( visible = False , elem_id = " img2img_update_resize_to " )
2023-03-29 03:23:40 +08:00
2023-04-30 00:39:22 +08:00
on_change_args = dict (
2023-03-29 03:23:40 +08:00
fn = resize_from_to_html ,
_js = " currentImg2imgSourceResolution " ,
inputs = [ dummy_component , dummy_component , scale_by ] ,
outputs = scale_by_html ,
show_progress = False ,
)
2023-04-30 00:39:22 +08:00
scale_by . release ( * * on_change_args )
button_update_resize_to . click ( * * on_change_args )
2023-05-01 18:47:46 +08:00
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
for component in [ init_img , sketch ] :
2023-04-30 00:39:22 +08:00
component . change ( fn = lambda : None , _js = " updateImg2imgResizeToTextAfterChangingImage " , inputs = [ ] , outputs = [ ] , show_progress = False )
2023-03-29 03:23:40 +08:00
tab_scale_to . select ( fn = lambda : 0 , inputs = [ ] , outputs = [ selected_scale_tab ] )
tab_scale_by . select ( fn = lambda : 1 , inputs = [ ] , outputs = [ selected_scale_tab ] )
2023-03-20 21:09:36 +08:00
2023-01-03 15:39:21 +08:00
if opts . dimensions_and_batch_together :
with gr . Column ( elem_id = " img2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " img2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " img2img_batch_size " )
2022-09-03 17:08:45 +08:00
2023-01-03 15:39:21 +08:00
elif category == " cfg " :
with FormGroup ( ) :
2023-02-04 16:18:44 +08:00
with FormRow ( ) :
cfg_scale = gr . Slider ( minimum = 1.0 , maximum = 30.0 , step = 0.5 , label = ' CFG Scale ' , value = 7.0 , elem_id = " img2img_cfg_scale " )
2023-05-02 14:08:00 +08:00
image_cfg_scale = gr . Slider ( minimum = 0 , maximum = 3.0 , step = 0.05 , label = ' Image CFG Scale ' , value = 1.5 , elem_id = " img2img_image_cfg_scale " , visible = False )
2023-03-29 01:36:57 +08:00
denoising_strength = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising strength ' , value = 0.75 , elem_id = " img2img_denoising_strength " )
2022-09-03 17:08:45 +08:00
2023-01-03 15:39:21 +08:00
elif category == " seed " :
seed , reuse_seed , subseed , reuse_subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox = create_seed_inputs ( ' img2img ' )
2022-09-03 17:08:45 +08:00
2023-01-03 15:39:21 +08:00
elif category == " checkboxes " :
2023-03-20 21:09:36 +08:00
with FormRow ( elem_classes = " checkboxes-row " , variant = " compact " ) :
2023-01-03 15:39:21 +08:00
restore_faces = gr . Checkbox ( label = ' Restore faces ' , value = False , visible = len ( shared . face_restorers ) > 1 , elem_id = " img2img_restore_faces " )
tiling = gr . Checkbox ( label = ' Tiling ' , value = False , elem_id = " img2img_tiling " )
2022-09-03 17:08:45 +08:00
2023-01-03 15:39:21 +08:00
elif category == " batch " :
if not opts . dimensions_and_batch_together :
with FormRow ( elem_id = " img2img_column_batch " ) :
batch_count = gr . Slider ( minimum = 1 , step = 1 , label = ' Batch count ' , value = 1 , elem_id = " img2img_batch_count " )
batch_size = gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size ' , value = 1 , elem_id = " img2img_batch_size " )
2022-09-03 22:21:15 +08:00
2023-01-30 05:25:30 +08:00
elif category == " override_settings " :
with FormRow ( elem_id = " img2img_override_settings_row " ) as row :
override_settings = create_override_settings_dropdown ( ' img2img ' , row )
2023-01-03 15:39:21 +08:00
elif category == " scripts " :
with FormGroup ( elem_id = " img2img_script_container " ) :
custom_inputs = modules . scripts . scripts_img2img . setup_ui ( )
2022-09-17 17:38:15 +08:00
2023-01-16 04:32:38 +08:00
elif category == " inpaint " :
2023-01-15 07:26:45 +08:00
with FormGroup ( elem_id = " inpaint_controls " , visible = False ) as inpaint_controls :
with FormRow ( ) :
mask_blur = gr . Slider ( label = ' Mask blur ' , minimum = 0 , maximum = 64 , step = 1 , value = 4 , elem_id = " img2img_mask_blur " )
mask_alpha = gr . Slider ( label = " Mask transparency " , visible = False , elem_id = " img2img_mask_alpha " )
with FormRow ( ) :
inpainting_mask_invert = gr . Radio ( label = ' Mask mode ' , choices = [ ' Inpaint masked ' , ' Inpaint not masked ' ] , value = ' Inpaint masked ' , type = " index " , elem_id = " img2img_mask_mode " )
with FormRow ( ) :
inpainting_fill = gr . Radio ( label = ' Masked content ' , choices = [ ' fill ' , ' original ' , ' latent noise ' , ' latent nothing ' ] , value = ' original ' , type = " index " , elem_id = " img2img_inpainting_fill " )
with FormRow ( ) :
with gr . Column ( ) :
inpaint_full_res = gr . Radio ( label = " Inpaint area " , choices = [ " Whole picture " , " Only masked " ] , type = " index " , value = " Whole picture " , elem_id = " img2img_inpaint_full_res " )
with gr . Column ( scale = 4 ) :
inpaint_full_res_padding = gr . Slider ( label = ' Only masked padding, pixels ' , minimum = 0 , maximum = 256 , step = 4 , value = 32 , elem_id = " img2img_inpaint_full_res_padding " )
def select_img2img_tab ( tab ) :
return gr . update ( visible = tab in [ 2 , 3 , 4 ] ) , gr . update ( visible = tab == 3 ) ,
2023-03-29 03:23:40 +08:00
for i , elem in enumerate ( img2img_tabs ) :
2023-01-15 07:26:45 +08:00
elem . select (
fn = lambda tab = i : select_img2img_tab ( tab ) ,
inputs = [ ] ,
outputs = [ inpaint_controls , mask_alpha ] ,
)
2023-01-01 04:40:55 +08:00
img2img_gallery , generation_info , html_info , html_log = create_output_panel ( " img2img " , opts . outdir_img2img_samples )
2022-09-03 17:08:45 +08:00
2022-09-19 14:02:10 +08:00
connect_reuse_seed ( seed , reuse_seed , generation_info , dummy_component , is_subseed = False )
connect_reuse_seed ( subseed , reuse_subseed , generation_info , dummy_component , is_subseed = True )
2022-09-17 03:20:56 +08:00
2022-10-13 07:17:26 +08:00
img2img_prompt_img . change (
fn = modules . images . image_data ,
inputs = [
2022-10-14 23:15:03 +08:00
img2img_prompt_img
2022-10-13 07:17:26 +08:00
] ,
outputs = [
img2img_prompt ,
img2img_prompt_img
]
)
2022-09-03 17:08:45 +08:00
img2img_args = dict (
2023-01-01 04:40:55 +08:00
fn = wrap_gradio_gpu_call ( modules . img2img . img2img , extra_outputs = [ None , ' ' , ' ' ] ) ,
2022-09-22 17:11:48 +08:00
_js = " submit_img2img " ,
2022-09-03 17:08:45 +08:00
inputs = [
2023-01-15 23:50:56 +08:00
dummy_component ,
2022-09-22 17:11:48 +08:00
dummy_component ,
2022-09-10 04:16:02 +08:00
img2img_prompt ,
2022-09-11 22:35:12 +08:00
img2img_negative_prompt ,
2023-01-14 19:56:39 +08:00
img2img_prompt_styles ,
2022-09-03 17:08:45 +08:00
init_img ,
2023-01-12 01:33:24 +08:00
sketch ,
2022-09-03 17:08:45 +08:00
init_img_with_mask ,
2023-01-12 01:33:24 +08:00
inpaint_color_sketch ,
inpaint_color_sketch_orig ,
2022-09-22 17:11:48 +08:00
init_img_inpaint ,
init_mask_inpaint ,
2022-09-03 17:08:45 +08:00
steps ,
sampler_index ,
mask_blur ,
2022-11-09 11:06:29 +08:00
mask_alpha ,
2022-09-03 17:08:45 +08:00
inpainting_fill ,
2022-09-07 17:32:28 +08:00
restore_faces ,
2022-09-05 08:25:37 +08:00
tiling ,
2022-09-03 17:08:45 +08:00
batch_count ,
batch_size ,
cfg_scale ,
2023-02-04 07:19:56 +08:00
image_cfg_scale ,
2022-09-03 17:08:45 +08:00
denoising_strength ,
seed ,
2022-09-21 18:34:10 +08:00
subseed , subseed_strength , seed_resize_from_h , seed_resize_from_w , seed_checkbox ,
2023-03-29 03:23:40 +08:00
selected_scale_tab ,
2022-09-03 17:08:45 +08:00
height ,
width ,
2023-03-29 03:23:40 +08:00
scale_by ,
2022-09-03 17:08:45 +08:00
resize_mode ,
inpaint_full_res ,
2022-09-22 17:11:48 +08:00
inpaint_full_res_padding ,
2022-09-04 02:02:38 +08:00
inpainting_mask_invert ,
2022-09-22 17:11:48 +08:00
img2img_batch_input_dir ,
img2img_batch_output_dir ,
2023-01-30 07:40:26 +08:00
img2img_batch_inpaint_mask_dir ,
override_settings ,
2022-09-03 22:21:15 +08:00
] + custom_inputs ,
2022-09-03 17:08:45 +08:00
outputs = [
img2img_gallery ,
generation_info ,
2023-01-01 04:40:55 +08:00
html_info ,
html_log ,
2022-09-18 16:14:42 +08:00
] ,
show_progress = False ,
2022-09-03 17:08:45 +08:00
)
2023-01-19 01:16:52 +08:00
interrogate_args = dict (
_js = " get_img2img_tab_index " ,
inputs = [
dummy_component ,
img2img_batch_input_dir ,
img2img_batch_output_dir ,
init_img ,
sketch ,
init_img_with_mask ,
inpaint_color_sketch ,
init_img_inpaint ,
] ,
outputs = [ img2img_prompt , dummy_component ] ,
)
2022-09-10 04:16:02 +08:00
img2img_prompt . submit ( * * img2img_args )
2022-09-03 17:08:45 +08:00
submit . click ( * * img2img_args )
2023-03-21 11:49:19 +08:00
res_switch_btn . click ( lambda w , h : ( h , w ) , inputs = [ width , height ] , outputs = [ width , height ] , show_progress = False )
2022-09-03 17:08:45 +08:00
2023-04-30 03:16:54 +08:00
restore_progress_button . click (
fn = progress . restore_progress ,
_js = " restoreProgressImg2img " ,
inputs = [ dummy_component ] ,
outputs = [
img2img_gallery ,
generation_info ,
html_info ,
html_log ,
] ,
show_progress = False ,
)
2022-09-11 23:48:36 +08:00
img2img_interrogate . click (
2023-01-21 14:14:27 +08:00
fn = lambda * args : process_interrogate ( interrogate , * args ) ,
2023-01-19 01:16:52 +08:00
* * interrogate_args ,
2022-09-11 23:48:36 +08:00
)
2022-11-26 21:10:46 +08:00
img2img_deepbooru . click (
2023-01-21 14:14:27 +08:00
fn = lambda * args : process_interrogate ( interrogate_deepbooru , * args ) ,
2023-01-19 01:16:52 +08:00
* * interrogate_args ,
2022-09-14 22:56:21 +08:00
)
prompts = [ ( txt2img_prompt , txt2img_negative_prompt ) , ( img2img_prompt , img2img_negative_prompt ) ]
2023-01-14 19:56:39 +08:00
style_dropdowns = [ txt2img_prompt_styles , img2img_prompt_styles ]
2022-10-01 00:12:44 +08:00
style_js_funcs = [ " update_txt2img_tokens " , " update_img2img_tokens " ]
2022-09-14 22:56:21 +08:00
for button , ( prompt , negative_prompt ) in zip ( [ txt2img_save_style , img2img_save_style ] , prompts ) :
2022-09-10 04:16:02 +08:00
button . click (
fn = add_style ,
_js = " ask_for_style_name " ,
2022-09-11 22:35:12 +08:00
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
# the same number of parameters, but we only know the style-name after the JavaScript prompt
inputs = [ dummy_component , prompt , negative_prompt ] ,
2023-01-14 19:56:39 +08:00
outputs = [ txt2img_prompt_styles , img2img_prompt_styles ] ,
2022-09-14 22:56:21 +08:00
)
2023-01-14 19:56:39 +08:00
for button , ( prompt , negative_prompt ) , styles , js_func in zip ( [ txt2img_prompt_style_apply , img2img_prompt_style_apply ] , prompts , style_dropdowns , style_js_funcs ) :
2022-09-14 22:56:21 +08:00
button . click (
fn = apply_styles ,
2022-09-30 02:40:47 +08:00
_js = js_func ,
2023-01-14 19:56:39 +08:00
inputs = [ prompt , negative_prompt , styles ] ,
outputs = [ prompt , negative_prompt , styles ] ,
2022-09-10 04:16:02 +08:00
)
2022-10-27 13:36:11 +08:00
token_button . click ( fn = update_token_counter , inputs = [ img2img_prompt , steps ] , outputs = [ token_counter ] )
2023-02-21 04:39:38 +08:00
negative_token_button . click ( fn = wrap_queued_call ( update_token_counter ) , inputs = [ img2img_negative_prompt , steps ] , outputs = [ negative_token_counter ] )
2022-10-27 13:36:11 +08:00
2023-01-21 13:36:07 +08:00
ui_extra_networks . setup_ui ( extra_networks_ui_img2img , img2img_gallery )
2022-09-25 14:25:28 +08:00
img2img_paste_fields = [
( img2img_prompt , " Prompt " ) ,
( img2img_negative_prompt , " Negative prompt " ) ,
( steps , " Steps " ) ,
( sampler_index , " Sampler " ) ,
( restore_faces , " Face restoration " ) ,
( cfg_scale , " CFG scale " ) ,
2023-02-04 07:19:56 +08:00
( image_cfg_scale , " Image CFG scale " ) ,
2022-09-25 14:25:28 +08:00
( seed , " Seed " ) ,
( width , " Size-1 " ) ,
( height , " Size-2 " ) ,
( batch_size , " Batch size " ) ,
( subseed , " Variation seed " ) ,
( subseed_strength , " Variation seed strength " ) ,
( seed_resize_from_w , " Seed resize from-1 " ) ,
( seed_resize_from_h , " Seed resize from-2 " ) ,
( denoising_strength , " Denoising strength " ) ,
2022-11-27 21:35:35 +08:00
( mask_blur , " Mask blur " ) ,
2022-10-22 17:23:45 +08:00
* modules . scripts . scripts_img2img . infotext_fields
2022-09-25 14:25:28 +08:00
]
2023-02-19 14:30:49 +08:00
parameters_copypaste . add_paste_fields ( " img2img " , init_img , img2img_paste_fields , override_settings )
parameters_copypaste . add_paste_fields ( " inpaint " , init_img_with_mask , img2img_paste_fields , override_settings )
2023-01-30 05:25:30 +08:00
parameters_copypaste . register_paste_params_button ( parameters_copypaste . ParamBinding (
2023-02-19 14:30:49 +08:00
paste_button = img2img_paste , tabname = " img2img " , source_text_component = img2img_prompt , source_image_component = None ,
2023-01-30 05:25:30 +08:00
) )
2022-09-24 03:49:21 +08:00
2022-11-20 00:10:17 +08:00
modules . scripts . scripts_current = None
2022-09-24 03:49:21 +08:00
2022-09-03 17:08:45 +08:00
with gr . Blocks ( analytics_enabled = False ) as extras_interface :
2023-01-23 14:24:43 +08:00
ui_postprocessing . create_ui ( )
2022-09-03 17:08:45 +08:00
2022-09-24 03:49:21 +08:00
with gr . Blocks ( analytics_enabled = False ) as pnginfo_interface :
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' panel ' ) :
image = gr . Image ( elem_id = " pnginfo_image " , label = " Source " , source = " upload " , interactive = True , type = " pil " )
with gr . Column ( variant = ' panel ' ) :
html = gr . HTML ( )
2023-01-01 21:51:12 +08:00
generation_info = gr . Textbox ( visible = False , elem_id = " pnginfo_generation_info " )
2022-09-24 03:49:21 +08:00
html2 = gr . HTML ( )
with gr . Row ( ) :
2022-10-27 13:36:11 +08:00
buttons = parameters_copypaste . create_buttons ( [ " txt2img " , " img2img " , " inpaint " , " extras " ] )
2023-01-30 05:25:30 +08:00
for tabname , button in buttons . items ( ) :
parameters_copypaste . register_paste_params_button ( parameters_copypaste . ParamBinding (
paste_button = button , tabname = tabname , source_text_component = generation_info , source_image_component = image ,
) )
2022-09-24 03:49:21 +08:00
image . change (
2022-10-02 20:03:39 +08:00
fn = wrap_gradio_call ( modules . extras . run_pnginfo ) ,
2022-09-24 03:49:21 +08:00
inputs = [ image ] ,
outputs = [ html , generation_info , html2 ] ,
)
2022-10-29 13:28:48 +08:00
2023-01-20 13:48:15 +08:00
def update_interp_description ( value ) :
interp_description_css = " <p style= ' margin-bottom: 2.5em ' > {} </p> "
interp_descriptions = {
" No interpolation " : interp_description_css . format ( " No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking. " ) ,
" Weighted sum " : interp_description_css . format ( " A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M " ) ,
" Add difference " : interp_description_css . format ( " The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M " )
}
return interp_descriptions [ value ]
2022-11-06 19:39:41 +08:00
with gr . Blocks ( analytics_enabled = False ) as modelmerger_interface :
2022-09-26 07:22:12 +08:00
with gr . Row ( ) . style ( equal_height = False ) :
2023-01-14 18:38:10 +08:00
with gr . Column ( variant = ' compact ' ) :
2023-01-20 13:48:15 +08:00
interp_description = gr . HTML ( value = update_interp_description ( " Weighted sum " ) , elem_id = " modelmerger_interp_description " )
2022-10-02 20:03:39 +08:00
2023-01-19 15:39:51 +08:00
with FormRow ( elem_id = " modelmerger_models " ) :
2022-10-14 14:05:06 +08:00
primary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_primary_model_name " , label = " Primary model (A) " )
2023-01-01 15:35:38 +08:00
create_refresh_button ( primary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_A " )
2022-10-14 14:05:06 +08:00
secondary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_secondary_model_name " , label = " Secondary model (B) " )
2023-01-01 15:35:38 +08:00
create_refresh_button ( secondary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_B " )
2022-10-14 14:05:06 +08:00
tertiary_model_name = gr . Dropdown ( modules . sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_tertiary_model_name " , label = " Tertiary model (C) " )
2023-01-01 15:35:38 +08:00
create_refresh_button ( tertiary_model_name , modules . sd_models . list_models , lambda : { " choices " : modules . sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_C " )
2023-01-01 21:51:12 +08:00
custom_name = gr . Textbox ( label = " Custom Name (Optional) " , elem_id = " modelmerger_custom_name " )
interp_amount = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.05 , label = ' Multiplier (M) - set to 0 to get model A ' , value = 0.3 , elem_id = " modelmerger_interp_amount " )
2023-01-19 15:39:51 +08:00
interp_method = gr . Radio ( choices = [ " No interpolation " , " Weighted sum " , " Add difference " ] , value = " Weighted sum " , label = " Interpolation Method " , elem_id = " modelmerger_interp_method " )
2023-01-20 13:48:15 +08:00
interp_method . change ( fn = update_interp_description , inputs = [ interp_method ] , outputs = [ interp_description ] )
2022-11-27 20:51:29 +08:00
2023-01-11 14:10:07 +08:00
with FormRow ( ) :
2023-04-03 06:41:55 +08:00
checkpoint_format = gr . Radio ( choices = [ " ckpt " , " safetensors " ] , value = " safetensors " , label = " Checkpoint format " , elem_id = " modelmerger_checkpoint_format " )
2023-01-01 21:51:12 +08:00
save_as_half = gr . Checkbox ( value = False , label = " Save as float16 " , elem_id = " modelmerger_save_as_half " )
2023-04-03 06:41:55 +08:00
save_metadata = gr . Checkbox ( value = True , label = " Save metadata (.safetensors only) " , elem_id = " modelmerger_save_metadata " )
2022-11-27 20:51:29 +08:00
2023-01-19 15:39:51 +08:00
with FormRow ( ) :
with gr . Column ( ) :
config_source = gr . Radio ( choices = [ " A, B or C " , " B " , " C " , " Don ' t " ] , value = " A, B or C " , label = " Copy config from " , type = " index " , elem_id = " modelmerger_config_method " )
with gr . Column ( ) :
with FormRow ( ) :
bake_in_vae = gr . Dropdown ( choices = [ " None " ] + list ( sd_vae . vae_dict ) , value = " None " , label = " Bake in VAE " , elem_id = " modelmerger_bake_in_vae " )
create_refresh_button ( bake_in_vae , sd_vae . refresh_vae_list , lambda : { " choices " : [ " None " ] + list ( sd_vae . vae_dict ) } , " modelmerger_refresh_bake_in_vae " )
2023-01-11 14:10:07 +08:00
2023-01-22 15:17:12 +08:00
with FormRow ( ) :
discard_weights = gr . Textbox ( value = " " , label = " Discard weights with matching name " , elem_id = " modelmerger_discard_weights " )
2023-01-14 18:38:10 +08:00
with gr . Row ( ) :
modelmerger_merge = gr . Button ( elem_id = " modelmerger_merge " , value = " Merge " , variant = ' primary ' )
2022-10-02 20:03:39 +08:00
2023-01-19 14:25:37 +08:00
with gr . Column ( variant = ' compact ' , elem_id = " modelmerger_results_container " ) :
with gr . Group ( elem_id = " modelmerger_results_panel " ) :
modelmerger_result = gr . HTML ( elem_id = " modelmerger_result " , show_label = False )
2022-09-26 07:22:12 +08:00
2022-11-06 19:39:41 +08:00
with gr . Blocks ( analytics_enabled = False ) as train_interface :
2022-10-02 20:03:39 +08:00
with gr . Row ( ) . style ( equal_height = False ) :
2022-10-12 16:05:57 +08:00
gr . HTML ( value = " <p style= ' margin-bottom: 0.7em ' >See <b><a href= \" https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion \" >wiki</a></b> for detailed explanation.</p> " )
2022-10-03 03:41:21 +08:00
2023-01-22 04:40:13 +08:00
with gr . Row ( variant = " compact " ) . style ( equal_height = False ) :
2022-10-12 16:05:57 +08:00
with gr . Tabs ( elem_id = " train_tabs " ) :
2022-10-02 20:03:39 +08:00
2023-03-30 02:04:02 +08:00
with gr . Tab ( label = " Create embedding " , id = " create_embedding " ) :
2023-01-01 21:51:12 +08:00
new_embedding_name = gr . Textbox ( label = " Name " , elem_id = " train_new_embedding_name " )
initialization_text = gr . Textbox ( label = " Initialization text " , value = " * " , elem_id = " train_initialization_text " )
nvpt = gr . Slider ( label = " Number of vectors per token " , minimum = 1 , maximum = 75 , step = 1 , value = 1 , elem_id = " train_nvpt " )
overwrite_old_embedding = gr . Checkbox ( value = False , label = " Overwrite Old Embedding " , elem_id = " train_overwrite_old_embedding " )
2022-10-02 20:03:39 +08:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2023-01-01 21:51:12 +08:00
create_embedding = gr . Button ( value = " Create embedding " , variant = ' primary ' , elem_id = " train_create_embedding " )
2022-10-02 20:03:39 +08:00
2023-03-30 02:04:02 +08:00
with gr . Tab ( label = " Create hypernetwork " , id = " create_hypernetwork " ) :
2023-01-01 21:51:12 +08:00
new_hypernetwork_name = gr . Textbox ( label = " Name " , elem_id = " train_new_hypernetwork_name " )
new_hypernetwork_sizes = gr . CheckboxGroup ( label = " Modules " , value = [ " 768 " , " 320 " , " 640 " , " 1280 " ] , choices = [ " 768 " , " 1024 " , " 320 " , " 640 " , " 1280 " ] , elem_id = " train_new_hypernetwork_sizes " )
new_hypernetwork_layer_structure = gr . Textbox ( " 1, 2, 1 " , label = " Enter hypernetwork layer structure " , placeholder = " 1st and last digit must be 1. ex: ' 1, 2, 1 ' " , elem_id = " train_new_hypernetwork_layer_structure " )
new_hypernetwork_activation_func = gr . Dropdown ( value = " linear " , label = " Select activation function of hypernetwork. Recommended : Swish / Linear(none) " , choices = modules . hypernetworks . ui . keys , elem_id = " train_new_hypernetwork_activation_func " )
new_hypernetwork_initialization_option = gr . Dropdown ( value = " Normal " , label = " Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise " , choices = [ " Normal " , " KaimingUniform " , " KaimingNormal " , " XavierUniform " , " XavierNormal " ] , elem_id = " train_new_hypernetwork_initialization_option " )
new_hypernetwork_add_layer_norm = gr . Checkbox ( label = " Add layer normalization " , elem_id = " train_new_hypernetwork_add_layer_norm " )
new_hypernetwork_use_dropout = gr . Checkbox ( label = " Use dropout " , elem_id = " train_new_hypernetwork_use_dropout " )
2023-01-10 13:56:57 +08:00
new_hypernetwork_dropout_structure = gr . Textbox ( " 0, 0, 0 " , label = " Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15 " , placeholder = " 1st and last digit must be 0 and values should be between 0 and 1. ex: ' 0, 0.01, 0 ' " )
2023-01-01 21:51:12 +08:00
overwrite_old_hypernetwork = gr . Checkbox ( value = False , label = " Overwrite Old Hypernetwork " , elem_id = " train_overwrite_old_hypernetwork " )
2022-10-08 04:22:22 +08:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2023-01-01 21:51:12 +08:00
create_hypernetwork = gr . Button ( value = " Create hypernetwork " , variant = ' primary ' , elem_id = " train_create_hypernetwork " )
2022-10-02 20:03:39 +08:00
2023-03-30 02:04:02 +08:00
with gr . Tab ( label = " Preprocess images " , id = " preprocess_images " ) :
2023-01-01 21:51:12 +08:00
process_src = gr . Textbox ( label = ' Source directory ' , elem_id = " train_process_src " )
process_dst = gr . Textbox ( label = ' Destination directory ' , elem_id = " train_process_dst " )
process_width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " train_process_width " )
process_height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " train_process_height " )
preprocess_txt_action = gr . Dropdown ( label = ' Existing Caption txt Action ' , value = " ignore " , choices = [ " ignore " , " copy " , " prepend " , " append " ] , elem_id = " train_preprocess_txt_action " )
2022-10-03 03:41:21 +08:00
with gr . Row ( ) :
2023-03-25 22:45:41 +08:00
process_keep_original_size = gr . Checkbox ( label = ' Keep original size ' , elem_id = " train_process_keep_original_size " )
2023-01-01 21:51:12 +08:00
process_flip = gr . Checkbox ( label = ' Create flipped copies ' , elem_id = " train_process_flip " )
process_split = gr . Checkbox ( label = ' Split oversized images ' , elem_id = " train_process_split " )
process_focal_crop = gr . Checkbox ( label = ' Auto focal point crop ' , elem_id = " train_process_focal_crop " )
2023-01-17 17:16:43 +08:00
process_multicrop = gr . Checkbox ( label = ' Auto-sized crop ' , elem_id = " train_process_multicrop " )
2023-01-01 21:51:12 +08:00
process_caption = gr . Checkbox ( label = ' Use BLIP for caption ' , elem_id = " train_process_caption " )
process_caption_deepbooru = gr . Checkbox ( label = ' Use deepbooru for caption ' , visible = True , elem_id = " train_process_caption_deepbooru " )
2022-10-03 03:41:21 +08:00
2022-10-20 21:56:45 +08:00
with gr . Row ( visible = False ) as process_split_extra_row :
2023-01-01 21:51:12 +08:00
process_split_threshold = gr . Slider ( label = ' Split image threshold ' , value = 0.5 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_split_threshold " )
process_overlap_ratio = gr . Slider ( label = ' Split image overlap ratio ' , value = 0.2 , minimum = 0.0 , maximum = 0.9 , step = 0.05 , elem_id = " train_process_overlap_ratio " )
2022-10-20 21:56:45 +08:00
2022-10-26 06:22:29 +08:00
with gr . Row ( visible = False ) as process_focal_crop_row :
2023-01-01 21:51:12 +08:00
process_focal_crop_face_weight = gr . Slider ( label = ' Focal point face weight ' , value = 0.9 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_face_weight " )
process_focal_crop_entropy_weight = gr . Slider ( label = ' Focal point entropy weight ' , value = 0.15 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_entropy_weight " )
process_focal_crop_edges_weight = gr . Slider ( label = ' Focal point edges weight ' , value = 0.5 , minimum = 0.0 , maximum = 1.0 , step = 0.05 , elem_id = " train_process_focal_crop_edges_weight " )
process_focal_crop_debug = gr . Checkbox ( label = ' Create debug image ' , elem_id = " train_process_focal_crop_debug " )
2023-01-17 17:16:43 +08:00
with gr . Column ( visible = False ) as process_multicrop_col :
gr . Markdown ( ' Each image is center-cropped with an automatically chosen width and height. ' )
with gr . Row ( ) :
process_multicrop_mindim = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Dimension lower bound " , value = 384 , elem_id = " train_process_multicrop_mindim " )
process_multicrop_maxdim = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Dimension upper bound " , value = 768 , elem_id = " train_process_multicrop_maxdim " )
with gr . Row ( ) :
process_multicrop_minarea = gr . Slider ( minimum = 64 * 64 , maximum = 2048 * 2048 , step = 1 , label = " Area lower bound " , value = 64 * 64 , elem_id = " train_process_multicrop_minarea " )
process_multicrop_maxarea = gr . Slider ( minimum = 64 * 64 , maximum = 2048 * 2048 , step = 1 , label = " Area upper bound " , value = 640 * 640 , elem_id = " train_process_multicrop_maxarea " )
with gr . Row ( ) :
process_multicrop_objective = gr . Radio ( [ " Maximize area " , " Minimize error " ] , value = " Maximize area " , label = " Resizing objective " , elem_id = " train_process_multicrop_objective " )
process_multicrop_threshold = gr . Slider ( minimum = 0 , maximum = 1 , step = 0.01 , label = " Error threshold " , value = 0.1 , elem_id = " train_process_multicrop_threshold " )
2022-10-03 03:41:21 +08:00
with gr . Row ( ) :
with gr . Column ( scale = 3 ) :
gr . HTML ( value = " " )
with gr . Column ( ) :
2022-11-18 10:03:57 +08:00
with gr . Row ( ) :
2023-01-01 21:51:12 +08:00
interrupt_preprocessing = gr . Button ( " Interrupt " , elem_id = " train_interrupt_preprocessing " )
run_preprocess = gr . Button ( value = " Preprocess " , variant = ' primary ' , elem_id = " train_run_preprocess " )
2022-10-03 03:41:21 +08:00
2022-10-20 21:56:45 +08:00
process_split . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_split ] ,
outputs = [ process_split_extra_row ] ,
)
2022-10-26 06:22:29 +08:00
process_focal_crop . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_focal_crop ] ,
outputs = [ process_focal_crop_row ] ,
)
2023-01-17 17:16:43 +08:00
process_multicrop . change (
fn = lambda show : gr_show ( show ) ,
inputs = [ process_multicrop ] ,
outputs = [ process_multicrop_col ] ,
)
2023-01-10 04:35:40 +08:00
def get_textual_inversion_template_names ( ) :
return sorted ( [ x for x in textual_inversion . textual_inversion_templates ] )
2023-03-30 02:04:02 +08:00
with gr . Tab ( label = " Train " , id = " train " ) :
2022-10-20 03:33:18 +08:00
gr . HTML ( value = " <p style= ' margin-bottom: 0.7em ' >Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href= \" https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion \" style= \" font-weight:bold; \" >[wiki]</a></p> " )
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
2022-10-18 02:15:32 +08:00
train_embedding_name = gr . Dropdown ( label = ' Embedding ' , elem_id = " train_embedding " , choices = sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) )
2022-10-16 12:42:52 +08:00
create_refresh_button ( train_embedding_name , sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings , lambda : { " choices " : sorted ( sd_hijack . model_hijack . embedding_db . word_embeddings . keys ( ) ) } , " refresh_train_embedding_name " )
2023-01-05 01:10:40 +08:00
2022-10-18 02:15:32 +08:00
train_hypernetwork_name = gr . Dropdown ( label = ' Hypernetwork ' , elem_id = " train_hypernetwork " , choices = [ x for x in shared . hypernetworks . keys ( ) ] )
2022-10-16 12:42:52 +08:00
create_refresh_button ( train_hypernetwork_name , shared . reload_hypernetworks , lambda : { " choices " : sorted ( [ x for x in shared . hypernetworks . keys ( ) ] ) } , " refresh_train_hypernetwork_name " )
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
2023-01-01 21:51:12 +08:00
embedding_learn_rate = gr . Textbox ( label = ' Embedding Learning rate ' , placeholder = " Embedding Learning rate " , value = " 0.005 " , elem_id = " train_embedding_learn_rate " )
hypernetwork_learn_rate = gr . Textbox ( label = ' Hypernetwork Learning rate ' , placeholder = " Hypernetwork Learning rate " , value = " 0.00001 " , elem_id = " train_hypernetwork_learn_rate " )
2023-01-05 00:56:35 +08:00
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
2022-10-28 18:16:23 +08:00
clip_grad_mode = gr . Dropdown ( value = " disabled " , label = " Gradient Clipping " , choices = [ " disabled " , " value " , " norm " ] )
2022-10-31 14:49:24 +08:00
clip_grad_value = gr . Textbox ( placeholder = " Gradient clip value " , value = " 0.1 " , show_label = False )
2023-01-01 21:51:12 +08:00
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
batch_size = gr . Number ( label = ' Batch size ' , value = 1 , precision = 0 , elem_id = " train_batch_size " )
gradient_step = gr . Number ( label = ' Gradient accumulation steps ' , value = 1 , precision = 0 , elem_id = " train_gradient_step " )
2023-01-01 21:51:12 +08:00
dataset_directory = gr . Textbox ( label = ' Dataset directory ' , placeholder = " Path to directory with input images " , elem_id = " train_dataset_directory " )
log_directory = gr . Textbox ( label = ' Log directory ' , placeholder = " Path to directory where to write outputs " , value = " textual_inversion " , elem_id = " train_log_directory " )
2023-01-10 04:35:40 +08:00
with FormRow ( ) :
template_file = gr . Dropdown ( label = ' Prompt template ' , value = " style_filewords.txt " , elem_id = " train_template_file " , choices = get_textual_inversion_template_names ( ) )
create_refresh_button ( template_file , textual_inversion . list_textual_inversion_templates , lambda : { " choices " : get_textual_inversion_template_names ( ) } , " refrsh_train_template_file " )
2023-01-01 21:51:12 +08:00
training_width = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Width " , value = 512 , elem_id = " train_training_width " )
training_height = gr . Slider ( minimum = 64 , maximum = 2048 , step = 8 , label = " Height " , value = 512 , elem_id = " train_training_height " )
2023-01-10 03:52:23 +08:00
varsize = gr . Checkbox ( label = " Do not resize images " , value = False , elem_id = " train_varsize " )
2023-01-01 21:51:12 +08:00
steps = gr . Number ( label = ' Max steps ' , value = 100000 , precision = 0 , elem_id = " train_steps " )
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
create_image_every = gr . Number ( label = ' Save an image to log directory every N steps, 0 to disable ' , value = 500 , precision = 0 , elem_id = " train_create_image_every " )
save_embedding_every = gr . Number ( label = ' Save a copy of embedding to log directory every N steps, 0 to disable ' , value = 500 , precision = 0 , elem_id = " train_save_embedding_every " )
2023-01-12 23:29:00 +08:00
use_weight = gr . Checkbox ( label = " Use PNG alpha channel as loss weight " , value = False , elem_id = " use_weight " )
2023-01-01 21:51:12 +08:00
save_image_with_stored_embedding = gr . Checkbox ( label = ' Save images with embedding in PNG chunks ' , value = True , elem_id = " train_save_image_with_stored_embedding " )
preview_from_txt2img = gr . Checkbox ( label = ' Read parameters (prompt, etc...) from txt2img tab when making previews ' , value = False , elem_id = " train_preview_from_txt2img " )
2023-01-05 01:10:40 +08:00
shuffle_tags = gr . Checkbox ( label = " Shuffle tags by ' , ' when creating prompts. " , value = False , elem_id = " train_shuffle_tags " )
tag_drop_out = gr . Slider ( minimum = 0 , maximum = 1 , step = 0.1 , label = " Drop out tags when creating prompts. " , value = 0 , elem_id = " train_tag_drop_out " )
latent_sampling_method = gr . Radio ( label = ' Choose latent sampling method ' , value = " once " , choices = [ ' once ' , ' deterministic ' , ' random ' ] , elem_id = " train_latent_sampling_method " )
2022-10-02 20:03:39 +08:00
with gr . Row ( ) :
2023-01-05 01:10:40 +08:00
train_embedding = gr . Button ( value = " Train Embedding " , variant = ' primary ' , elem_id = " train_train_embedding " )
2023-01-01 21:51:12 +08:00
interrupt_training = gr . Button ( value = " Interrupt " , elem_id = " train_interrupt_training " )
train_hypernetwork = gr . Button ( value = " Train Hypernetwork " , variant = ' primary ' , elem_id = " train_train_hypernetwork " )
2022-10-02 20:03:39 +08:00
2022-11-08 13:38:10 +08:00
params = script_callbacks . UiTrainTabParams ( txt2img_preview_params )
script_callbacks . ui_train_tabs_callback ( params )
2023-01-15 23:50:56 +08:00
with gr . Column ( elem_id = ' ti_gallery_container ' ) :
2022-10-02 20:03:39 +08:00
ti_output = gr . Text ( elem_id = " ti_output " , value = " " , show_label = False )
2023-04-29 14:17:35 +08:00
ti_gallery = gr . Gallery ( label = ' Output ' , show_label = False , elem_id = ' ti_gallery ' ) . style ( columns = 4 )
2022-10-02 20:03:39 +08:00
ti_progress = gr . HTML ( elem_id = " ti_progress " , value = " " )
ti_outcome = gr . HTML ( elem_id = " ti_error " , value = " " )
create_embedding . click (
fn = modules . textual_inversion . ui . create_embedding ,
inputs = [
new_embedding_name ,
2022-10-03 00:40:51 +08:00
initialization_text ,
2022-10-02 20:03:39 +08:00
nvpt ,
2022-10-20 03:33:18 +08:00
overwrite_old_embedding ,
2022-10-02 20:03:39 +08:00
] ,
outputs = [
train_embedding_name ,
ti_output ,
ti_outcome ,
]
)
2022-10-08 04:22:22 +08:00
create_hypernetwork . click (
2022-10-11 20:54:34 +08:00
fn = modules . hypernetworks . ui . create_hypernetwork ,
2022-10-08 04:22:22 +08:00
inputs = [
new_hypernetwork_name ,
2022-10-11 23:04:47 +08:00
new_hypernetwork_sizes ,
2022-10-20 07:27:16 +08:00
overwrite_old_hypernetwork ,
2022-10-19 22:30:33 +08:00
new_hypernetwork_layer_structure ,
2022-10-20 08:10:45 +08:00
new_hypernetwork_activation_func ,
2022-10-25 13:48:49 +08:00
new_hypernetwork_initialization_option ,
2022-10-19 22:30:33 +08:00
new_hypernetwork_add_layer_norm ,
2023-01-10 13:56:57 +08:00
new_hypernetwork_use_dropout ,
new_hypernetwork_dropout_structure
2022-10-08 04:22:22 +08:00
] ,
outputs = [
train_hypernetwork_name ,
ti_output ,
ti_outcome ,
]
)
2022-10-03 03:41:21 +08:00
run_preprocess . click (
fn = wrap_gradio_gpu_call ( modules . textual_inversion . ui . preprocess , extra_outputs = [ gr . update ( ) ] ) ,
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 23:50:56 +08:00
dummy_component ,
2022-10-03 03:41:21 +08:00
process_src ,
process_dst ,
2022-10-10 21:35:35 +08:00
process_width ,
process_height ,
2022-10-20 07:48:07 +08:00
preprocess_txt_action ,
2023-03-25 22:45:41 +08:00
process_keep_original_size ,
2022-10-03 03:41:21 +08:00
process_flip ,
process_split ,
process_caption ,
2022-10-20 21:56:45 +08:00
process_caption_deepbooru ,
process_split_threshold ,
process_overlap_ratio ,
2022-10-26 06:22:29 +08:00
process_focal_crop ,
process_focal_crop_face_weight ,
process_focal_crop_entropy_weight ,
process_focal_crop_edges_weight ,
process_focal_crop_debug ,
2023-01-17 17:16:43 +08:00
process_multicrop ,
process_multicrop_mindim ,
process_multicrop_maxdim ,
process_multicrop_minarea ,
process_multicrop_maxarea ,
process_multicrop_objective ,
process_multicrop_threshold ,
2022-10-03 03:41:21 +08:00
] ,
outputs = [
ti_output ,
ti_outcome ,
] ,
)
2022-10-02 20:03:39 +08:00
train_embedding . click (
fn = wrap_gradio_gpu_call ( modules . textual_inversion . ui . train_embedding , extra_outputs = [ gr . update ( ) ] ) ,
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 23:50:56 +08:00
dummy_component ,
2022-10-02 20:03:39 +08:00
train_embedding_name ,
2022-10-20 07:19:40 +08:00
embedding_learn_rate ,
2022-10-15 14:24:59 +08:00
batch_size ,
2022-11-20 11:35:26 +08:00
gradient_step ,
2022-10-02 20:03:39 +08:00
dataset_directory ,
log_directory ,
2022-10-10 21:35:35 +08:00
training_width ,
training_height ,
2023-01-08 01:34:52 +08:00
varsize ,
2022-10-02 20:03:39 +08:00
steps ,
2022-10-28 11:31:27 +08:00
clip_grad_mode ,
clip_grad_value ,
2022-11-20 11:35:26 +08:00
shuffle_tags ,
tag_drop_out ,
latent_sampling_method ,
2023-01-12 23:29:00 +08:00
use_weight ,
2022-10-02 20:03:39 +08:00
create_image_every ,
save_embedding_every ,
template_file ,
2022-10-09 12:40:57 +08:00
save_image_with_stored_embedding ,
2022-10-15 01:31:49 +08:00
preview_from_txt2img ,
* txt2img_preview_params ,
2022-10-02 20:03:39 +08:00
] ,
outputs = [
ti_output ,
ti_outcome ,
]
)
2022-10-08 04:22:22 +08:00
train_hypernetwork . click (
2022-10-11 20:54:34 +08:00
fn = wrap_gradio_gpu_call ( modules . hypernetworks . ui . train_hypernetwork , extra_outputs = [ gr . update ( ) ] ) ,
2022-10-08 04:22:22 +08:00
_js = " start_training_textual_inversion " ,
inputs = [
2023-01-15 23:50:56 +08:00
dummy_component ,
2022-10-08 04:22:22 +08:00
train_hypernetwork_name ,
2022-10-20 07:19:40 +08:00
hypernetwork_learn_rate ,
2022-10-15 14:24:59 +08:00
batch_size ,
2022-11-20 11:35:26 +08:00
gradient_step ,
2022-10-08 04:22:22 +08:00
dataset_directory ,
log_directory ,
2022-10-19 13:44:33 +08:00
training_width ,
training_height ,
2023-01-08 01:34:52 +08:00
varsize ,
2022-10-02 20:03:39 +08:00
steps ,
2022-10-28 10:44:56 +08:00
clip_grad_mode ,
clip_grad_value ,
2022-11-20 11:35:26 +08:00
shuffle_tags ,
tag_drop_out ,
latent_sampling_method ,
2023-01-12 23:29:00 +08:00
use_weight ,
2022-10-02 20:03:39 +08:00
create_image_every ,
save_embedding_every ,
template_file ,
2022-10-15 01:31:49 +08:00
preview_from_txt2img ,
* txt2img_preview_params ,
2022-10-02 20:03:39 +08:00
] ,
outputs = [
ti_output ,
ti_outcome ,
]
)
interrupt_training . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-11-18 10:03:57 +08:00
interrupt_preprocessing . click (
fn = lambda : shared . state . interrupt ( ) ,
inputs = [ ] ,
outputs = [ ] ,
)
2022-10-14 00:22:41 +08:00
def create_setting_component ( key , is_quicksettings = False ) :
2022-09-03 17:08:45 +08:00
def fun ( ) :
return opts . data [ key ] if key in opts . data else opts . data_labels [ key ] . default
info = opts . data_labels [ key ]
t = type ( info . default )
2022-09-12 04:00:42 +08:00
args = info . component_args ( ) if callable ( info . component_args ) else info . component_args
2022-09-03 17:08:45 +08:00
if info . component is not None :
2022-09-12 04:00:42 +08:00
comp = info . component
2022-09-03 17:08:45 +08:00
elif t == str :
2022-09-12 04:00:42 +08:00
comp = gr . Textbox
2022-09-03 17:08:45 +08:00
elif t == int :
2022-09-12 04:00:42 +08:00
comp = gr . Number
2022-09-03 17:08:45 +08:00
elif t == bool :
2022-09-12 04:00:42 +08:00
comp = gr . Checkbox
2022-09-03 17:08:45 +08:00
else :
raise Exception ( f ' bad options item type: { str ( t ) } for key { key } ' )
2022-10-18 02:15:32 +08:00
elem_id = " setting_ " + key
2022-10-14 00:22:41 +08:00
if info . refresh is not None :
if is_quicksettings :
2022-11-06 19:39:41 +08:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-18 02:15:32 +08:00
create_refresh_button ( res , info . refresh , info . component_args , " refresh_ " + key )
2022-10-14 00:22:41 +08:00
else :
2023-01-03 14:04:29 +08:00
with FormRow ( ) :
2022-11-06 19:39:41 +08:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-18 02:15:32 +08:00
create_refresh_button ( res , info . refresh , info . component_args , " refresh_ " + key )
2022-10-14 00:22:41 +08:00
else :
2022-11-06 19:39:41 +08:00
res = comp ( label = info . label , value = fun ( ) , elem_id = elem_id , * * ( args or { } ) )
2022-10-14 00:22:41 +08:00
return res
2022-09-03 17:08:45 +08:00
2022-09-10 16:10:00 +08:00
components = [ ]
2022-09-29 05:59:44 +08:00
component_dict = { }
2023-01-30 05:25:30 +08:00
shared . settings_components = component_dict
2022-09-10 16:10:00 +08:00
2022-10-23 00:18:56 +08:00
script_callbacks . ui_settings_callback ( )
opts . reorder ( )
2022-09-03 17:08:45 +08:00
def run_settings ( * args ) :
2022-11-06 15:12:53 +08:00
changed = [ ]
2022-09-23 22:27:30 +08:00
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , components ) :
2022-11-04 15:35:30 +08:00
assert comp == dummy_component or opts . same_type ( value , opts . data_labels [ key ] . default ) , f " Bad value for setting { key } : { value } ; expecting { type ( opts . data_labels [ key ] . default ) . __name__ } "
2022-09-03 17:08:45 +08:00
2022-09-10 16:10:00 +08:00
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , components ) :
2022-10-10 03:24:07 +08:00
if comp == dummy_component :
continue
2022-11-19 20:15:24 +08:00
if opts . set ( key , value ) :
2022-11-06 15:12:53 +08:00
changed . append ( key )
2022-09-12 04:00:42 +08:00
2022-11-05 03:24:42 +08:00
try :
opts . save ( shared . config_filename )
except RuntimeError :
2022-11-06 15:12:53 +08:00
return opts . dumpjson ( ) , f ' { len ( changed ) } settings changed without save: { " , " . join ( changed ) } . '
2022-12-09 14:47:45 +08:00
return opts . dumpjson ( ) , f ' { len ( changed ) } settings changed { " : " if len ( changed ) > 0 else " " } { " , " . join ( changed ) } . '
2022-09-03 17:08:45 +08:00
2022-10-10 03:24:07 +08:00
def run_settings_single ( value , key ) :
if not opts . same_type ( value , opts . data_labels [ key ] . default ) :
return gr . update ( visible = True ) , opts . dumpjson ( )
2022-11-19 20:15:24 +08:00
if not opts . set ( key , value ) :
return gr . update ( value = getattr ( opts , key ) ) , opts . dumpjson ( )
2022-10-10 03:24:07 +08:00
opts . save ( shared . config_filename )
2023-01-19 23:58:08 +08:00
return get_value_for_setting ( key ) , opts . dumpjson ( )
2022-10-10 03:24:07 +08:00
2022-09-10 16:10:00 +08:00
with gr . Blocks ( analytics_enabled = False ) as settings_interface :
2023-01-03 12:20:20 +08:00
with gr . Row ( ) :
2023-01-04 01:23:17 +08:00
with gr . Column ( scale = 6 ) :
settings_submit = gr . Button ( value = " Apply settings " , variant = ' primary ' , elem_id = " settings_submit " )
with gr . Column ( ) :
restart_gradio = gr . Button ( value = ' Reload UI ' , variant = ' primary ' , elem_id = " settings_restart_gradio " )
2022-09-10 16:10:00 +08:00
2023-01-03 12:20:20 +08:00
result = gr . HTML ( elem_id = " settings_result " )
2022-09-10 16:10:00 +08:00
2023-05-08 20:38:25 +08:00
quicksettings_names = opts . quicksettings_list
2023-01-03 14:13:35 +08:00
quicksettings_names = { x : i for i , x in enumerate ( quicksettings_names ) if x != ' quicksettings ' }
2022-10-13 21:07:18 +08:00
2022-10-10 03:24:07 +08:00
quicksettings_list = [ ]
2022-09-23 02:32:44 +08:00
previous_section = None
2023-01-03 12:20:20 +08:00
current_tab = None
2023-01-14 19:56:39 +08:00
current_row = None
2023-01-03 12:20:20 +08:00
with gr . Tabs ( elem_id = " settings " ) :
2022-09-23 02:32:44 +08:00
for i , ( k , item ) in enumerate ( opts . data_labels . items ( ) ) :
2022-10-31 22:36:45 +08:00
section_must_be_skipped = item . section [ 0 ] is None
2022-09-23 00:26:26 +08:00
2022-10-31 22:36:45 +08:00
if previous_section != item . section and not section_must_be_skipped :
2023-01-03 12:20:20 +08:00
elem_id , text = item . section
2022-09-23 00:26:26 +08:00
2023-01-03 12:20:20 +08:00
if current_tab is not None :
2023-01-14 19:56:39 +08:00
current_row . __exit__ ( )
2023-01-03 12:20:20 +08:00
current_tab . __exit__ ( )
2022-09-10 16:10:00 +08:00
2023-01-14 19:56:39 +08:00
gr . Group ( )
2023-01-03 12:20:20 +08:00
current_tab = gr . TabItem ( elem_id = " settings_ {} " . format ( elem_id ) , label = text )
current_tab . __enter__ ( )
2023-01-14 19:56:39 +08:00
current_row = gr . Column ( variant = ' compact ' )
current_row . __enter__ ( )
2022-09-23 02:32:44 +08:00
previous_section = item . section
2022-10-23 03:05:22 +08:00
if k in quicksettings_names and not shared . cmd_opts . freeze_settings :
2022-10-10 03:24:07 +08:00
quicksettings_list . append ( ( i , k , item ) )
components . append ( dummy_component )
2022-10-31 22:36:45 +08:00
elif section_must_be_skipped :
components . append ( dummy_component )
2022-10-10 03:24:07 +08:00
else :
component = create_setting_component ( k )
component_dict [ k ] = component
components . append ( component )
2022-09-03 17:08:45 +08:00
2023-01-03 12:20:20 +08:00
if current_tab is not None :
2023-01-14 19:56:39 +08:00
current_row . __exit__ ( )
2023-01-03 12:20:20 +08:00
current_tab . __exit__ ( )
2022-10-18 02:15:32 +08:00
2023-05-08 20:30:32 +08:00
with gr . TabItem ( " Actions " , id = " actions " , elem_id = " settings_tab_actions " ) :
2023-01-03 12:20:20 +08:00
request_notifications = gr . Button ( value = ' Request browser notifications ' , elem_id = " request_notifications " )
download_localization = gr . Button ( value = ' Download localization template ' , elem_id = " download_localization " )
reload_script_bodies = gr . Button ( value = ' Reload custom script bodies (No ui updates, No restart) ' , variant = ' secondary ' , elem_id = " settings_reload_script_bodies " )
2023-03-09 12:56:19 +08:00
with gr . Row ( ) :
unload_sd_model = gr . Button ( value = ' Unload SD checkpoint to free VRAM ' , elem_id = " sett_unload_sd_model " )
reload_sd_model = gr . Button ( value = ' Reload the last SD checkpoint back into VRAM ' , elem_id = " sett_reload_sd_model " )
2022-10-13 21:07:18 +08:00
2023-05-08 20:30:32 +08:00
with gr . TabItem ( " Licenses " , id = " licenses " , elem_id = " settings_tab_licenses " ) :
2023-01-21 13:36:07 +08:00
gr . HTML ( shared . html ( " licenses.html " ) , elem_id = " licenses " )
2023-01-04 01:23:17 +08:00
2023-01-03 15:01:06 +08:00
gr . Button ( value = " Show all pages " , elem_id = " settings_show_all_pages " )
2023-04-30 03:15:20 +08:00
2023-03-09 12:56:19 +08:00
def unload_sd_weights ( ) :
modules . sd_models . unload_model_weights ( )
def reload_sd_weights ( ) :
modules . sd_models . reload_model_weights ( )
unload_sd_model . click (
fn = unload_sd_weights ,
inputs = [ ] ,
outputs = [ ]
)
reload_sd_model . click (
fn = reload_sd_weights ,
inputs = [ ] ,
outputs = [ ]
)
2022-10-13 21:07:18 +08:00
2022-09-19 09:41:57 +08:00
request_notifications . click (
fn = lambda : None ,
inputs = [ ] ,
outputs = [ ] ,
2022-09-22 18:15:33 +08:00
_js = ' function() {} '
2022-09-19 09:41:57 +08:00
)
2022-10-18 02:15:32 +08:00
download_localization . click (
fn = lambda : None ,
inputs = [ ] ,
outputs = [ ] ,
_js = ' download_localization '
)
2022-10-02 08:19:55 +08:00
def reload_scripts ( ) :
2022-10-03 02:26:06 +08:00
modules . scripts . reload_script_body_only ( )
2022-10-23 03:05:22 +08:00
reload_javascript ( ) # need to refresh the html page
2022-10-02 08:19:55 +08:00
reload_script_bodies . click (
fn = reload_scripts ,
inputs = [ ] ,
2022-11-02 14:47:53 +08:00
outputs = [ ]
2022-10-02 08:19:55 +08:00
)
2022-10-02 08:36:30 +08:00
def request_restart ( ) :
2022-10-05 11:43:05 +08:00
shared . state . interrupt ( )
2022-10-31 22:36:45 +08:00
shared . state . need_restart = True
2022-10-02 08:36:30 +08:00
restart_gradio . click (
fn = request_restart ,
2022-11-06 14:02:25 +08:00
_js = ' restart_reload ' ,
2022-10-02 08:36:30 +08:00
inputs = [ ] ,
outputs = [ ] ,
)
2022-10-10 09:26:52 +08:00
2022-09-03 17:08:45 +08:00
interfaces = [
2022-09-10 16:10:00 +08:00
( txt2img_interface , " txt2img " , " txt2img " ) ,
( img2img_interface , " img2img " , " img2img " ) ,
( extras_interface , " Extras " , " extras " ) ,
( pnginfo_interface , " PNG Info " , " pnginfo " ) ,
2022-09-26 07:22:12 +08:00
( modelmerger_interface , " Checkpoint Merger " , " modelmerger " ) ,
2022-10-12 16:05:57 +08:00
( train_interface , " Train " , " ti " ) ,
2022-09-03 17:08:45 +08:00
]
2022-10-29 15:56:19 +08:00
interfaces + = script_callbacks . ui_tabs_callback ( )
interfaces + = [ ( settings_interface , " Settings " , " settings " ) ]
2022-10-31 22:36:45 +08:00
extensions_interface = ui_extensions . create_ui ( )
interfaces + = [ ( extensions_interface , " Extensions " , " extensions " ) ]
2023-02-19 22:21:44 +08:00
shared . tab_names = [ ]
for _interface , label , _ifid in interfaces :
shared . tab_names . append ( label )
2023-04-29 17:45:43 +08:00
with gr . Blocks ( theme = shared . gradio_theme , analytics_enabled = False , title = " Stable Diffusion " ) as demo :
2023-01-18 19:33:09 +08:00
with gr . Row ( elem_id = " quicksettings " , variant = " compact " ) :
2023-01-03 14:13:35 +08:00
for i , k , item in sorted ( quicksettings_list , key = lambda x : quicksettings_names . get ( x [ 1 ] , x [ 0 ] ) ) :
2022-10-14 00:22:41 +08:00
component = create_setting_component ( k , is_quicksettings = True )
2022-10-10 03:24:07 +08:00
component_dict [ k ] = component
2023-01-30 05:25:30 +08:00
parameters_copypaste . connect_paste_params_buttons ( )
2022-10-29 15:56:19 +08:00
2022-10-14 01:42:27 +08:00
with gr . Tabs ( elem_id = " tabs " ) as tabs :
2022-09-10 16:10:00 +08:00
for interface , label , ifid in interfaces :
2023-02-19 22:21:44 +08:00
if label in shared . opts . hidden_tabs :
2023-02-14 09:26:47 +08:00
continue
2022-10-11 13:22:46 +08:00
with gr . TabItem ( label , id = ifid , elem_id = ' tab_ ' + ifid ) :
2022-09-10 16:10:00 +08:00
interface . render ( )
2022-10-10 09:26:52 +08:00
2022-09-27 04:57:31 +08:00
if os . path . exists ( os . path . join ( script_path , " notification.mp3 " ) ) :
audio_notification = gr . Audio ( interactive = False , value = os . path . join ( script_path , " notification.mp3 " ) , elem_id = " audio_notification " , visible = False )
2022-09-10 16:10:00 +08:00
2023-01-21 13:36:07 +08:00
footer = shared . html ( " footer.html " )
footer = footer . format ( versions = versions_html ( ) )
gr . HTML ( footer , elem_id = " footer " )
2023-01-04 01:23:17 +08:00
2022-09-19 22:16:04 +08:00
text_settings = gr . Textbox ( elem_id = " settings_json " , value = lambda : opts . dumpjson ( ) , visible = False )
2022-09-19 03:25:18 +08:00
settings_submit . click (
2022-11-04 15:35:30 +08:00
fn = wrap_gradio_call ( run_settings , extra_outputs = [ gr . update ( ) ] ) ,
2022-09-24 05:13:32 +08:00
inputs = components ,
2022-11-04 15:35:30 +08:00
outputs = [ text_settings , result ] ,
2022-09-19 03:25:18 +08:00
)
2022-10-10 03:24:07 +08:00
for i , k , item in quicksettings_list :
component = component_dict [ k ]
2023-03-21 13:18:14 +08:00
info = opts . data_labels [ k ]
2022-10-10 03:24:07 +08:00
2023-04-30 00:39:22 +08:00
change_handler = component . release if hasattr ( component , ' release ' ) else component . change
change_handler (
2022-10-10 03:24:07 +08:00
fn = lambda value , k = k : run_settings_single ( value , key = k ) ,
inputs = [ component ] ,
outputs = [ component , text_settings ] ,
2023-03-21 13:18:14 +08:00
show_progress = info . refresh is not None ,
2022-10-10 03:24:07 +08:00
)
2023-05-02 14:08:00 +08:00
update_image_cfg_scale_visibility = lambda : gr . update ( visible = shared . sd_model and shared . sd_model . cond_stage_key == " edit " )
text_settings . change ( fn = update_image_cfg_scale_visibility , inputs = [ ] , outputs = [ image_cfg_scale ] )
demo . load ( fn = update_image_cfg_scale_visibility , inputs = [ ] , outputs = [ image_cfg_scale ] )
2023-02-04 16:18:44 +08:00
2023-01-29 03:52:27 +08:00
button_set_checkpoint = gr . Button ( ' Change checkpoint ' , elem_id = ' change_checkpoint ' , visible = False )
button_set_checkpoint . click (
fn = lambda value , _ : run_settings_single ( value , key = ' sd_model_checkpoint ' ) ,
_js = " function(v) { var res = desiredCheckpointName; desiredCheckpointName = ' ' ; return [res || v, null]; } " ,
inputs = [ component_dict [ ' sd_model_checkpoint ' ] , dummy_component ] ,
outputs = [ component_dict [ ' sd_model_checkpoint ' ] , text_settings ] ,
)
2022-11-06 19:39:41 +08:00
component_keys = [ k for k in opts . data_labels . keys ( ) if k in component_dict ]
def get_settings_values ( ) :
2023-01-19 23:07:37 +08:00
return [ get_value_for_setting ( key ) for key in component_keys ]
2022-11-06 19:39:41 +08:00
demo . load (
fn = get_settings_values ,
inputs = [ ] ,
outputs = [ component_dict [ k ] for k in component_keys ] ,
2023-03-29 13:58:29 +08:00
queue = False ,
2022-11-06 19:39:41 +08:00
)
2022-09-29 07:50:34 +08:00
def modelmerger ( * args ) :
try :
2022-10-02 20:03:39 +08:00
results = modules . extras . run_modelmerger ( * args )
2022-09-29 07:50:34 +08:00
except Exception as e :
print ( " Error loading/saving model file: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-10-02 20:03:39 +08:00
modules . sd_models . list_models ( ) # to remove the potentially missing models from the list
2023-01-19 14:25:37 +08:00
return [ * [ gr . Dropdown . update ( choices = modules . sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ] , f " Error merging checkpoints: { e } " ]
2022-09-29 07:50:34 +08:00
return results
2022-09-19 03:25:18 +08:00
2023-01-19 15:39:51 +08:00
modelmerger_merge . click ( fn = lambda : ' ' , inputs = [ ] , outputs = [ modelmerger_result ] )
2022-09-29 05:59:44 +08:00
modelmerger_merge . click (
2023-01-19 14:25:37 +08:00
fn = wrap_gradio_gpu_call ( modelmerger , extra_outputs = lambda : [ gr . update ( ) for _ in range ( 4 ) ] ) ,
_js = ' modelmerger ' ,
2022-09-29 05:59:44 +08:00
inputs = [
2023-01-19 14:25:37 +08:00
dummy_component ,
2022-09-29 05:59:44 +08:00
primary_model_name ,
secondary_model_name ,
2022-10-14 14:05:06 +08:00
tertiary_model_name ,
2022-09-29 05:59:44 +08:00
interp_method ,
interp_amount ,
save_as_half ,
2022-09-29 07:50:34 +08:00
custom_name ,
2022-11-27 20:51:29 +08:00
checkpoint_format ,
2023-01-11 14:10:07 +08:00
config_source ,
2023-01-19 15:39:51 +08:00
bake_in_vae ,
2023-01-22 15:17:12 +08:00
discard_weights ,
2023-04-03 06:41:55 +08:00
save_metadata ,
2022-09-29 05:59:44 +08:00
] ,
outputs = [
primary_model_name ,
secondary_model_name ,
2022-10-14 14:05:06 +08:00
tertiary_model_name ,
2022-09-29 05:59:44 +08:00
component_dict [ ' sd_model_checkpoint ' ] ,
2023-01-19 14:25:37 +08:00
modelmerger_result ,
2022-09-29 05:59:44 +08:00
]
)
2022-09-24 03:49:21 +08:00
2022-09-10 13:18:54 +08:00
ui_config_file = cmd_opts . ui_config_file
2022-09-04 18:52:01 +08:00
ui_settings = { }
settings_count = len ( ui_settings )
error_loading = False
try :
if os . path . exists ( ui_config_file ) :
with open ( ui_config_file , " r " , encoding = " utf8 " ) as file :
ui_settings = json . load ( file )
except Exception :
error_loading = True
print ( " Error loading settings: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
def loadsave ( path , x ) :
2022-10-18 17:51:57 +08:00
def apply_field ( obj , field , condition = None , init_field = None ) :
2022-09-04 18:52:01 +08:00
key = path + " / " + field
2022-09-25 13:56:50 +08:00
2022-10-29 15:56:19 +08:00
if getattr ( obj , ' custom_script_source ' , None ) is not None :
2022-09-25 13:56:50 +08:00
key = ' customscript/ ' + obj . custom_script_source + ' / ' + key
2022-10-10 09:26:52 +08:00
2022-09-26 00:43:42 +08:00
if getattr ( obj , ' do_not_save_to_config ' , False ) :
return
2022-10-10 09:26:52 +08:00
2022-09-04 18:52:01 +08:00
saved_value = ui_settings . get ( key , None )
if saved_value is None :
ui_settings [ key ] = getattr ( obj , field )
2022-10-16 04:09:11 +08:00
elif condition and not condition ( saved_value ) :
2023-01-19 04:04:24 +08:00
pass
# this warning is generally not useful;
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
2022-10-16 04:09:11 +08:00
else :
2022-09-04 18:52:01 +08:00
setattr ( obj , field , saved_value )
2022-10-18 17:51:57 +08:00
if init_field is not None :
init_field ( saved_value )
2022-09-04 18:52:01 +08:00
2023-04-18 05:48:28 +08:00
if type ( x ) in [ gr . Slider , gr . Radio , gr . Checkbox , gr . Textbox , gr . Number , gr . Dropdown , ToolButton ] and x . visible :
2022-09-26 00:43:42 +08:00
apply_field ( x , ' visible ' )
2022-09-04 18:52:01 +08:00
if type ( x ) == gr . Slider :
apply_field ( x , ' value ' )
apply_field ( x , ' minimum ' )
apply_field ( x , ' maximum ' )
apply_field ( x , ' step ' )
if type ( x ) == gr . Radio :
2022-09-06 00:11:29 +08:00
apply_field ( x , ' value ' , lambda val : val in x . choices )
2022-09-04 18:52:01 +08:00
2022-09-25 13:31:02 +08:00
if type ( x ) == gr . Checkbox :
2022-09-25 13:40:37 +08:00
apply_field ( x , ' value ' )
2022-09-25 13:31:02 +08:00
if type ( x ) == gr . Textbox :
2022-09-25 13:40:37 +08:00
apply_field ( x , ' value ' )
2022-10-10 09:26:52 +08:00
2022-09-25 13:39:22 +08:00
if type ( x ) == gr . Number :
2022-09-25 13:40:37 +08:00
apply_field ( x , ' value ' )
2022-10-10 09:26:52 +08:00
2023-01-06 21:03:43 +08:00
if type ( x ) == gr . Dropdown :
2023-01-14 19:56:39 +08:00
def check_dropdown ( val ) :
2023-01-22 01:07:14 +08:00
if getattr ( x , ' multiselect ' , False ) :
2023-01-14 19:56:39 +08:00
return all ( [ value in x . choices for value in val ] )
else :
return val in x . choices
apply_field ( x , ' value ' , check_dropdown , getattr ( x , ' init_field ' , None ) )
2022-10-16 03:47:03 +08:00
2023-03-30 02:04:02 +08:00
def check_tab_id ( tab_id ) :
tab_items = list ( filter ( lambda e : isinstance ( e , gr . TabItem ) , x . children ) )
if type ( tab_id ) == str :
tab_ids = [ t . id for t in tab_items ]
return tab_id in tab_ids
elif type ( tab_id ) == int :
return tab_id > = 0 and tab_id < len ( tab_items )
else :
return False
if type ( x ) == gr . Tabs :
apply_field ( x , ' selected ' , check_tab_id )
2022-09-04 18:52:01 +08:00
visit ( txt2img_interface , loadsave , " txt2img " )
visit ( img2img_interface , loadsave , " img2img " )
2022-09-11 16:31:16 +08:00
visit ( extras_interface , loadsave , " extras " )
2022-10-18 00:56:23 +08:00
visit ( modelmerger_interface , loadsave , " modelmerger " )
2023-01-05 01:10:40 +08:00
visit ( train_interface , loadsave , " train " )
2022-09-04 18:52:01 +08:00
2023-03-30 02:04:02 +08:00
loadsave ( f " webui/Tabs@ { tabs . elem_id } " , tabs )
2022-09-04 18:52:01 +08:00
if not error_loading and ( not os . path . exists ( ui_config_file ) or settings_count != len ( ui_settings ) ) :
with open ( ui_config_file , " w " , encoding = " utf8 " ) as file :
json . dump ( ui_settings , file , indent = 4 )
2023-01-20 13:48:15 +08:00
# Required as a workaround for change() event not triggering when loading values from ui-config.json
interp_description . value = update_interp_description ( interp_method . value )
2022-09-03 17:08:45 +08:00
return demo
2023-03-27 17:59:12 +08:00
def webpath ( fn ) :
if fn . startswith ( script_path ) :
web_path = os . path . relpath ( fn , script_path ) . replace ( ' \\ ' , ' / ' )
else :
web_path = os . path . abspath ( fn )
return f ' file= { web_path } ? { os . path . getmtime ( fn ) } '
def javascript_html ( ) :
2023-03-11 09:18:08 +08:00
script_js = os . path . join ( script_path , " script.js " )
2023-03-27 17:59:12 +08:00
head = f ' <script type= " text/javascript " src= " { webpath ( script_js ) } " ></script> \n '
2022-09-03 17:08:45 +08:00
2023-01-22 03:57:19 +08:00
inline = f " { localization . localization_js ( shared . opts . localization ) } ; "
2022-10-19 14:43:49 +08:00
if cmd_opts . theme is not None :
2023-01-22 03:57:19 +08:00
inline + = f " set_theme( ' { cmd_opts . theme } ' ); "
2022-09-03 17:08:45 +08:00
2023-01-22 03:57:19 +08:00
for script in modules . scripts . list_scripts ( " javascript " , " .js " ) :
2023-03-27 17:59:12 +08:00
head + = f ' <script type= " text/javascript " src= " { webpath ( script . path ) } " ></script> \n '
2022-09-03 17:08:45 +08:00
2023-02-23 11:29:22 +08:00
for script in modules . scripts . list_scripts ( " javascript " , " .mjs " ) :
2023-03-27 17:59:12 +08:00
head + = f ' <script type= " module " src= " { webpath ( script . path ) } " ></script> \n '
2023-02-23 11:29:22 +08:00
2023-01-23 16:54:42 +08:00
head + = f ' <script type= " text/javascript " > { inline } </script> \n '
2023-03-27 17:59:12 +08:00
return head
def css_html ( ) :
head = " "
def stylesheet ( fn ) :
return f ' <link rel= " stylesheet " property= " stylesheet " href= " { webpath ( fn ) } " > '
for cssfile in modules . scripts . list_files_with_name ( " style.css " ) :
if not os . path . isfile ( cssfile ) :
continue
head + = stylesheet ( cssfile )
if os . path . exists ( os . path . join ( data_path , " user.css " ) ) :
head + = stylesheet ( os . path . join ( data_path , " user.css " ) )
return head
def reload_javascript ( ) :
js = javascript_html ( )
css = css_html ( )
2022-10-03 02:26:06 +08:00
def template_response ( * args , * * kwargs ) :
2022-11-08 13:35:01 +08:00
res = shared . GradioTemplateResponseOriginal ( * args , * * kwargs )
2023-03-27 17:59:12 +08:00
res . body = res . body . replace ( b ' </head> ' , f ' { js } </head> ' . encode ( " utf8 " ) )
res . body = res . body . replace ( b ' </body> ' , f ' { css } </body> ' . encode ( " utf8 " ) )
2022-10-03 02:26:06 +08:00
res . init_headers ( )
return res
gradio . routes . templates . TemplateResponse = template_response
2022-10-13 00:19:34 +08:00
2022-10-15 01:04:47 +08:00
2022-11-08 13:35:01 +08:00
if not hasattr ( shared , ' GradioTemplateResponseOriginal ' ) :
shared . GradioTemplateResponseOriginal = gradio . routes . templates . TemplateResponse
2023-01-05 16:57:01 +08:00
def versions_html ( ) :
import torch
import launch
python_version = " . " . join ( [ str ( x ) for x in sys . version_info [ 0 : 3 ] ] )
commit = launch . commit_hash ( )
2023-05-08 20:23:49 +08:00
tag = launch . git_tag ( )
2023-01-05 16:57:01 +08:00
if shared . xformers_available :
import xformers
xformers_version = xformers . __version__
else :
xformers_version = " N/A "
return f """
2023-05-08 20:23:49 +08:00
version : < a href = " https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/ {commit} " > { tag } < / a >
•
2023-01-05 16:57:01 +08:00
python : < span title = " {sys.version} " > { python_version } < / span >
•
2023-02-19 13:38:38 +08:00
torch : { getattr ( torch , ' __long_version__ ' , torch . __version__ ) }
2023-01-05 16:57:01 +08:00
•
xformers : { xformers_version }
•
gradio : { gr . __version__ }
•
2023-01-14 20:55:40 +08:00
checkpoint : < a id = " sd_checkpoint_hash " > N / A < / a >
2023-01-05 16:57:01 +08:00
"""
2023-05-08 21:46:35 +08:00
def setup_ui_api ( app ) :
from pydantic import BaseModel , Field
from typing import List
class QuicksettingsHint ( BaseModel ) :
name : str = Field ( title = " Name of the quicksettings field " )
label : str = Field ( title = " Label of the quicksettings field " )
def quicksettings_hint ( ) :
return [ QuicksettingsHint ( name = k , label = v . label ) for k , v in opts . data_labels . items ( ) ]
app . add_api_route ( " /internal/quicksettings-hint " , quicksettings_hint , methods = [ " GET " ] , response_model = List [ QuicksettingsHint ] )
2023-05-09 16:42:47 +08:00
app . add_api_route ( " /internal/ping " , lambda : { } , methods = [ " GET " ] )