vae fix for flux

This commit is contained in:
Won-Kyu Park 2024-09-08 20:21:48 +09:00
parent 7e2d51965f
commit 9c0fd83b5e
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 4 additions and 1 deletions

View File

@ -270,7 +270,8 @@ class FLUX1Inferencer(torch.nn.Module):
with torch.no_grad():
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype
self.first_stage_model.dtype = devices.dtype_vae
self.vae = self.first_stage_model # real vae
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)

View File

@ -956,6 +956,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
else:
weight_dtype_conversion = {
'first_stage_model': None,
'text_encoders': None,
'vae': None,
'alphas_cumprod': None,
'': torch.float16 if loadable_unet_dtype in (torch.float16, torch.float32, torch.bfloat16) else None,
}