mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
Use correct implementation, fix device error
This commit is contained in:
parent
851c3d51ed
commit
1792e193b1
@ -153,7 +153,7 @@ class NetworkModule:
|
|||||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||||
|
|
||||||
self.dora_scale = weights.w.get("dora_scale", None)
|
self.dora_scale = weights.w.get("dora_scale", None)
|
||||||
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
|
self.dora_norm_dims = len(self.shape) - 1
|
||||||
|
|
||||||
def multiplier(self):
|
def multiplier(self):
|
||||||
if 'transformer' in self.sd_key[:20]:
|
if 'transformer' in self.sd_key[:20]:
|
||||||
@ -170,10 +170,22 @@ class NetworkModule:
|
|||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def apply_weight_decompose(self, updown, orig_weight):
|
def apply_weight_decompose(self, updown, orig_weight):
|
||||||
orig_weight = orig_weight.to(updown)
|
# Match the device/dtype
|
||||||
|
orig_weight = orig_weight.to(updown.dtype)
|
||||||
|
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||||
|
updown = updown.to(orig_weight.device)
|
||||||
|
|
||||||
merged_scale1 = updown + orig_weight
|
merged_scale1 = updown + orig_weight
|
||||||
|
merged_scale1_norm = (
|
||||||
|
merged_scale1.transpose(0, 1)
|
||||||
|
.reshape(merged_scale1.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
dora_merged = (
|
dora_merged = (
|
||||||
merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale
|
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||||
)
|
)
|
||||||
final_updown = dora_merged - orig_weight
|
final_updown = dora_merged - orig_weight
|
||||||
return final_updown
|
return final_updown
|
||||||
|
Loading…
Reference in New Issue
Block a user