mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
init job and add info to model merge
This commit is contained in:
parent
e9fb9bb0c2
commit
1d9dc48efd
@ -242,6 +242,9 @@ def run_pnginfo(image):
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
|
||||
shared.state.begin()
|
||||
shared.state.job = 'model-merge'
|
||||
|
||||
def weighted_sum(theta0, theta1, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_func1, theta_func2 = theta_funcs[interp_method]
|
||||
|
||||
if theta_func1 and not tertiary_model_info:
|
||||
shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
|
||||
shared.state.end()
|
||||
return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
||||
shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
|
||||
print(f"Loading {secondary_model_info.filename}...")
|
||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||
|
||||
@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_1[key] = torch.zeros_like(theta_1[key])
|
||||
del theta_2
|
||||
|
||||
shared.state.textinfo = f"Loading {primary_model_info.filename}..."
|
||||
print(f"Loading {primary_model_info.filename}...")
|
||||
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
|
||||
|
||||
@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
a = theta_0[key]
|
||||
b = theta_1[key]
|
||||
|
||||
shared.state.textinfo = f'Merging layer {key}'
|
||||
# this enables merging an inpainting model (A) with another one (B);
|
||||
# where normal model would have 4 channels, for latenst space, inpainting model would
|
||||
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
|
||||
@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
|
||||
|
||||
theta_0[key] = theta_func2(a, b, multiplier)
|
||||
|
||||
if save_as_half:
|
||||
@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
shared.state.textinfo = f"Saving to {output_modelname}..."
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
|
||||
sd_models.list_models()
|
||||
|
||||
print("Checkpoint saved.")
|
||||
shared.state.textinfo = "Checkpoint saved to " + output_modelname
|
||||
shared.state.end()
|
||||
|
||||
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
|
||||
|
Loading…
Reference in New Issue
Block a user