From 310d0e6938e74fd461a4bf5a33d9a5b4d977b29d Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 4 Oct 2024 00:08:44 +0900 Subject: [PATCH] restore org_dtype != compute dtype case --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 66f262f4d..948fa6740 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight): setattr(obj, field, None) return - getattr(obj, field).copy_(weight) + old_weight = getattr(obj, field) + old_weight.copy_(weight.to(dtype=old_weight.dtype)) def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):