mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-25 14:28:59 +08:00
fix: get boft params from weight shape
This commit is contained in:
parent
2f1073dc6e
commit
325eaeb584
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import network
|
import network
|
||||||
from lyco_helpers import factorization, butterfly_factor
|
from lyco_helpers import factorization
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
@ -37,10 +37,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
self.is_boft = False
|
self.is_boft = False
|
||||||
if "boft" in weights.w.keys():
|
if weights.w["oft_diag"].dim() == 4:
|
||||||
self.is_boft = True
|
self.is_boft = True
|
||||||
self.boft_b = weights.w["boft_b"]
|
|
||||||
self.boft_m = weights.w["boft_m"]
|
|
||||||
|
|
||||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
@ -59,7 +57,11 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.block_size = self.out_dim // self.dim
|
self.block_size = self.out_dim // self.dim
|
||||||
elif self.is_boft:
|
elif self.is_boft:
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
|
self.boft_m = weights.w["oft_diag"].shape[0]
|
||||||
|
self.block_num = weights.w["oft_diag"].shape[1]
|
||||||
|
self.block_size = weights.w["oft_diag"].shape[2]
|
||||||
|
self.boft_b = self.block_size
|
||||||
|
#self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
|
||||||
else:
|
else:
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
@ -88,8 +90,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
else:
|
else:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
m = self.boft_m.to(device=oft_blocks.device, dtype=oft_blocks.dtype)
|
m = self.boft_m
|
||||||
b = self.boft_b.to(device=oft_blocks.device, dtype=oft_blocks.dtype)
|
b = self.boft_b
|
||||||
r_b = b // 2
|
r_b = b // 2
|
||||||
inp = orig_weight
|
inp = orig_weight
|
||||||
for i in range(m):
|
for i in range(m):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user