diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff61b3699..e102eafc1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,11 +1,12 @@ import torch import network from einops import rearrange +from modules import devices class ModuleTypeOFT(network.ModuleType): def create_module(self, net: network.Network, weights: network.NetworkWeights): - if all(x in weights.w for x in ["oft_blocks"]): + if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]): return NetworkModuleOFT(net, weights) return None @@ -16,66 +17,117 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) - self.oft_blocks = weights.w["oft_blocks"] - self.alpha = weights.w["alpha"] - self.dim = self.oft_blocks.shape[0] - self.num_blocks = self.dim + self.lin_module = None + # kohya-ss + if "oft_blocks" in weights.w.keys(): + self.is_kohya = True + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + self.dim = self.oft_blocks.shape[0] + elif "oft_diag" in weights.w.keys(): + self.is_kohya = False + self.oft_blocks = weights.w["oft_diag"] + # alpha is rank if alpha is 0 or None + if self.alpha is None: + pass + self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + else: + raise ValueError("oft_blocks or oft_diag must be in weights dict") - if "Linear" in self.sd_module.__class__.__name__: + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + #if "Linear" in self.sd_module.__class__.__name__ or is_linear: + if is_linear: self.out_dim = self.sd_module.out_features - elif "Conv" in self.sd_module.__class__.__name__: + #elif hasattr(self.sd_module, "embed_dim"): + # self.out_dim = self.sd_module.embed_dim + #else: + # raise ValueError("Linear sd_module must have out_features or embed_dim") + elif is_other_linear: + self.out_dim = self.sd_module.embed_dim + elif is_conv: self.out_dim = self.sd_module.out_channels + else: + raise ValueError("sd_module must be Linear or Conv") - self.constraint = self.alpha * self.out_dim - self.block_size = self.out_dim // self.num_blocks + + if self.is_kohya: + self.num_blocks = self.dim + self.block_size = self.out_dim // self.num_blocks + self.constraint = self.alpha * self.out_dim + #elif is_linear or is_conv: + else: + self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.constraint = None self.org_module: list[torch.Module] = [self.sd_module] - # def merge_weight(self, R_weight, org_weight): - # R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) - # if org_weight.dim() == 4: - # weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) - # else: - # weight = torch.einsum("oi, op -> pi", org_weight, R_weight) - # weight = torch.einsum( - # "k n m, k n ... -> k m ...", - # self.oft_diag * scale + torch.eye(self.block_size, device=device), - # org_weight - # ) - # return weight + # if is_other_linear: + # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) + # module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + # with torch.no_grad(): + # if weight.shape != module.weight.shape: + # weight = weight.reshape(module.weight.shape) + # module.weight.copy_(weight) + # module.to(device=devices.cpu, dtype=devices.dtype) + # module.weight.requires_grad_(False) + # self.lin_module = module + #return module + + def merge_weight(self, R_weight, org_weight): + R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R_weight) + #weight = torch.einsum( + # "k n m, k n ... -> k m ...", + # self.oft_diag * scale + torch.eye(self.block_size, device=device), + # org_weight + #) + return weight def get_weight(self, oft_blocks, multiplier=None): - # constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) + if self.constraint is not None: + constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype) - # block_Q = oft_blocks - oft_blocks.transpose(1, 2) - # norm_Q = torch.norm(block_Q.flatten()) - # new_norm_Q = torch.clamp(norm_Q, max=constraint) - # block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - # m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - # block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + if self.constraint is not None: + new_norm_Q = torch.clamp(norm_Q, max=constraint) + else: + new_norm_Q = norm_Q + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + m_I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) - # block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I - # R = torch.block_diag(*block_R_weighted) - #return R - return self.oft_blocks + block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I + R = torch.block_diag(*block_R_weighted) + return R + #return self.oft_blocks def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - #R = self.get_weight(self.oft_blocks, multiplier) - R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - #merged_weight = self.merge_weight(R, orig_weight) + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) - orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - orig_weight - ) - weight = rearrange(weight, 'k m ... -> (k m) ...') + #if self.lin_module is not None: + # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) + # weight = torch.mul(torch.mul(R, multiplier), orig_weight) + #else: + # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + # weight = torch.einsum( + # 'k n m, k n ... -> k m ...', + # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + # orig_weight + # ) + # weight = rearrange(weight, 'k m ... -> (k m) ...') - #updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape orig_weight = orig_weight @@ -100,3 +152,49 @@ class NetworkModuleOFT(network.NetworkModule): ex_bias = ex_bias * self.multiplier() return updown, ex_bias + +# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py +def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: + ''' + return a tuple of two value of input dimension decomposed by the number closest to factor + second value is higher or equal than first value. + + In LoRA with Kroneckor Product, first value is a value for weight scale. + secon value is a value for weight. + + Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. + + examples) + factor + -1 2 4 8 16 ... + 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 + 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 + 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 + 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 + 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 + 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 + ''' + + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m length or new_m>factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n +