mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
misc fixes to support float8 dtype_unet
* devices.dtype_unet, dtype_vae could be considered as storage dtypes (current_dtype) * use devices.dtype_inference as computational dtype (taget_dtype) * misc fixes to support float8 unet storage
This commit is contained in:
parent
fcd609f4b4
commit
537d9dd71c
@ -984,7 +984,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(target_dtype=devices.dtype_inference, current_dtype=devices.dtype_unet):
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
|
@ -42,12 +42,12 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||
if isinstance(cond, dict):
|
||||
for y in cond.keys():
|
||||
if isinstance(cond[y], list):
|
||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
cond[y] = [x.to(devices.dtype_inference) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||
else:
|
||||
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
cond[y] = cond[y].to(devices.dtype_inference) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||
|
||||
with devices.autocast():
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
|
||||
result = orig_func(self, x_noisy.to(devices.dtype_inference), t.to(devices.dtype_inference), cond, **kwargs)
|
||||
if devices.unet_needs_upcast:
|
||||
return result.float()
|
||||
else:
|
||||
@ -107,7 +107,7 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||
def forward(self, x):
|
||||
if devices.unet_needs_upcast:
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_inference)
|
||||
else:
|
||||
return torch.nn.GELU.forward(self, x)
|
||||
|
||||
@ -125,11 +125,11 @@ unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
||||
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_inference), unet_needs_upcast)
|
||||
|
||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_inference), unet_needs_upcast)
|
||||
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||
|
||||
first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
|
||||
@ -146,7 +146,7 @@ def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
|
||||
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = devices.dtype_unet
|
||||
dtype = devices.dtype_inference
|
||||
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
with torch.no_grad(), devices.manual_cast(devices.dtype_vae): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
|
||||
return x_sample
|
||||
@ -64,7 +64,7 @@ def single_sample_to_image(sample, approximation=None):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
x_sample = 255. * np.moveaxis(x_sample.to(dtype=devices.dtype).cpu().numpy(), 0, 2)
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
|
||||
return Image.fromarray(x_sample)
|
||||
|
Loading…
Reference in New Issue
Block a user