2022-09-03 17:08:45 +08:00
2022-10-22 19:07:00 +08:00
import gradio as gr
2022-09-03 17:08:45 +08:00
2023-08-01 12:43:43 +08:00
from modules import sd_models , sd_vae , errors , extras , call_queue
from modules . ui_components import FormRow
2023-06-01 03:40:09 +08:00
from modules . ui_common import create_refresh_button
2022-10-17 03:06:21 +08:00
2023-08-01 12:43:43 +08:00
def update_interp_description ( value ) :
interp_description_css = " <p style= ' margin-bottom: 2.5em ' > {} </p> "
interp_descriptions = {
" No interpolation " : interp_description_css . format ( " No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking. " ) ,
" Weighted sum " : interp_description_css . format ( " A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M " ) ,
" Add difference " : interp_description_css . format ( " The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M " )
}
return interp_descriptions [ value ]
2022-10-14 16:56:41 +08:00
2022-09-03 17:08:45 +08:00
2023-08-01 12:43:43 +08:00
def modelmerger ( * args ) :
2022-10-04 19:35:12 +08:00
try :
2023-08-01 12:43:43 +08:00
results = extras . run_modelmerger ( * args )
except Exception as e :
errors . report ( " Error loading/saving model file " , exc_info = True )
sd_models . list_models ( ) # to remove the potentially missing models from the list
return [ * [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ] , f " Error merging checkpoints: { e } " ]
return results
2023-03-29 03:23:40 +08:00
2023-08-01 12:43:43 +08:00
class UiCheckpointMerger :
def __init__ ( self ) :
with gr . Blocks ( analytics_enabled = False ) as modelmerger_interface :
with gr . Row ( ) . style ( equal_height = False ) :
with gr . Column ( variant = ' compact ' ) :
self . interp_description = gr . HTML ( value = update_interp_description ( " Weighted sum " ) , elem_id = " modelmerger_interp_description " )
2023-03-29 03:23:40 +08:00
2023-08-01 12:43:43 +08:00
with FormRow ( elem_id = " modelmerger_models " ) :
self . primary_model_name = gr . Dropdown ( sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_primary_model_name " , label = " Primary model (A) " )
create_refresh_button ( self . primary_model_name , sd_models . list_models , lambda : { " choices " : sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_A " )
2023-03-29 03:23:40 +08:00
2023-08-01 12:43:43 +08:00
self . secondary_model_name = gr . Dropdown ( sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_secondary_model_name " , label = " Secondary model (B) " )
create_refresh_button ( self . secondary_model_name , sd_models . list_models , lambda : { " choices " : sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_B " )
2023-04-30 00:39:22 +08:00
2023-08-01 12:43:43 +08:00
self . tertiary_model_name = gr . Dropdown ( sd_models . checkpoint_tiles ( ) , elem_id = " modelmerger_tertiary_model_name " , label = " Tertiary model (C) " )
create_refresh_button ( self . tertiary_model_name , sd_models . list_models , lambda : { " choices " : sd_models . checkpoint_tiles ( ) } , " refresh_checkpoint_C " )
2023-04-30 00:39:22 +08:00
2023-08-01 12:43:43 +08:00
self . custom_name = gr . Textbox ( label = " Custom Name (Optional) " , elem_id = " modelmerger_custom_name " )
self . interp_amount = gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.05 , label = ' Multiplier (M) - set to 0 to get model A ' , value = 0.3 , elem_id = " modelmerger_interp_amount " )
self . interp_method = gr . Radio ( choices = [ " No interpolation " , " Weighted sum " , " Add difference " ] , value = " Weighted sum " , label = " Interpolation Method " , elem_id = " modelmerger_interp_method " )
self . interp_method . change ( fn = update_interp_description , inputs = [ self . interp_method ] , outputs = [ self . interp_description ] )
2023-01-01 21:51:12 +08:00
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
2023-08-01 12:43:43 +08:00
self . checkpoint_format = gr . Radio ( choices = [ " ckpt " , " safetensors " ] , value = " safetensors " , label = " Checkpoint format " , elem_id = " modelmerger_checkpoint_format " )
self . save_as_half = gr . Checkbox ( value = False , label = " Save as float16 " , elem_id = " modelmerger_save_as_half " )
self . save_metadata = gr . Checkbox ( value = True , label = " Save metadata (.safetensors only) " , elem_id = " modelmerger_save_metadata " )
2023-01-10 04:35:40 +08:00
with FormRow ( ) :
2023-08-01 12:43:43 +08:00
with gr . Column ( ) :
self . config_source = gr . Radio ( choices = [ " A, B or C " , " B " , " C " , " Don ' t " ] , value = " A, B or C " , label = " Copy config from " , type = " index " , elem_id = " modelmerger_config_method " )
2023-01-10 04:35:40 +08:00
2023-08-01 12:43:43 +08:00
with gr . Column ( ) :
with FormRow ( ) :
self . bake_in_vae = gr . Dropdown ( choices = [ " None " ] + list ( sd_vae . vae_dict ) , value = " None " , label = " Bake in VAE " , elem_id = " modelmerger_bake_in_vae " )
create_refresh_button ( self . bake_in_vae , sd_vae . refresh_vae_list , lambda : { " choices " : [ " None " ] + list ( sd_vae . vae_dict ) } , " modelmerger_refresh_bake_in_vae " )
2023-01-05 01:10:40 +08:00
with FormRow ( ) :
2023-08-01 12:43:43 +08:00
self . discard_weights = gr . Textbox ( value = " " , label = " Discard weights with matching name " , elem_id = " modelmerger_discard_weights " )
2022-10-02 20:03:39 +08:00
with gr . Row ( ) :
2023-08-01 12:43:43 +08:00
self . modelmerger_merge = gr . Button ( elem_id = " modelmerger_merge " , value = " Merge " , variant = ' primary ' )
2023-05-11 04:41:08 +08:00
2023-08-01 12:43:43 +08:00
with gr . Column ( variant = ' compact ' , elem_id = " modelmerger_results_container " ) :
with gr . Group ( elem_id = " modelmerger_results_panel " ) :
self . modelmerger_result = gr . HTML ( elem_id = " modelmerger_result " , show_label = False )
2022-09-10 16:10:00 +08:00
2023-08-01 12:43:43 +08:00
self . blocks = modelmerger_interface
2023-01-04 01:23:17 +08:00
2023-08-01 12:43:43 +08:00
def setup_ui ( self , dummy_component , sd_model_checkpoint_component ) :
self . modelmerger_merge . click ( fn = lambda : ' ' , inputs = [ ] , outputs = [ self . modelmerger_result ] )
self . modelmerger_merge . click (
fn = call_queue . wrap_gradio_gpu_call ( modelmerger , extra_outputs = lambda : [ gr . update ( ) for _ in range ( 4 ) ] ) ,
2023-01-19 14:25:37 +08:00
_js = ' modelmerger ' ,
2022-09-29 05:59:44 +08:00
inputs = [
2023-01-19 14:25:37 +08:00
dummy_component ,
2023-08-01 12:43:43 +08:00
self . primary_model_name ,
self . secondary_model_name ,
self . tertiary_model_name ,
self . interp_method ,
self . interp_amount ,
self . save_as_half ,
self . custom_name ,
self . checkpoint_format ,
self . config_source ,
self . bake_in_vae ,
self . discard_weights ,
self . save_metadata ,
2022-09-29 05:59:44 +08:00
] ,
outputs = [
2023-08-01 12:43:43 +08:00
self . primary_model_name ,
self . secondary_model_name ,
self . tertiary_model_name ,
sd_model_checkpoint_component ,
self . modelmerger_result ,
2022-09-29 05:59:44 +08:00
]
)
2022-09-24 03:49:21 +08:00
2023-08-01 12:43:43 +08:00
# Required as a workaround for change() event not triggering when loading values from ui-config.json
self . interp_description . value = update_interp_description ( self . interp_method . value )
2023-06-03 18:55:35 +08:00