Add support for checkpoint merging

This commit is contained in:
William Moorehouse 2022-09-25 19:22:12 -04:00
parent ca3e5519e8
commit 91643f651d
3 changed files with 53 additions and 2 deletions

View File

@ -3,6 +3,8 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
@ -135,3 +137,25 @@ def run_pnginfo(image):
info = f"<div><p>{message}<p></div>" info = f"<div><p>{message}<p></div>"
return '', geninfo, info return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, alpha):
model_0 = torch.load('models/' + modelname_0 + '.ckpt')
model_1 = torch.load('models/' + modelname_1 + '.ckpt')
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt';
torch.save(model_0, output_modelname)
return "<p>Model saved to " + output_modelname + "</p>"

View File

@ -393,7 +393,7 @@ def setup_progressbar(progressbar, preview, id_part):
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -853,6 +853,31 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)")
modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)")
alpha = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Alpha', value=0.3)
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.HTML(elem_id="modelmerger_result")
submit.click(
fn=run_modelmerger,
inputs=[
modelname_0,
modelname_1,
alpha
],
outputs=[
submit_result,
]
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
@ -950,6 +975,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]

View File

@ -85,7 +85,8 @@ def webui():
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img), img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
) )
demo.launch( demo.launch(