mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-11 08:02:53 +08:00
make main model loading and model merger use the same code
This commit is contained in:
parent
050a6a798c
commit
c77c89cc83
@ -170,8 +170,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
theta_0 = primary_model['state_dict']
|
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
|
||||||
theta_1 = secondary_model['state_dict']
|
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
|
||||||
|
|
||||||
theta_funcs = {
|
theta_funcs = {
|
||||||
"Weighted Sum": weighted_sum,
|
"Weighted Sum": weighted_sum,
|
||||||
|
@ -122,6 +122,13 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict_from_checkpoint(pl_sd):
|
||||||
|
if "state_dict" in pl_sd:
|
||||||
|
return pl_sd["state_dict"]
|
||||||
|
|
||||||
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info):
|
def load_model_weights(model, checkpoint_info):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
sd_model_hash = checkpoint_info.hash
|
sd_model_hash = checkpoint_info.hash
|
||||||
@ -132,10 +139,7 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
|
||||||
if "state_dict" in pl_sd:
|
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
else:
|
|
||||||
sd = pl_sd
|
|
||||||
|
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user