mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 05:45:05 +08:00
Merge pull request #10089 from AUTOMATIC1111/LoraFix
Fix some Lora's not working
This commit is contained in:
commit
b15bbef798
@ -166,8 +166,10 @@ def load_lora(name, filename):
|
|||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
@ -233,6 +235,8 @@ def lora_calc_updown(lora, module, target):
|
|||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user