From 44a8480f0c9979dc143d9cb0f5c92aa9960474ac Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 10 Sep 2024 19:05:41 +0900 Subject: [PATCH] minor update * use dtype_inference --- modules/models/flux/flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 14fa4e255..46fd568a0 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -78,10 +78,10 @@ class FluxCond(torch.nn.Module): self.tokenizer = FluxTokenizer() with torch.no_grad(): - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) if shared.opts.flux_enable_t5: - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference) else: self.t5xxl = None @@ -202,8 +202,8 @@ class BaseModel(torch.nn.Module): def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() timestep = self.model_sampling.timestep(sigma).float() - guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype) - model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float() + guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32) + model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype) return self.model_sampling.calculate_denoised(sigma, model_output, x) def forward(self, *args, **kwargs): @@ -268,7 +268,7 @@ class FLUX1Inferencer(torch.nn.Module): diffusion_model_prefix = "" with torch.no_grad(): - self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype) + self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) self.first_stage_model.dtype = devices.dtype_vae self.vae = self.first_stage_model # real vae