2022-10-29 13:42:34 +08:00
import base64
import io
2022-10-09 11:57:19 +08:00
import os
2022-09-24 03:49:21 +08:00
import re
2022-11-28 04:04:42 +08:00
from pathlib import Path
2022-09-24 03:49:21 +08:00
import gradio as gr
2022-10-09 11:57:19 +08:00
from modules . shared import script_path
2022-10-13 17:26:34 +08:00
from modules import shared
2022-10-27 13:36:11 +08:00
import tempfile
2022-10-29 15:56:19 +08:00
from PIL import Image
2022-09-24 03:49:21 +08:00
2022-10-21 21:10:51 +08:00
re_param_code = r ' \ s*([ \ w ]+): \ s*( " (?: \\ | \ " |[^ \ " ])+ " |[^,]*)(?:,|$) '
2022-09-25 14:25:28 +08:00
re_param = re . compile ( re_param_code )
re_params = re . compile ( r " ^(?: " + re_param_code + " ) { 3,}$ " )
2022-09-24 03:49:21 +08:00
re_imagesize = re . compile ( r " ^( \ d+)x( \ d+)$ " )
2022-12-14 06:25:16 +08:00
re_hypernet_hash = re . compile ( " \ (([0-9a-f]+) \ )$ " )
2022-09-25 14:25:28 +08:00
type_of_gr_update = type ( gr . update ( ) )
2022-10-27 13:36:11 +08:00
paste_fields = { }
bind_list = [ ]
2022-09-24 03:49:21 +08:00
2022-10-29 13:42:34 +08:00
2022-10-31 22:36:45 +08:00
def reset ( ) :
paste_fields . clear ( )
bind_list . clear ( )
2022-10-21 21:10:51 +08:00
def quote ( text ) :
if ' , ' not in str ( text ) :
return text
text = str ( text )
text = text . replace ( ' \\ ' , ' \\ \\ ' )
text = text . replace ( ' " ' , ' \\ " ' )
return f ' " { text } " '
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def image_from_url_text ( filedata ) :
if type ( filedata ) == dict and filedata [ " is_file " ] :
filename = filedata [ " name " ]
2022-11-28 04:04:42 +08:00
is_in_right_dir = any ( Path ( temp_dir ) . resolve ( ) in Path ( filename ) . resolve ( ) . parents for temp_dir in shared . demo . temp_dirs )
assert is_in_right_dir , ' trying to open image file outside of allowed directories '
2022-10-27 13:36:11 +08:00
return Image . open ( filename )
if type ( filedata ) == list :
if len ( filedata ) == 0 :
return None
filedata = filedata [ 0 ]
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filedata = base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) )
image = Image . open ( io . BytesIO ( filedata ) )
return image
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def add_paste_fields ( tabname , init_img , fields ) :
2022-10-29 14:01:04 +08:00
paste_fields [ tabname ] = { " init_img " : init_img , " fields " : fields }
# backwards compatibility for existing extensions
import modules . ui
if tabname == ' txt2img ' :
modules . ui . txt2img_paste_fields = fields
elif tabname == ' img2img ' :
modules . ui . img2img_paste_fields = fields
2022-10-27 13:36:11 +08:00
2022-10-29 13:42:34 +08:00
2022-10-29 15:56:19 +08:00
def integrate_settings_paste_fields ( component_dict ) :
from modules import ui
settings_map = {
' sd_hypernetwork ' : ' Hypernet ' ,
2022-10-30 13:48:53 +08:00
' sd_hypernetwork_strength ' : ' Hypernet strength ' ,
2022-10-29 15:56:19 +08:00
' CLIP_stop_at_last_layers ' : ' Clip skip ' ,
2022-11-19 17:47:52 +08:00
' inpainting_mask_weight ' : ' Conditional mask weight ' ,
2022-10-29 15:56:19 +08:00
' sd_model_checkpoint ' : ' Model hash ' ,
2022-11-27 21:28:32 +08:00
' eta_noise_seed_delta ' : ' ENSD ' ,
2022-12-10 14:51:26 +08:00
' initial_noise_multiplier ' : ' Noise multiplier ' ,
2022-10-29 15:56:19 +08:00
}
settings_paste_fields = [
( component_dict [ k ] , lambda d , k = k , v = v : ui . apply_setting ( k , d . get ( v , None ) ) )
for k , v in settings_map . items ( )
]
for tabname , info in paste_fields . items ( ) :
if info [ " fields " ] is not None :
info [ " fields " ] + = settings_paste_fields
2022-10-27 13:36:11 +08:00
def create_buttons ( tabs_list ) :
buttons = { }
for tab in tabs_list :
buttons [ tab ] = gr . Button ( f " Send to { tab } " )
return buttons
2022-10-29 13:42:34 +08:00
2022-10-29 14:01:04 +08:00
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
2022-10-27 13:36:11 +08:00
def bind_buttons ( buttons , send_image , send_generate_info ) :
bind_list . append ( [ buttons , send_image , send_generate_info ] )
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def run_bind ( ) :
for buttons , send_image , send_generate_info in bind_list :
for tab in buttons :
button = buttons [ tab ]
if send_image and paste_fields [ tab ] [ " init_img " ] :
if type ( send_image ) == gr . Gallery :
button . click (
fn = lambda x : image_from_url_text ( x ) ,
_js = " extract_image_from_gallery " ,
inputs = [ send_image ] ,
outputs = [ paste_fields [ tab ] [ " init_img " ] ] ,
)
else :
button . click (
2022-10-29 15:56:19 +08:00
fn = lambda x : x ,
2022-10-27 13:36:11 +08:00
inputs = [ send_image ] ,
outputs = [ paste_fields [ tab ] [ " init_img " ] ] ,
)
2022-10-29 14:01:04 +08:00
2022-10-27 13:36:11 +08:00
if send_generate_info and paste_fields [ tab ] [ " fields " ] is not None :
2022-10-29 14:01:04 +08:00
if send_generate_info in paste_fields :
2022-12-05 19:41:36 +08:00
paste_field_names = [ ' Prompt ' , ' Negative prompt ' , ' Steps ' , ' Face restoration ' ] + ( [ ' Size-1 ' , ' Size-2 ' ] if shared . opts . send_size else [ ] ) + ( [ " Seed " ] if shared . opts . send_seed else [ ] )
2022-10-27 13:36:11 +08:00
button . click (
2022-10-29 15:56:19 +08:00
fn = lambda * x : x ,
inputs = [ field for field , name in paste_fields [ send_generate_info ] [ " fields " ] if name in paste_field_names ] ,
outputs = [ field for field , name in paste_fields [ tab ] [ " fields " ] if name in paste_field_names ] ,
2022-10-27 13:36:11 +08:00
)
else :
2022-10-29 15:56:19 +08:00
connect_paste ( button , paste_fields [ tab ] [ " fields " ] , send_generate_info )
2022-10-27 13:36:11 +08:00
button . click (
fn = None ,
_js = f " switch_to_ { tab } " ,
inputs = None ,
outputs = None ,
)
2022-10-29 13:42:34 +08:00
2022-12-14 06:25:16 +08:00
def find_hypernetwork_key ( hypernet_name , hypernet_hash = None ) :
""" Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example : an infotext provides " Hypernet: ke-ta " and " Hypernet hash: 1234abcd " . For the " Hypernet " config
parameter this means there should be an entry that looks like " ke-ta-10000(1234abcd) " to set it to .
2022-12-14 06:32:26 +08:00
If the infotext has no hash , then a hypernet with the same name will be selected instead .
2022-12-14 06:25:16 +08:00
"""
hypernet_name = hypernet_name . lower ( )
if hypernet_hash is not None :
# Try to match the hash in the name
for hypernet_key in shared . hypernetworks . keys ( ) :
result = re_hypernet_hash . search ( hypernet_key )
if result is not None and result [ 1 ] == hypernet_hash :
return hypernet_key
else :
# Fall back to a hypernet with the same name
for hypernet_key in shared . hypernetworks . keys ( ) :
if hypernet_key . lower ( ) . startswith ( hypernet_name ) :
return hypernet_key
return None
2022-09-24 03:49:21 +08:00
def parse_generation_parameters ( x : str ) :
""" parses generation parameters string, the one you see in text field under the picture in UI:
` ` `
girl with an artist ' s beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt : ugly , fat , obese , chubby , ( ( ( deformed ) ) ) , [ blurry ] , bad anatomy , disfigured , poorly drawn face , mutation , mutated , ( extra_limb ) , ( ugly ) , ( poorly drawn hands ) , messy drawing
Steps : 20 , Sampler : Euler a , CFG scale : 7 , Seed : 965400086 , Size : 512 x512 , Model hash : 45 dee52b
` ` `
returns a dict with field values
"""
res = { }
prompt = " "
negative_prompt = " "
done_with_prompt = False
* lines , lastline = x . strip ( ) . split ( " \n " )
2022-09-25 14:25:28 +08:00
if not re_params . match ( lastline ) :
lines . append ( lastline )
lastline = ' '
2022-09-24 03:49:21 +08:00
for i , line in enumerate ( lines ) :
line = line . strip ( )
if line . startswith ( " Negative prompt: " ) :
done_with_prompt = True
line = line [ 16 : ] . strip ( )
if done_with_prompt :
2022-09-25 14:25:28 +08:00
negative_prompt + = ( " " if negative_prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
else :
2022-09-25 14:25:28 +08:00
prompt + = ( " " if prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
2022-10-20 02:22:03 +08:00
res [ " Prompt " ] = prompt
res [ " Negative prompt " ] = negative_prompt
2022-09-24 03:49:21 +08:00
for k , v in re_param . findall ( lastline ) :
m = re_imagesize . match ( v )
if m is not None :
res [ k + " -1 " ] = m . group ( 1 )
res [ k + " -2 " ] = m . group ( 2 )
else :
res [ k ] = v
2022-12-02 03:34:16 +08:00
# Missing CLIP skip means it was set to 1 (the default)
if " Clip skip " not in res :
res [ " Clip skip " ] = " 1 "
2022-12-14 06:30:54 +08:00
if " Hypernet strength " not in res :
res [ " Hypernet strength " ] = " 1 "
2022-12-14 06:25:16 +08:00
if " Hypernet " in res :
hypernet_name = res [ " Hypernet " ]
hypernet_hash = res . get ( " Hypernet hash " , None )
res [ " Hypernet " ] = find_hypernetwork_key ( hypernet_name , hypernet_hash )
2022-09-24 03:49:21 +08:00
return res
2022-10-29 14:01:04 +08:00
def connect_paste ( button , paste_fields , input_comp , jsfunc = None ) :
2022-09-24 03:49:21 +08:00
def paste_func ( prompt ) :
2022-10-13 17:26:34 +08:00
if not prompt and not shared . cmd_opts . hide_ui_dir_config :
2022-10-09 11:57:19 +08:00
filename = os . path . join ( script_path , " params.txt " )
if os . path . exists ( filename ) :
with open ( filename , " r " , encoding = " utf8 " ) as file :
prompt = file . read ( )
2022-09-24 03:49:21 +08:00
params = parse_generation_parameters ( prompt )
res = [ ]
2022-09-25 14:25:28 +08:00
for output , key in paste_fields :
if callable ( key ) :
v = key ( params )
else :
v = params . get ( key , None )
2022-09-24 03:49:21 +08:00
if v is None :
res . append ( gr . update ( ) )
2022-09-25 14:25:28 +08:00
elif isinstance ( v , type_of_gr_update ) :
res . append ( v )
2022-09-24 03:49:21 +08:00
else :
try :
valtype = type ( output . value )
2022-10-21 21:10:51 +08:00
if valtype == bool and v == " False " :
val = False
else :
val = valtype ( v )
2022-09-24 03:49:21 +08:00
res . append ( gr . update ( value = val ) )
except Exception :
res . append ( gr . update ( ) )
return res
button . click (
fn = paste_func ,
2022-10-29 14:01:04 +08:00
_js = jsfunc ,
2022-09-24 03:49:21 +08:00
inputs = [ input_comp ] ,
2022-09-25 14:25:28 +08:00
outputs = [ x [ 0 ] for x in paste_fields ] ,
2022-09-24 03:49:21 +08:00
)
2022-10-27 13:36:11 +08:00