diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fc7132651..d7b317029 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -29,13 +29,14 @@ class NetworkModuleOFT(network.NetworkModule): self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.alpha = weights.w["alpha"] # alpha is constraint self.dim = self.oft_blocks.shape[0] # lora dim - # LyCORIS + # LyCORIS OFT elif "oft_diag" in weights.w.keys(): self.is_kohya = False self.oft_blocks = weights.w["oft_diag"] # self.alpha is unused self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size) + # LyCORIS BOFT self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True @@ -89,6 +90,7 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') else: + # TODO: determine correct value for scale scale = 1.0 m = self.boft_m b = self.boft_b @@ -99,8 +101,6 @@ class NetworkModuleOFT(network.NetworkModule): if i == 0: # Apply multiplier/scale and rescale into first weight bi = bi * scale + (1 - scale) * eye - #if self.rescaled: - # bi = bi * self.rescale inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b) inp = rearrange(inp, "(d b) ... -> d b ...", b=b) inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)