reduce backup_weight size for float8 freeze model

This commit is contained in:
Won-Kyu Park 2024-10-02 20:02:30 +09:00
parent 1d3dae1471
commit 0ab4d7992c
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 69 additions and 7 deletions

View File

@ -377,13 +377,13 @@ def allowed_layer_without_weight(layer):
return False return False
def store_weights_backup(weight): def store_weights_backup(weight, dtype):
if weight is None: if weight is None:
return None return None
if shared.opts.lora_without_backup_weight: if shared.opts.lora_without_backup_weight:
return True return True
return weight.to(devices.cpu, copy=True) return weight.to(devices.cpu, dtype=dtype, copy=True)
def restore_weights_backup(obj, field, weight): def restore_weights_backup(obj, field, weight):
@ -437,18 +437,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged") raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention): if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight)) weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
else: else:
weights_backup = store_weights_backup(self.weight) weights_backup = store_weights_backup(self.weight, self.org_dtype)
self.network_weights_backup = weights_backup self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None) bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != (): if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = store_weights_backup(self.out_proj.bias) bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
elif getattr(self, 'bias', None) is not None: elif getattr(self, 'bias', None) is not None:
bias_backup = store_weights_backup(self.bias) bias_backup = store_weights_backup(self.bias, self.org_dtype)
else: else:
bias_backup = None bias_backup = None
@ -487,6 +487,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else: else:
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
del weight, bias, updown, ex_bias
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@ -538,6 +539,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.weight += updown_qkv self.weight += updown_qkv
del updown_qkv del updown_qkv
del updown_q, updown_k, updown_v
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
@ -560,6 +562,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp]) updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
self.weight += updown_qkv_mlp self.weight += updown_qkv_mlp
del updown_qkv_mlp del updown_qkv_mlp
del updown_q, updown_k, updown_v, updown_mlp
except RuntimeError as e: except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")

View File

@ -1,4 +1,5 @@
import re import re
import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI from fastapi import FastAPI
@ -9,7 +10,7 @@ import lora # noqa:F401
import lora_patches import lora_patches
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices
def unload(): def unload():
@ -97,6 +98,64 @@ def infotext_pasted(infotext, d):
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
class ScriptLora(scripts.Script):
name = "Lora"
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def after_extra_networks_activate(self, p, *args, **kwargs):
# check modules and setup org_dtype
modules = []
if shared.sd_model.is_sdxl:
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
if not hasattr(embedder, 'wrapped'):
continue
for _name, module in embedder.wrapped.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
else:
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
for _name, module in cond_stage_model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
for _name, module in shared.sd_model.model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
if hasattr(module, 'weight'):
modules.append(module)
elif isinstance(module, torch.nn.MultiheadAttention):
modules.append(module)
print("Total lora modules after_extra_networks_activate() =", len(modules))
target_dtype = devices.dtype_inference
for module in modules:
if isinstance(module, torch.nn.MultiheadAttention):
org_dtype = torch.float32
else:
org_dtype = None
for _name, param in module.named_parameters():
if param.dtype != target_dtype:
org_dtype = param.dtype
break
# set org_dtype
module.org_dtype = org_dtype
script_callbacks.on_infotext_pasted(infotext_pasted) script_callbacks.on_infotext_pasted(infotext_pasted)
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)