Add extra norm module into built-in lora ext

refer to LyCORIS 1.9.0.dev6
add new option and module for training norm layer
(Which is reported to be good for style)
This commit is contained in:
Kohaku-Blueleaf 2023-08-13 02:27:39 +08:00
parent b2080756fc
commit bd4da4474b
4 changed files with 105 additions and 11 deletions

View File

@ -133,7 +133,7 @@ class NetworkModule:
return 1.0 return 1.0
def finalize_updown(self, updown, orig_weight, output_shape): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@ -145,7 +145,10 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel(): if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape) updown = updown.reshape(orig_weight.shape)
return updown * self.calc_scale() * self.multiplier() if ex_bias is None:
ex_bias = 0
return updown * self.calc_scale() * self.multiplier(), ex_bias * self.multiplier()
def calc_updown(self, target): def calc_updown(self, target):
raise NotImplementedError() raise NotImplementedError()

View File

@ -0,0 +1,29 @@
import network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
print("NetworkModuleNorm")
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)

View File

@ -7,6 +7,7 @@ import network_hada
import network_ia3 import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm
import torch import torch
from typing import Union from typing import Union
@ -19,6 +20,7 @@ module_types = [
network_ia3.ModuleTypeIa3(), network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(), network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
] ]
@ -31,6 +33,8 @@ suffix_conversion = {
"resnets": { "resnets": {
"conv1": "in_layers_2", "conv1": "in_layers_2",
"conv2": "out_layers_3", "conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1", "time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection", "conv_shortcut": "skip_connection",
} }
@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory() purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None: if weights_backup is None and bias_backup is None:
return return
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention): if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0]) self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1]) self.out_proj.weight.copy_(weights_backup[1])
else: else:
self.weight.copy_(weights_backup) self.weight.copy_(weights_backup)
if bias_backup is not None:
self.bias.copy_(bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of networks to the weights of torch layer self. Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing. If weights already have this particular set of networks applied, does nothing.
@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
self.network_bias_backup = bias_backup
if current_names != wanted_names: if current_names != wanted_names:
network_restore_weights_from_backup(self) network_restore_weights_from_backup(self)
@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module = net.modules.get(network_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
with torch.no_grad(): with torch.no_grad():
updown = module.calc_updown(self.weight) updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight += updown
if getattr(self, 'bias', None) is not None:
self.bias += ex_bias
continue continue
module_q = net.modules.get(network_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.GroupNorm_forward_before_network(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.LayerNorm_forward_before_network(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs): def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self) network_apply_weights(self)

View File

@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict