From 0550659ce6e1c37d1ab05cb8a2cb31d499fa552f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:13:02 -0700 Subject: [PATCH] style: fix ambiguous variable name --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 0a87958e2..4e8382c18 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -43,8 +43,8 @@ class NetworkModuleOFT(network.NetworkModule): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + m_I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(m_I + block_Q, (m_I - block_Q).inverse()) #block_R_weighted = multiplier * block_R + (1 - multiplier) * I #R = torch.block_diag(*block_R_weighted) R = torch.block_diag(*block_R)