mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
reduce backup_weight size for float8 freeze model
This commit is contained in:
parent
1d3dae1471
commit
0ab4d7992c
@ -377,13 +377,13 @@ def allowed_layer_without_weight(layer):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def store_weights_backup(weight):
|
def store_weights_backup(weight, dtype):
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if shared.opts.lora_without_backup_weight:
|
if shared.opts.lora_without_backup_weight:
|
||||||
return True
|
return True
|
||||||
return weight.to(devices.cpu, copy=True)
|
return weight.to(devices.cpu, dtype=dtype, copy=True)
|
||||||
|
|
||||||
|
|
||||||
def restore_weights_backup(obj, field, weight):
|
def restore_weights_backup(obj, field, weight):
|
||||||
@ -437,18 +437,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
|
raise RuntimeError(f"{network_layer_name} - no backup weights found and current weights are not unchanged")
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))
|
weights_backup = (store_weights_backup(self.in_proj_weight, self.org_dtype), store_weights_backup(self.out_proj.weight, self.org_dtype))
|
||||||
else:
|
else:
|
||||||
weights_backup = store_weights_backup(self.weight)
|
weights_backup = store_weights_backup(self.weight, self.org_dtype)
|
||||||
|
|
||||||
self.network_weights_backup = weights_backup
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
bias_backup = getattr(self, "network_bias_backup", None)
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
if bias_backup is None and wanted_names != ():
|
if bias_backup is None and wanted_names != ():
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||||
bias_backup = store_weights_backup(self.out_proj.bias)
|
bias_backup = store_weights_backup(self.out_proj.bias, self.org_dtype)
|
||||||
elif getattr(self, 'bias', None) is not None:
|
elif getattr(self, 'bias', None) is not None:
|
||||||
bias_backup = store_weights_backup(self.bias)
|
bias_backup = store_weights_backup(self.bias, self.org_dtype)
|
||||||
else:
|
else:
|
||||||
bias_backup = None
|
bias_backup = None
|
||||||
|
|
||||||
@ -487,6 +487,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||||
else:
|
else:
|
||||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||||
|
del weight, bias, updown, ex_bias
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
@ -538,6 +539,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
self.weight += updown_qkv
|
self.weight += updown_qkv
|
||||||
del updown_qkv
|
del updown_qkv
|
||||||
|
del updown_q, updown_k, updown_v
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
@ -560,6 +562,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
updown_qkv_mlp = torch.vstack([updown_q, updown_k, updown_v, updown_mlp])
|
||||||
self.weight += updown_qkv_mlp
|
self.weight += updown_qkv_mlp
|
||||||
del updown_qkv_mlp
|
del updown_qkv_mlp
|
||||||
|
del updown_q, updown_k, updown_v, updown_mlp
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -9,7 +10,7 @@ import lora # noqa:F401
|
|||||||
import lora_patches
|
import lora_patches
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, scripts, devices
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
@ -97,6 +98,64 @@ def infotext_pasted(infotext, d):
|
|||||||
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptLora(scripts.Script):
|
||||||
|
name = "Lora"
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||||
|
# check modules and setup org_dtype
|
||||||
|
modules = []
|
||||||
|
if shared.sd_model.is_sdxl:
|
||||||
|
for _i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||||
|
if not hasattr(embedder, 'wrapped'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for _name, module in embedder.wrapped.named_modules():
|
||||||
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
modules.append(module)
|
||||||
|
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||||
|
modules.append(module)
|
||||||
|
|
||||||
|
else:
|
||||||
|
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
|
||||||
|
|
||||||
|
for _name, module in cond_stage_model.named_modules():
|
||||||
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
modules.append(module)
|
||||||
|
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||||
|
modules.append(module)
|
||||||
|
|
||||||
|
for _name, module in shared.sd_model.model.named_modules():
|
||||||
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention)):
|
||||||
|
if hasattr(module, 'weight'):
|
||||||
|
modules.append(module)
|
||||||
|
elif isinstance(module, torch.nn.MultiheadAttention):
|
||||||
|
modules.append(module)
|
||||||
|
|
||||||
|
print("Total lora modules after_extra_networks_activate() =", len(modules))
|
||||||
|
|
||||||
|
target_dtype = devices.dtype_inference
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, torch.nn.MultiheadAttention):
|
||||||
|
org_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
org_dtype = None
|
||||||
|
for _name, param in module.named_parameters():
|
||||||
|
if param.dtype != target_dtype:
|
||||||
|
org_dtype = param.dtype
|
||||||
|
break
|
||||||
|
|
||||||
|
# set org_dtype
|
||||||
|
module.org_dtype = org_dtype
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
|
||||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||||
|
Loading…
Reference in New Issue
Block a user