mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-10 23:52:54 +08:00
support loading .yaml config with same name as model
support EMA weights in processing (????)
This commit is contained in:
parent
432782163a
commit
050a6a798c
@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(all_prompts, all_seeds, all_subseeds)
|
p.init(all_prompts, all_seeds, all_subseeds)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from modules.paths import models_path
|
|||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -63,14 +63,20 @@ def list_models():
|
|||||||
if os.path.exists(cmd_ckpt):
|
if os.path.exists(cmd_ckpt):
|
||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
|
||||||
shared.opts.data['sd_model_checkpoint'] = title
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
for filename in model_list:
|
for filename in model_list:
|
||||||
h = model_hash(filename)
|
h = model_hash(filename)
|
||||||
title, short_model_name = modeltitle(filename, h)
|
title, short_model_name = modeltitle(filename, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
|
|
||||||
|
basename, _ = os.path.splitext(filename)
|
||||||
|
config = basename + ".yaml"
|
||||||
|
if not os.path.exists(config):
|
||||||
|
config = shared.cmd_opts.config
|
||||||
|
|
||||||
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(searchString):
|
||||||
@ -116,7 +122,10 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_file, sd_model_hash):
|
def load_model_weights(model, checkpoint_info):
|
||||||
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
sd_model_hash = checkpoint_info.hash
|
||||||
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||||
|
|
||||||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||||
@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
|||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpoint = checkpoint_file
|
model.sd_model_checkpoint = checkpoint_file
|
||||||
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = select_checkpoint()
|
||||||
|
|
||||||
sd_config = OmegaConf.load(shared.cmd_opts.config)
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
|
print(f"Loading config from: {shared.cmd_opts.config}")
|
||||||
|
|
||||||
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None):
|
|||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||||
|
return load_model()
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None):
|
|||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user