mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 03:40:14 +08:00
Fix dtype casting for OFT module
This commit is contained in:
parent
a06dab8d7a
commit
f8f38c7c28
@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
if self.is_kohya:
|
if self.is_kohya:
|
||||||
@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
R = oft_blocks.to(orig_weight.device)
|
||||||
|
|
||||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
)
|
)
|
||||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user