mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
Add extra norm module into built-in lora ext
refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style)
This commit is contained in:
parent
b2080756fc
commit
bd4da4474b
@ -133,7 +133,7 @@ class NetworkModule:
|
|||||||
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
@ -145,7 +145,10 @@ class NetworkModule:
|
|||||||
if orig_weight.size().numel() == updown.size().numel():
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
updown = updown.reshape(orig_weight.shape)
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
return updown * self.calc_scale() * self.multiplier()
|
if ex_bias is None:
|
||||||
|
ex_bias = 0
|
||||||
|
|
||||||
|
return updown * self.calc_scale() * self.multiplier(), ex_bias * self.multiplier()
|
||||||
|
|
||||||
def calc_updown(self, target):
|
def calc_updown(self, target):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
29
extensions-builtin/Lora/network_norm.py
Normal file
29
extensions-builtin/Lora/network_norm.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeNorm(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||||
|
return NetworkModuleNorm(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleNorm(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
print("NetworkModuleNorm")
|
||||||
|
|
||||||
|
self.w_norm = weights.w.get("w_norm")
|
||||||
|
self.b_norm = weights.w.get("b_norm")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.w_norm.shape
|
||||||
|
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
if self.b_norm is not None:
|
||||||
|
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
@ -7,6 +7,7 @@ import network_hada
|
|||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
|
import network_norm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -19,6 +20,7 @@ module_types = [
|
|||||||
network_ia3.ModuleTypeIa3(),
|
network_ia3.ModuleTypeIa3(),
|
||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
|
network_norm.ModuleTypeNorm(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +33,8 @@ suffix_conversion = {
|
|||||||
"resnets": {
|
"resnets": {
|
||||||
"conv1": "in_layers_2",
|
"conv1": "in_layers_2",
|
||||||
"conv2": "out_layers_3",
|
"conv2": "out_layers_3",
|
||||||
|
"norm1": "in_layers_0",
|
||||||
|
"norm2": "out_layers_0",
|
||||||
"time_emb_proj": "emb_layers_1",
|
"time_emb_proj": "emb_layers_1",
|
||||||
"conv_shortcut": "skip_connection",
|
"conv_shortcut": "skip_connection",
|
||||||
}
|
}
|
||||||
@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
purge_networks_from_memory()
|
purge_networks_from_memory()
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
|
||||||
if weights_backup is None:
|
if weights_backup is None and bias_backup is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if weights_backup is not None:
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
else:
|
else:
|
||||||
self.weight.copy_(weights_backup)
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
if bias_backup is not None:
|
||||||
|
self.bias.copy_(bias_backup)
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
||||||
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
self.network_weights_backup = weights_backup
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
if bias_backup is None and getattr(self, 'bias', None) is not None:
|
||||||
|
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||||
|
self.network_bias_backup = bias_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
network_restore_weights_from_backup(self)
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown = module.calc_updown(self.weight)
|
updown, ex_bias = module.calc_updown(self.weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight += updown
|
||||||
|
if getattr(self, 'bias', None) is not None:
|
||||||
|
self.bias += ex_bias
|
||||||
continue
|
continue
|
||||||
|
|
||||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
|||||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.GroupNorm_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.LayerNorm_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
|||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
|
||||||
|
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
|
||||||
|
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
|
||||||
|
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
|
||||||
|
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||||
|
|
||||||
@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
|
|||||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||||
|
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
|
||||||
|
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
|
||||||
|
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
|
||||||
|
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
|
||||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user