support copy option to reduce ram usage

This commit is contained in:
Won-Kyu Park 2024-09-07 12:23:03 +09:00
parent 2060886450
commit 2f72fd89ff
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -1,5 +1,6 @@
import sys
import contextlib
from copy import deepcopy
from functools import lru_cache
import torch
@ -169,7 +170,7 @@ patch_module_list = [
]
def manual_cast_forward(target_dtype, target_device=None):
def manual_cast_forward(target_dtype, target_device=None, copy=False):
params = dict()
if supports_non_blocking():
params['non_blocking'] = True
@ -193,8 +194,17 @@ def manual_cast_forward(target_dtype, target_device=None):
org_dtype = param.dtype
break
if copy:
copied = deepcopy(self)
if org_dtype != target_dtype:
copied.to(**params)
result = copied.org_forward(*args, **kwargs)
del copied
else:
if org_dtype != target_dtype:
self.to(**params)
result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype:
@ -220,15 +230,17 @@ def manual_cast_forward(target_dtype, target_device=None):
def manual_cast(target_dtype, target_device=None):
applied = False
copy = shared.opts.lora_without_backup_weight
for module_type in patch_module_list:
if hasattr(module_type, "org_forward"):
continue
applied = True
org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32, target_device)
module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
else:
module_type.forward = manual_cast_forward(target_dtype, target_device)
module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
module_type.org_forward = org_forward
try:
yield None