misc fixes to support float8 dtype_unet

* devices.dtype_unet, dtype_vae could be considered as storage dtypes (current_dtype)
 * use devices.dtype_inference as computational dtype (taget_dtype)
 * misc fixes to support float8 unet storage
This commit is contained in:
Won-Kyu Park 2024-09-07 01:06:05 +09:00
parent fcd609f4b4
commit 537d9dd71c
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
3 changed files with 10 additions and 10 deletions

View File

@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
sd_models.apply_alpha_schedule_override(p.sd_model, p) sd_models.apply_alpha_schedule_override(p.sd_model, p)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet):
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if p.scripts is not None: if p.scripts is not None:

View File

@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict): if isinstance(cond, dict):
for y in cond.keys(): for y in cond.keys():
if isinstance(cond[y], list): if isinstance(cond[y], list):
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]]
else: else:
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast(): with devices.autocast():
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs)
if devices.unet_needs_upcast: if devices.unet_needs_upcast:
return result.float() return result.float()
else: else:
@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
torch.nn.GELU.__init__(self, *args, **kwargs) torch.nn.GELU.__init__(self, *args, **kwargs)
def forward(self, x): def forward(self, x):
if devices.unet_needs_upcast: if devices.unet_needs_upcast:
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference)
else: else:
return torch.nn.GELU.forward(self, x) return torch.nn.GELU.forward(self, x)
@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64: if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32 dtype = torch.float32
else: else:
dtype = devices.dtype_unet dtype = devices.dtype_inference
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)

View File

@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else: else:
if model is None: if model is None:
model = shared.sd_model model = shared.sd_model
with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 with torch.no_grad(), devices.manual_cast(devices.dtype_vae): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample return x_sample
@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)