From 5159edbf0e0e1d5a25fbd588e000487746790117 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:44:37 +0900 Subject: [PATCH] Store base_vae and loaded_vae_file in sd_model --- modules/sd_models.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c976561e..150d550b6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,7 +462,6 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] - self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -489,24 +488,19 @@ class SdModelData: def set_sd_model(self, v, already_loaded=False): self.sd_model = v if already_loaded: - sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(v, "base_vae", None) + sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None) + sd_vae.checkpoint_info = v.sd_checkpoint_info try: self.loaded_sd_models.remove(v) - self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: + setattr(v, "base_vae", sd_vae.base_vae) + setattr(v, "loaded_vae_file", sd_vae.loaded_vae_file) self.loaded_sd_models.insert(0, v) - self.loaded_vae_states[v.sd_model_hash] = dict( - base_vae=sd_vae.base_vae, - loaded_vae_file=sd_vae.loaded_vae_file, - checkpoint_info=sd_vae.checkpoint_info, - ) model_data = SdModelData() @@ -661,7 +655,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() - model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -691,10 +684,9 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model - sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(sd_model, "base_vae", None) + sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None) + sd_vae.checkpoint_info = sd_model.sd_checkpoint_info print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model