mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
reduce backup_weight size for float8 freeze model
This commit is contained in:
parent
1d3dae1471
commit
0ab4d7992c
@ -377,13 +377,13 @@ def allowed_layer_without_weight(layer):
|
||||
return False
|
||||
|
||||
|
||||
def store_weights_backup(weight):
|
||||
def store_weights_backup(weight, dtype):
|
||||
if weight is None:
|
||||
return None
|
||||
|
||||
if shared.opts.lora_without_backup_weight:
|
||||
return True
|
||||
return weight.to(devices.cpu, copy=True)
|
||||
return weight.to(devices.cpu, dtype=dtype, copy=True)
|
||||
|
||||
|
||||
def restore_weights_backup(obj, field, weight):
|
||||
@ -437,18 +437,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
|
||||
weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
|
||||
else:
|
||||
weights_backup = store_weights_backup(self.weight)
|
||||
weights_backup = store_weights_backup(self.weight, self.org_dtype)
|
||||
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None and wanted_names != ():
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = store_weights_backup(self.out_proj.bias)
|
||||
bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = store_weights_backup(self.bias)
|
||||
bias_backup = store_weights_backup(self.bias, self.org_dtype)
|
||||
else:
|
||||
bias_backup = None
|
||||
|
||||
@ -487,6 +487,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||
else:
|
||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||
del weight, bias, updown, ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
@ -538,6 +539,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
self.weight += updown_qkv
|
||||
del updown_qkv
|
||||
del updown_q, updown_k, updown_v
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
@ -560,6 +562,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
||||
self.weight += updown_qkv_mlp
|
||||
del updown_qkv_mlp
|
||||
del updown_q, updown_k, updown_v, updown_mlp
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
@ -9,7 +10,7 @@ import lora # noqa:F401
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices
|
||||
|
||||
|
||||
def unload():
|
||||
@ -97,6 +98,64 @@ def infotext_pasted(infotext, d):
|
||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||
|
||||
|
||||
class ScriptLora(scripts.Script):
|
||||
name = "Lora"
|
||||
|
||||
def title(self):
|
||||
return self.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||
# check modules and setup org_dtype
|
||||
modules = []
|
||||
if shared.sd_model.is_sdxl:
|
||||
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
|
||||
for _name, module in embedder.wrapped.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
else:
|
||||
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
|
||||
|
||||
for _name, module in cond_stage_model.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
for _name, module in shared.sd_model.model.named_modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||
if hasattr(module, 'weight'):
|
||||
modules.append(module)
|
||||
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||
modules.append(module)
|
||||
|
||||
print("Total lora modules after_extra_networks_activate() =", len(modules))
|
||||
|
||||
target_dtype = devices.dtype_inference
|
||||
for module in modules:
|
||||
if isinstance(module, torch.nn.MultiheadAttention):
|
||||
org_dtype = torch.float32
|
||||
else:
|
||||
org_dtype = None
|
||||
for _name, param in module.named_parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
# set org_dtype
|
||||
module.org_dtype = org_dtype
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||
|
Loading…
Reference in New Issue
Block a user