mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
Update network_oft.py
This commit is contained in:
parent
0a271938d8
commit
a5436a3ac0
@ -22,24 +22,24 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.org_module: list[torch.Module] = [self.sd_module]
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
self.scale = 1.0
|
self.scale = 1.0
|
||||||
self.is_kohya = False
|
self.is_R = False
|
||||||
self.is_boft = False
|
self.is_boft = False
|
||||||
|
|
||||||
# kohya-ss
|
# kohya-ss/New LyCORIS OFT/BOFT
|
||||||
if "oft_blocks" in weights.w.keys():
|
if "oft_blocks" in weights.w.keys():
|
||||||
self.is_kohya = True
|
|
||||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
self.alpha = weights.w["alpha"] # alpha is constraint
|
self.alpha = weights.w.get("alpha", self.alpha) # alpha is constraint
|
||||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
# LyCORIS OFT
|
# Old LyCORIS OFT
|
||||||
elif "oft_diag" in weights.w.keys():
|
elif "oft_diag" in weights.w.keys():
|
||||||
|
self.is_R = True
|
||||||
self.oft_blocks = weights.w["oft_diag"]
|
self.oft_blocks = weights.w["oft_diag"]
|
||||||
# self.alpha is unused
|
# self.alpha is unused
|
||||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
# LyCORIS BOFT
|
# LyCORIS BOFT
|
||||||
if weights.w["oft_diag"].dim() == 4:
|
if self.oft_blocks.dim() == 4:
|
||||||
self.is_boft = True
|
self.is_boft = True
|
||||||
self.rescale = weights.w.get('rescale', None)
|
self.rescale = weights.w.get('rescale', None)
|
||||||
if self.rescale is not None:
|
if self.rescale is not None:
|
||||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||||
@ -55,26 +55,24 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
elif is_other_linear:
|
elif is_other_linear:
|
||||||
self.out_dim = self.sd_module.embed_dim
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
if self.is_kohya:
|
self.num_blocks = self.dim
|
||||||
self.constraint = self.alpha * self.out_dim
|
self.block_size = self.out_dim // self.dim
|
||||||
self.num_blocks = self.dim
|
self.constraint = (1 if self.alpha is None else self.alpha) * self.out_dim
|
||||||
self.block_size = self.out_dim // self.dim
|
if self.is_R:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size = self.dim
|
||||||
|
self.num_blocks = self.out_dim // self.dim
|
||||||
elif self.is_boft:
|
elif self.is_boft:
|
||||||
self.constraint = None
|
self.boft_m = self.oft_blocks.shape[0]
|
||||||
self.boft_m = weights.w["oft_diag"].shape[0]
|
self.num_blocks = self.oft_blocks.shape[1]
|
||||||
self.block_num = weights.w["oft_diag"].shape[1]
|
self.block_size = self.oft_blocks.shape[2]
|
||||||
self.block_size = weights.w["oft_diag"].shape[2]
|
|
||||||
self.boft_b = self.block_size
|
self.boft_b = self.block_size
|
||||||
#self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
|
|
||||||
else:
|
|
||||||
self.constraint = None
|
|
||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||||
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||||
|
|
||||||
if self.is_kohya:
|
if not self.is_R:
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||||
|
Loading…
Reference in New Issue
Block a user