mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
skip multihead attn for now
This commit is contained in:
parent
d727ddfccd
commit
fe1967a4c4
@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
super().__init__(net, weights)
|
super().__init__(net, weights)
|
||||||
|
|
||||||
self.lin_module = None
|
self.lin_module = None
|
||||||
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
# kohya-ss
|
# kohya-ss
|
||||||
if "oft_blocks" in weights.w.keys():
|
if "oft_blocks" in weights.w.keys():
|
||||||
self.is_kohya = True
|
self.is_kohya = True
|
||||||
@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
# alpha is rank if alpha is 0 or None
|
# alpha is rank if alpha is 0 or None
|
||||||
if self.alpha is None:
|
if self.alpha is None:
|
||||||
pass
|
pass
|
||||||
self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
|
self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
|
||||||
else:
|
else:
|
||||||
raise ValueError("oft_blocks or oft_diag must be in weights dict")
|
raise ValueError("oft_blocks or oft_diag must be in weights dict")
|
||||||
|
|
||||||
@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
# raise ValueError("Linear sd_module must have out_features or embed_dim")
|
# raise ValueError("Linear sd_module must have out_features or embed_dim")
|
||||||
elif is_other_linear:
|
elif is_other_linear:
|
||||||
self.out_dim = self.sd_module.embed_dim
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
#self.org_weight = self.org_module[0].weight
|
||||||
|
# if hasattr(self.sd_module, "in_proj_weight"):
|
||||||
|
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
|
||||||
|
# if hasattr(self.sd_module, "out_proj_weight"):
|
||||||
|
# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
|
||||||
|
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
|
||||||
elif is_conv:
|
elif is_conv:
|
||||||
self.out_dim = self.sd_module.out_channels
|
self.out_dim = self.sd_module.out_channels
|
||||||
else:
|
else:
|
||||||
@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.constraint = self.alpha * self.out_dim
|
self.constraint = self.alpha * self.out_dim
|
||||||
#elif is_linear or is_conv:
|
#elif is_linear or is_conv:
|
||||||
else:
|
else:
|
||||||
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
|
|
||||||
self.org_module: list[torch.Module] = [self.sd_module]
|
|
||||||
|
|
||||||
# if is_other_linear:
|
# if is_other_linear:
|
||||||
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
|
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
|
||||||
@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
multiplier = self.multiplier() * self.calc_scale()
|
multiplier = self.multiplier() * self.calc_scale()
|
||||||
|
is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
|
||||||
|
if self.is_kohya and not is_other_linear:
|
||||||
R = self.get_weight(self.oft_blocks, multiplier)
|
R = self.get_weight(self.oft_blocks, multiplier)
|
||||||
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
merged_weight = self.merge_weight(R, orig_weight)
|
merged_weight = self.merge_weight(R, orig_weight)
|
||||||
|
elif not self.is_kohya and not is_other_linear:
|
||||||
|
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||||
|
orig_weight=orig_weight.permute(1, 0)
|
||||||
|
R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
#orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
|
||||||
|
merged_weight = torch.einsum(
|
||||||
|
'k n m, k n ... -> k m ...',
|
||||||
|
R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
|
||||||
|
merged_weight
|
||||||
|
)
|
||||||
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
|
||||||
|
orig_weight=orig_weight.permute(1, 0)
|
||||||
|
#merged_weight=merged_weight.permute(1, 0)
|
||||||
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
|
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
else:
|
||||||
|
# skip for now
|
||||||
|
updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
output_shape = (orig_weight.shape[1], orig_weight.shape[1])
|
||||||
|
|
||||||
#if self.lin_module is not None:
|
#if self.lin_module is not None:
|
||||||
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
|
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
|
||||||
#else:
|
#else:
|
||||||
# orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
|
||||||
# weight = torch.einsum(
|
|
||||||
# 'k n m, k n ... -> k m ...',
|
|
||||||
# R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
|
|
||||||
# orig_weight
|
|
||||||
# )
|
|
||||||
# weight = rearrange(weight, 'k m ... -> (k m) ...')
|
|
||||||
|
|
||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
|
||||||
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
|
||||||
output_shape = orig_weight.shape
|
|
||||||
orig_weight = orig_weight
|
orig_weight = orig_weight
|
||||||
|
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user