mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
Add rescale mechanism
LyCORIS will support save oft_blocks instead of oft_diag in the near future (for both OFT and BOFT) But this means we need to store the rescale if user enable it.
This commit is contained in:
parent
eb6f2df826
commit
90441294db
@ -40,6 +40,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.is_boft = False
|
self.is_boft = False
|
||||||
if weights.w["oft_diag"].dim() == 4:
|
if weights.w["oft_diag"].dim() == 4:
|
||||||
self.is_boft = True
|
self.is_boft = True
|
||||||
|
self.rescale = weight.w.get('rescale', None)
|
||||||
|
|
||||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
@ -108,6 +109,10 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
||||||
merged_weight = inp
|
merged_weight = inp
|
||||||
|
|
||||||
|
# Rescale mechanism
|
||||||
|
if self.rescale is not None:
|
||||||
|
merged_weight = self.rescale.to(merged_weight) * merged_weight
|
||||||
|
|
||||||
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user