support loading clip/t5 from the main model checkpoint

This commit is contained in:
AUTOMATIC1111 2024-06-29 00:38:52 +03:00
parent d67348a0a5
commit 7e4b06fcd0
3 changed files with 24 additions and 25 deletions

View File

@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g) self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
self.model_t5 = Sd3T5(self.t5xxl) self.model_t5 = Sd3T5(self.t5xxl)
self.weights_loaded = False
def forward(self, prompts: list[str]): def forward(self, prompts: list[str]):
with devices.without_autocast(): with devices.without_autocast():
lg_out, vector_out = self.model_lg(prompts) lg_out, vector_out = self.model_lg(prompts)
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
token_count = lg_out.shape[1]
t5_out = self.model_t5(prompts, token_count=token_count)
lgt_out = torch.cat([lg_out, t5_out], dim=-2) lgt_out = torch.cat([lg_out, t5_out], dim=-2)
return { return {
@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
'vector': vector_out, 'vector': vector_out,
} }
def load_weights(self): def before_load_weights(self, state_dict):
if self.weights_loaded:
return
clip_path = os.path.join(shared.models_path, "CLIP") clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file: with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file: with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl: if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file: with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
self.weights_loaded = True
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX return torch.tensor([[0]], device=devices.device) # XXX

View File

@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.cond_stage_model = SD3Cond() self.text_encoders = SD3Cond()
self.cond_stage_key = 'txt' self.cond_stage_key = 'txt'
self.parameterization = "eps" self.parameterization = "eps"
@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
self.latent_format = SD3LatentFormat() self.latent_format = SD3LatentFormat()
self.latent_channels = 16 self.latent_channels = 16
def after_load_weights(self): @property
self.cond_stage_model.load_weights() def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self): def ema_scope(self):
return contextlib.nullcontext() return contextlib.nullcontext()

View File

@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy() checkpoints_loaded[checkpoint_info] = state_dict.copy()
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model") timer.record("apply weights to model")
if hasattr(model, "after_load_weights"):
model.after_load_weights(state_dict)
del state_dict del state_dict
# Set is_sdxl_inpaint flag. # Set is_sdxl_inpaint flag.
@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if hasattr(sd_model, "after_load_weights"):
sd_model.after_load_weights()
timer.record("load weights from state dict") timer.record("load weights from state dict")
send_model_to_device(sd_model) send_model_to_device(sd_model)