minor update

* use dtype_inference
This commit is contained in:
Won-Kyu Park 2024-09-10 19:05:41 +09:00
parent 789bfc7db4
commit 44a8480f0c
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -78,10 +78,10 @@ class FluxCond(torch.nn.Module):
self.tokenizer = FluxTokenizer()
with torch.no_grad():
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.flux_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
else:
self.t5xxl = None
@ -202,8 +202,8 @@ class BaseModel(torch.nn.Module):
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype)
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float()
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32)
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype)
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
@ -268,7 +268,7 @@ class FLUX1Inferencer(torch.nn.Module):
diffusion_model_prefix = ""
with torch.no_grad():
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype)
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = devices.dtype_vae
self.vae = self.first_stage_model # real vae