mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
ManualCast for 10/16 series gpu
This commit is contained in:
parent
0beb131c7f
commit
d4d3134f6d
@ -16,6 +16,23 @@ def has_mps() -> bool:
|
|||||||
return mac_specific.has_mps
|
return mac_specific.has_mps
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_no_autocast(device_id=None) -> bool:
|
||||||
|
if device_id is None:
|
||||||
|
device_id = get_cuda_device_id()
|
||||||
|
return (
|
||||||
|
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||||
|
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_device_id():
|
||||||
|
return (
|
||||||
|
int(shared.cmd_opts.device_id)
|
||||||
|
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||||
|
else 0
|
||||||
|
) or torch.cuda.current_device()
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
if shared.cmd_opts.device_id is not None:
|
if shared.cmd_opts.device_id is not None:
|
||||||
return f"cuda:{shared.cmd_opts.device_id}"
|
return f"cuda:{shared.cmd_opts.device_id}"
|
||||||
@ -60,8 +77,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
if cuda_no_autocast():
|
||||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@ -92,15 +108,44 @@ def cond_cast_float(input):
|
|||||||
|
|
||||||
|
|
||||||
nv_rng = None
|
nv_rng = None
|
||||||
|
patch_module_list = [
|
||||||
|
torch.nn.Linear,
|
||||||
|
torch.nn.Conv2d,
|
||||||
|
torch.nn.MultiheadAttention,
|
||||||
|
torch.nn.GroupNorm,
|
||||||
|
torch.nn.LayerNorm,
|
||||||
|
]
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def manual_autocast():
|
||||||
|
def manual_cast_forward(self, *args, **kwargs):
|
||||||
|
org_dtype = next(self.parameters()).dtype
|
||||||
|
self.to(dtype)
|
||||||
|
result = self.org_forward(*args, **kwargs)
|
||||||
|
self.to(org_dtype)
|
||||||
|
return result
|
||||||
|
for module_type in patch_module_list:
|
||||||
|
org_forward = module_type.forward
|
||||||
|
module_type.forward = manual_cast_forward
|
||||||
|
module_type.org_forward = org_forward
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
for module_type in patch_module_list:
|
||||||
|
module_type.forward = module_type.org_forward
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False, unet=False):
|
def autocast(disable=False):
|
||||||
|
print(fp8, dtype, shared.cmd_opts.precision, device)
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if unet and fp8 and device==cpu:
|
if fp8 and device==cpu:
|
||||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
|
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
|
||||||
|
return manual_autocast()
|
||||||
|
|
||||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True):
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
if getattr(samples_ddim, 'already_decoded', False):
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
|
@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
if enable_fp8:
|
if enable_fp8:
|
||||||
devices.fp8 = True
|
devices.fp8 = True
|
||||||
|
if model.is_sdxl:
|
||||||
|
cond_stage = model.conditioner
|
||||||
|
else:
|
||||||
|
cond_stage = model.cond_stage_model
|
||||||
|
|
||||||
|
for module in cond_stage.modules():
|
||||||
|
if isinstance(module, torch.nn.Linear):
|
||||||
|
module.to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
if devices.device == devices.cpu:
|
if devices.device == devices.cpu:
|
||||||
for module in model.model.diffusion_model.modules():
|
for module in model.model.diffusion_model.modules():
|
||||||
if isinstance(module, torch.nn.Conv2d):
|
if isinstance(module, torch.nn.Conv2d):
|
||||||
module.to(torch.float8_e4m3fn)
|
module.to(torch.float8_e4m3fn)
|
||||||
elif isinstance(module, torch.nn.Linear):
|
elif isinstance(module, torch.nn.Linear):
|
||||||
module.to(torch.float8_e4m3fn)
|
module.to(torch.float8_e4m3fn)
|
||||||
timer.record("apply fp8 unet for cpu")
|
|
||||||
else:
|
else:
|
||||||
if model.is_sdxl:
|
|
||||||
cond_stage = model.conditioner
|
|
||||||
else:
|
|
||||||
cond_stage = model.cond_stage_model
|
|
||||||
for module in cond_stage.modules():
|
|
||||||
if isinstance(module, torch.nn.Linear):
|
|
||||||
module.to(torch.float8_e4m3fn)
|
|
||||||
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
|
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
|
||||||
timer.record("apply fp8 unet")
|
timer.record("apply fp8")
|
||||||
|
else:
|
||||||
|
devices.fp8 = False
|
||||||
|
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user