mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-08 06:32:54 +08:00
minor update
* use dtype_inference
This commit is contained in:
parent
789bfc7db4
commit
44a8480f0c
@ -78,10 +78,10 @@ class FluxCond(torch.nn.Module):
|
||||
self.tokenizer = FluxTokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype_inference, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
|
||||
if shared.opts.flux_enable_t5:
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype_inference)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
|
||||
@ -202,8 +202,8 @@ class BaseModel(torch.nn.Module):
|
||||
def apply_model(self, x, sigma, c_crossattn=None, y=None):
|
||||
dtype = self.get_dtype()
|
||||
timestep = self.model_sampling.timestep(sigma).float()
|
||||
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype)
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float()
|
||||
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=torch.float32)
|
||||
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).to(x.dtype)
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@ -268,7 +268,7 @@ class FLUX1Inferencer(torch.nn.Module):
|
||||
diffusion_model_prefix = ""
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype)
|
||||
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference)
|
||||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
|
||||
self.first_stage_model.dtype = devices.dtype_vae
|
||||
self.vae = self.first_stage_model # real vae
|
||||
|
Loading…
Reference in New Issue
Block a user