For no constraint

This commit is contained in:
Kohaku-Blueleaf 2024-02-22 00:43:32 +08:00 committed by GitHub
parent 64179c3221
commit c4afdb7895
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,7 +27,7 @@ class NetworkModuleOFT(network.NetworkModule):
# kohya-ss/New LyCORIS OFT/BOFT # kohya-ss/New LyCORIS OFT/BOFT
if "oft_blocks" in weights.w.keys(): if "oft_blocks" in weights.w.keys():
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size) self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w.get("alpha", self.alpha) # alpha is constraint self.alpha = weights.w.get("alpha", None) # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim self.dim = self.oft_blocks.shape[0] # lora dim
# Old LyCORIS OFT # Old LyCORIS OFT
elif "oft_diag" in weights.w.keys(): elif "oft_diag" in weights.w.keys():
@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
self.num_blocks = self.dim self.num_blocks = self.dim
self.block_size = self.out_dim // self.dim self.block_size = self.out_dim // self.dim
self.constraint = (1 if self.alpha is None else self.alpha) * self.out_dim self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
if self.is_R: if self.is_R:
self.constraint = None self.constraint = None
self.block_size = self.dim self.block_size = self.dim
@ -73,6 +73,7 @@ class NetworkModuleOFT(network.NetworkModule):
if not self.is_R: if not self.is_R:
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
if self.constraint != 0:
norm_Q = torch.norm(block_Q.flatten()) norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))