mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
For no constraint
This commit is contained in:
parent
64179c3221
commit
c4afdb7895
@ -27,7 +27,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
# kohya-ss/New LyCORIS OFT/BOFT
|
# kohya-ss/New LyCORIS OFT/BOFT
|
||||||
if "oft_blocks" in weights.w.keys():
|
if "oft_blocks" in weights.w.keys():
|
||||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
self.alpha = weights.w.get("alpha", self.alpha) # alpha is constraint
|
self.alpha = weights.w.get("alpha", None) # alpha is constraint
|
||||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
# Old LyCORIS OFT
|
# Old LyCORIS OFT
|
||||||
elif "oft_diag" in weights.w.keys():
|
elif "oft_diag" in weights.w.keys():
|
||||||
@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
|
|
||||||
self.num_blocks = self.dim
|
self.num_blocks = self.dim
|
||||||
self.block_size = self.out_dim // self.dim
|
self.block_size = self.out_dim // self.dim
|
||||||
self.constraint = (1 if self.alpha is None else self.alpha) * self.out_dim
|
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||||
if self.is_R:
|
if self.is_R:
|
||||||
self.constraint = None
|
self.constraint = None
|
||||||
self.block_size = self.dim
|
self.block_size = self.dim
|
||||||
@ -73,9 +73,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
|
|
||||||
if not self.is_R:
|
if not self.is_R:
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
if self.constraint != 0:
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device)
|
R = oft_blocks.to(orig_weight.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user