From 613b0d9548a859408433bff7a6dca7fd0f2eae7e Mon Sep 17 00:00:00 2001
From: v0xie <28695009+v0xie@users.noreply.github.com>
Date: Thu, 8 Feb 2024 21:58:59 -0800
Subject: [PATCH] doc: add boft comment

---
 extensions-builtin/Lora/network_oft.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index fc7132651..d7b317029 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -29,13 +29,14 @@ class NetworkModuleOFT(network.NetworkModule):
             self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
             self.alpha = weights.w["alpha"] # alpha is constraint
             self.dim = self.oft_blocks.shape[0] # lora dim
-        # LyCORIS
+        # LyCORIS OFT
         elif "oft_diag" in weights.w.keys():
             self.is_kohya = False
             self.oft_blocks = weights.w["oft_diag"]
             # self.alpha is unused
             self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
 
+            # LyCORIS BOFT
             self.is_boft = False
             if weights.w["oft_diag"].dim() == 4:
                 self.is_boft = True
@@ -89,6 +90,7 @@ class NetworkModuleOFT(network.NetworkModule):
             )
             merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
         else:
+            # TODO: determine correct value for scale
             scale = 1.0
             m = self.boft_m
             b = self.boft_b
@@ -99,8 +101,6 @@ class NetworkModuleOFT(network.NetworkModule):
                 if i == 0:
                     # Apply multiplier/scale and rescale into first weight
                     bi = bi * scale + (1 - scale) * eye
-                    #if self.rescaled:
-                    #    bi = bi * self.rescale
                 inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
                 inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
                 inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)