mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
skip multihead attn for now
This commit is contained in:
parent
d727ddfccd
commit
fe1967a4c4
@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.lin_module = None
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
# kohya-ss
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.is_kohya = True
|
||||
@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
# alpha is rank if alpha is 0 or None
|
||||
if self.alpha is None:
|
||||
pass
|
||||
self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
|
||||
self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
|
||||
else:
|
||||
raise ValueError("oft_blocks or oft_diag must be in weights dict")
|
||||
|
||||
@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
# raise ValueError("Linear sd_module must have out_features or embed_dim")
|
||||
elif is_other_linear:
|
||||
self.out_dim = self.sd_module.embed_dim
|
||||
#self.org_weight = self.org_module[0].weight
|
||||
# if hasattr(self.sd_module, "in_proj_weight"):
|
||||
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
|
||||
# if hasattr(self.sd_module, "out_proj_weight"):
|
||||
# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
|
||||
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
|
||||
elif is_conv:
|
||||
self.out_dim = self.sd_module.out_channels
|
||||
else:
|
||||
@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.constraint = self.alpha * self.out_dim
|
||||
#elif is_linear or is_conv:
|
||||
else:
|
||||
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
self.constraint = None
|
||||
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
# if is_other_linear:
|
||||
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
|
||||
@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
multiplier = self.multiplier() * self.calc_scale()
|
||||
R = self.get_weight(self.oft_blocks, multiplier)
|
||||
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
merged_weight = self.merge_weight(R, orig_weight)
|
||||
is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
|
||||
if self.is_kohya and not is_other_linear:
|
||||
R = self.get_weight(self.oft_blocks, multiplier)
|
||||
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
merged_weight = self.merge_weight(R, orig_weight)
|
||||
elif not self.is_kohya and not is_other_linear:
|
||||
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||
orig_weight=orig_weight.permute(1, 0)
|
||||
R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
#orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
|
||||
merged_weight = torch.einsum(
|
||||
'k n m, k n ... -> k m ...',
|
||||
R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||
orig_weight=orig_weight.permute(1, 0)
|
||||
#merged_weight=merged_weight.permute(1, 0)
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
output_shape = orig_weight.shape
|
||||
else:
|
||||
# skip for now
|
||||
updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
|
||||
output_shape = (orig_weight.shape[1], orig_weight.shape[1])
|
||||
|
||||
#if self.lin_module is not None:
|
||||
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
|
||||
#else:
|
||||
# orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
# weight = torch.einsum(
|
||||
# 'k n m, k n ... -> k m ...',
|
||||
# R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
|
||||
# orig_weight
|
||||
# )
|
||||
# weight = rearrange(weight, 'k m ... -> (k m) ...')
|
||||
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
output_shape = orig_weight.shape
|
||||
orig_weight = orig_weight
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
Loading…
Reference in New Issue
Block a user