skip multihead attn for now

This commit is contained in:
v0xie 2023-11-03 17:52:55 -07:00
parent d727ddfccd
commit fe1967a4c4

View File

@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule):
super().__init__(net, weights) super().__init__(net, weights)
self.lin_module = None self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
# kohya-ss # kohya-ss
if "oft_blocks" in weights.w.keys(): if "oft_blocks" in weights.w.keys():
self.is_kohya = True self.is_kohya = True
@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule):
# alpha is rank if alpha is 0 or None # alpha is rank if alpha is 0 or None
if self.alpha is None: if self.alpha is None:
pass pass
self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
else: else:
raise ValueError("oft_blocks or oft_diag must be in weights dict") raise ValueError("oft_blocks or oft_diag must be in weights dict")
@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule):
# raise ValueError("Linear sd_module must have out_features or embed_dim") # raise ValueError("Linear sd_module must have out_features or embed_dim")
elif is_other_linear: elif is_other_linear:
self.out_dim = self.sd_module.embed_dim self.out_dim = self.sd_module.embed_dim
#self.org_weight = self.org_module[0].weight
# if hasattr(self.sd_module, "in_proj_weight"):
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
# if hasattr(self.sd_module, "out_proj_weight"):
# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
elif is_conv: elif is_conv:
self.out_dim = self.sd_module.out_channels self.out_dim = self.sd_module.out_channels
else: else:
@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.constraint = self.alpha * self.out_dim self.constraint = self.alpha * self.out_dim
#elif is_linear or is_conv: #elif is_linear or is_conv:
else: else:
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
self.constraint = None self.constraint = None
self.org_module: list[torch.Module] = [self.sd_module]
# if is_other_linear: # if is_other_linear:
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule):
def calc_updown(self, orig_weight): def calc_updown(self, orig_weight):
multiplier = self.multiplier() * self.calc_scale() multiplier = self.multiplier() * self.calc_scale()
is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
if self.is_kohya and not is_other_linear:
R = self.get_weight(self.oft_blocks, multiplier) R = self.get_weight(self.oft_blocks, multiplier)
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
merged_weight = self.merge_weight(R, orig_weight) merged_weight = self.merge_weight(R, orig_weight)
elif not self.is_kohya and not is_other_linear:
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
orig_weight=orig_weight.permute(1, 0)
R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
#orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
orig_weight=orig_weight.permute(1, 0)
#merged_weight=merged_weight.permute(1, 0)
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
else:
# skip for now
updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
output_shape = (orig_weight.shape[1], orig_weight.shape[1])
#if self.lin_module is not None: #if self.lin_module is not None:
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
# weight = torch.mul(torch.mul(R, multiplier), orig_weight) # weight = torch.mul(torch.mul(R, multiplier), orig_weight)
#else: #else:
# orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
# weight = torch.einsum(
# 'k n m, k n ... -> k m ...',
# R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
# orig_weight
# )
# weight = rearrange(weight, 'k m ... -> (k m) ...')
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
orig_weight = orig_weight orig_weight = orig_weight
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)