mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 03:40:14 +08:00
Merge pull request #13610 from v0xie/network-glora
Support inference with LyCORIS GLora networks
This commit is contained in:
commit
26500b8c1b
33
extensions-builtin/Lora/network_glora.py
Normal file
33
extensions-builtin/Lora/network_glora.py
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
import network
|
||||
|
||||
class ModuleTypeGLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||
return NetworkModuleGLora(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||
class NetworkModuleGLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["a1.weight"]
|
||||
self.w1b = weights.w["b1.weight"]
|
||||
self.w2a = weights.w["a2.weight"]
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
@ -5,6 +5,7 @@ import re
|
||||
import lora_patches
|
||||
import network
|
||||
import network_lora
|
||||
import network_glora
|
||||
import network_hada
|
||||
import network_ia3
|
||||
import network_lokr
|
||||
@ -23,6 +24,7 @@ module_types = [
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
network_glora.ModuleTypeGLora(),
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user