Option for using fp16 weight when apply lora

This commit is contained in:
Kohaku-Blueleaf 2023-11-21 19:59:34 +08:00
parent b2e039d07b
commit 370a77f8e7
4 changed files with 25 additions and 7 deletions

View File

@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if getattr(self, 'fp16_weight', None) is None:
weight = self.weight
bias = self.bias
else:
weight = self.fp16_weight.clone().to(self.weight.device)
bias = getattr(self, 'fp16_bias', None)
if bias is not None:
bias = bias.clone().to(self.bias.device)
updown, ex_bias = module.calc_updown(weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype))
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1

View File

@ -178,6 +178,7 @@ def configure_opts_onchange():
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
startup_timer.record("opts onchange")

View File

@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
devices.dtype_unet = torch.float16
timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.clone().half()
if module.bias is not None:
module.fp16_bias = module.bias.clone().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")

View File

@ -201,6 +201,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {