From 12bcacf41393ff9368836514af641d835b8a3b02 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:29:40 +0800 Subject: [PATCH] Initial implementation --- extensions-builtin/Lora/network.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index b8fd91941..2268b0f7e 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -146,6 +146,9 @@ class NetworkModule: self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None + self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None + self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1) + def multiplier(self): if 'transformer' in self.sd_key[:20]: return self.network.te_multiplier @@ -160,6 +163,15 @@ class NetworkModule: return 1.0 + def apply_weight_decompose(self, updown, orig_weight): + orig_weight = orig_weight.to(updown) + merged_scale1 = updown + orig_weight + dora_merged = ( + merged_scale1 / merged_scale1(dim=self.dora_mean_dim, keepdim=True) * self.dora_scale + ) + final_updown = dora_merged - orig_weight + return final_updown + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -175,6 +187,9 @@ class NetworkModule: if ex_bias is not None: ex_bias = ex_bias * self.multiplier() + if self.dora_scale is not None: + updown = self.apply_weight_decompose(updown, orig_weight) + return updown * self.calc_scale() * self.multiplier(), ex_bias def calc_updown(self, target):