Merge pull request #14547 from AUTOMATIC1111/lyco-forward

Implement general forward method for all method in built-in lora ext
This commit is contained in:
AUTOMATIC1111 2024-01-06 10:50:06 +03:00 committed by GitHub
commit a4ee64050a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 7 deletions

View File

@ -3,6 +3,9 @@ import os
from collections import namedtuple from collections import namedtuple
import enum import enum
import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@ -115,6 +118,29 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'): if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape self.shape = self.sd_module.weight.shape
self.ops = None
self.extra_kwargs = {}
if isinstance(self.sd_module, nn.Conv2d):
self.ops = F.conv2d
self.extra_kwargs = {
'stride': self.sd_module.stride,
'padding': self.sd_module.padding
}
elif isinstance(self.sd_module, nn.Linear):
self.ops = F.linear
elif isinstance(self.sd_module, nn.LayerNorm):
self.ops = F.layer_norm
self.extra_kwargs = {
'normalized_shape': self.sd_module.normalized_shape,
'eps': self.sd_module.eps
}
elif isinstance(self.sd_module, nn.GroupNorm):
self.ops = F.group_norm
self.extra_kwargs = {
'num_groups': self.sd_module.num_groups,
'eps': self.sd_module.eps
}
self.dim = None self.dim = None
self.bias = weights.w.get("bias") self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@ -155,5 +181,10 @@ class NetworkModule:
raise NotImplementedError() raise NotImplementedError()
def forward(self, x, y): def forward(self, x, y):
raise NotImplementedError() """A general forward implementation for all modules"""
if self.ops is None:
raise NotImplementedError()
else:
updown, ex_bias = self.calc_updown(self.sd_module.weight)
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)

View File

@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names self.network_current_names = wanted_names
def network_forward(module, input, original_forward): def network_forward(org_module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_networks) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(org_module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module) network_restore_weights_from_backup(org_module)
network_reset_cached_weight(module) network_reset_cached_weight(org_module)
y = original_forward(module, input) y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None) network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks: for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None: