From 537d9dd71c92fd97def33a55150c70e1d7d80e27 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 7 Sep 2024 01:06:05 +0900 Subject: [PATCH] misc fixes to support float8 dtype_unet * devices.dtype_unet, dtype_vae could be considered as storage dtypes (current_dtype) * use devices.dtype_inference as computational dtype (taget_dtype) * misc fixes to support float8 unet storage --- modules/processing.py | 2 +- modules/sd_hijack_unet.py | 14 +++++++------- modules/sd_samplers_common.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92c3582cc..b1107e757 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: sd_models.apply_alpha_schedule_override(p.sd_model, p) - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if p.scripts is not None: diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index b4f03b138..842030be8 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): if isinstance(cond[y], list): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]] else: - cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] + cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): - result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs) if devices.unet_needs_upcast: return result.float() else: @@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): torch.nn.GELU.__init__(self, *args, **kwargs) def forward(self, x): if devices.unet_needs_upcast: - return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference) else: return torch.nn.GELU.forward(self, x) @@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) - CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast) CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 @@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): if devices.unet_needs_upcast and timesteps.dtype == torch.int64: dtype = torch.float32 else: - dtype = devices.dtype_unet + dtype = devices.dtype_inference return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index c060cccb2..28b8bd820 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + with torch.no_grad(), devices.manual_cast(devices.dtype_vae): # fixes an issue with unstable VAEs that are flaky even in fp32 x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample @@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None): x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample)