diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d1c46a4b2..8a37828cc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -1,6 +1,6 @@ import torch import network -from lyco_helpers import factorization +from lyco_helpers import factorization, butterfly_factor from einops import rearrange @@ -36,6 +36,12 @@ class NetworkModuleOFT(network.NetworkModule): # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + self.is_boft = False + if "boft" in weights.w.keys(): + self.is_boft = True + self.boft_b = weights.w["boft_b"] + self.boft_m = weights.w["boft_m"] + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported @@ -68,14 +74,34 @@ class NetworkModuleOFT(network.NetworkModule): R = oft_blocks.to(orig_weight.device) - # This errors out for MultiheadAttention, might need to be handled up-stream - merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - merged_weight = torch.einsum( - 'k n m, k n ... -> k m ...', - R, - merged_weight - ) - merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if not self.is_boft: + # This errors out for MultiheadAttention, might need to be handled up-stream + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R, + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + else: + scale = 1.0 + m = self.boft_m.to(device=oft_blocks.device, dtype=oft_blocks.dtype) + b = self.boft_b.to(device=oft_blocks.device, dtype=oft_blocks.dtype) + r_b = b // 2 + inp = orig_weight + for i in range(m): + bi = R[i] # b_num, b_size, b_size + if i == 0: + # Apply multiplier/scale and rescale into first weight + bi = bi * scale + (1 - scale) * eye + #if self.rescaled: + # bi = bi * self.rescale + inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) + inp = rearrange(inp, "(d b) ... -> d b ...", b=b) + inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp) + inp = rearrange(inp, "d b ... -> (d b) ...") + inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) + merged_weight = inp updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape