refactor: use forward hook instead of custom forward

This commit is contained in:
v0xie 2023-10-21 13:43:31 -07:00
parent 0550659ce6
commit 2d8c894b27

View File

@ -36,9 +36,11 @@ class NetworkModuleOFT(network.NetworkModule):
# how do we revert this to unload the weights? # how do we revert this to unload the weights?
def apply_to(self): def apply_to(self):
self.org_forward = self.org_module[0].forward self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward #self.org_module[0].forward = self.forward
self.org_module[0].register_forward_hook(self.forward_hook)
def get_weight(self, oft_blocks, multiplier=None): def get_weight(self, oft_blocks, multiplier=None):
self.constraint = self.constraint.to(oft_blocks.device, dtype=oft_blocks.dtype)
block_Q = oft_blocks - oft_blocks.transpose(1, 2) block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten()) norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint) new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
@ -67,13 +69,9 @@ class NetworkModuleOFT(network.NetworkModule):
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y=None): def forward_hook(self, module, args, output):
x = self.org_forward(x) #print(f'Forward hook in {self.network_key} called')
if self.multiplier() == 0.0: x = output
return x
# calculating R here is excruciatingly slow
#R = self.get_weight().to(x.device, dtype=x.dtype)
R = self.R.to(x.device, dtype=x.dtype) R = self.R.to(x.device, dtype=x.dtype)
if x.dim() == 4: if x.dim() == 4:
@ -83,3 +81,20 @@ class NetworkModuleOFT(network.NetworkModule):
else: else:
x = torch.matmul(x, R) x = torch.matmul(x, R)
return x return x
# def forward(self, x, y=None):
# x = self.org_forward(x)
# if self.multiplier() == 0.0:
# return x
# # calculating R here is excruciatingly slow
# #R = self.get_weight().to(x.device, dtype=x.dtype)
# R = self.R.to(x.device, dtype=x.dtype)
# if x.dim() == 4:
# x = x.permute(0, 2, 3, 1)
# x = torch.matmul(x, R)
# x = x.permute(0, 3, 1, 2)
# else:
# x = torch.matmul(x, R)
# return x