diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index c61ae0fe6..e25ba1b63 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -177,12 +177,13 @@ class SD3Cond(torch.nn.Module): self.weights_loaded = False def forward(self, prompts: list[str]): - lg_out, vector_out = self.model_lg(prompts) + with devices.without_autocast(): + lg_out, vector_out = self.model_lg(prompts) - token_count = lg_out.shape[1] + token_count = lg_out.shape[1] - t5_out = self.model_t5(prompts, token_count=token_count) - lgt_out = torch.cat([lg_out, t5_out], dim=-2) + t5_out = self.model_t5(prompts, token_count=token_count) + lgt_out = torch.cat([lg_out, t5_out], dim=-2) return { 'crossattn': lgt_out, diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index c17fd97e9..336e8d2d4 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -47,8 +47,7 @@ class SD3Inferencer(torch.nn.Module): return contextlib.nullcontext() def get_learned_conditioning(self, batch: list[str]): - with devices.without_autocast(): - return self.cond_stage_model(batch) + return self.cond_stage_model(batch) def apply_model(self, x, t, cond): return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) diff --git a/modules/sd_models.py b/modules/sd_models.py index 45575e440..681030442 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -718,16 +718,15 @@ def get_empty_cond(sd_model): p = processing.StableDiffusionProcessingTxt2Img() extra_networks.activate(p, {}) - if hasattr(sd_model, 'conditioner'): + if hasattr(sd_model, 'get_learned_conditioning'): d = sd_model.get_learned_conditioning([""]) - return d['crossattn'] else: d = sd_model.cond_stage_model([""]) - if isinstance(d, dict): - d = d['crossattn'] + if isinstance(d, dict): + d = d['crossattn'] - return d + return d def send_model_to_cpu(m):