diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 66f262f4d..948fa6740 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -391,7 +391,8 @@ def restore_weights_backup(obj, field, weight): setattr(obj, field, None) return - getattr(obj, field).copy_(weight) + old_weight = getattr(obj, field) + old_weight.copy_(weight.to(dtype=old_weight.dtype)) def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention], cleanup=False):