Merge pull request #12503 from AUTOMATIC1111/extra-norm-module

Add Norm Module to lora ext and add "bias" support
This commit is contained in:
AUTOMATIC1111 2023-08-13 08:28:48 +03:00 committed by GitHub
commit da80d649fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 104 additions and 11 deletions

View File

@ -133,7 +133,7 @@ class NetworkModule:
return 1.0 return 1.0
def finalize_updown(self, updown, orig_weight, output_shape): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@ -145,7 +145,10 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel(): if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape) updown = updown.reshape(orig_weight.shape)
return updown * self.calc_scale() * self.multiplier() if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target): def calc_updown(self, target):
raise NotImplementedError() raise NotImplementedError()

View File

@ -0,0 +1,28 @@
import network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)

View File

@ -7,6 +7,7 @@ import network_hada
import network_ia3 import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm
import torch import torch
from typing import Union from typing import Union
@ -19,6 +20,7 @@ module_types = [
network_ia3.ModuleTypeIa3(), network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(), network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
] ]
@ -31,6 +33,8 @@ suffix_conversion = {
"resnets": { "resnets": {
"conv1": "in_layers_2", "conv1": "in_layers_2",
"conv2": "out_layers_3", "conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1", "time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection", "conv_shortcut": "skip_connection",
} }
@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory() purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None: if weights_backup is None and bias_backup is None:
return return
if isinstance(self, torch.nn.MultiheadAttention): if weights_backup is not None:
self.in_proj_weight.copy_(weights_backup[0]) if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.weight.copy_(weights_backup[1]) self.in_proj_weight.copy_(weights_backup[0])
else: self.out_proj.weight.copy_(weights_backup[1])
self.weight.copy_(weights_backup) else:
self.weight.copy_(weights_backup)
if bias_backup is not None:
self.bias.copy_(bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of networks to the weights of torch layer self. Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing. If weights already have this particular set of networks applied, does nothing.
@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
self.network_bias_backup = bias_backup
if current_names != wanted_names: if current_names != wanted_names:
network_restore_weights_from_backup(self) network_restore_weights_from_backup(self)
@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module = net.modules.get(network_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
with torch.no_grad(): with torch.no_grad():
updown = module.calc_updown(self.weight) updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight += updown
if ex_bias is not None and getattr(self, 'bias', None) is not None:
self.bias += ex_bias
continue continue
module_q = net.modules.get(network_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.GroupNorm_forward_before_network(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.LayerNorm_forward_before_network(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs): def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self) network_apply_weights(self)

View File

@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict