Merge pull request #15160 from AUTOMATIC1111/dora-weight-decompose

Add DoRA (weight-decompose) support for LoRA/LoHa/LoKr
This commit is contained in:
AUTOMATIC1111 2024-03-08 08:02:17 +03:00 committed by GitHub
commit 9409419afb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -146,6 +146,9 @@ class NetworkModule:
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None
self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
def multiplier(self): def multiplier(self):
if 'transformer' in self.sd_key[:20]: if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier return self.network.te_multiplier
@ -160,6 +163,15 @@ class NetworkModule:
return 1.0 return 1.0
def apply_weight_decompose(self, updown, orig_weight):
orig_weight = orig_weight.to(updown)
merged_scale1 = updown + orig_weight
dora_merged = (
merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale
)
final_updown = dora_merged - orig_weight
return final_updown
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
@ -175,6 +187,9 @@ class NetworkModule:
if ex_bias is not None: if ex_bias is not None:
ex_bias = ex_bias * self.multiplier() ex_bias = ex_bias * self.multiplier()
if self.dora_scale is not None:
updown = self.apply_weight_decompose(updown, orig_weight)
return updown * self.calc_scale() * self.multiplier(), ex_bias return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target): def calc_updown(self, target):