support copy option to reduce ram usage

This commit is contained in:
Won-Kyu Park 2024-09-07 12:23:03 +09:00
parent 2060886450
commit 2f72fd89ff
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -1,5 +1,6 @@
import sys import sys
import contextlib import contextlib
from copy import deepcopy
from functools import lru_cache from functools import lru_cache
import torch import torch
@ -169,7 +170,7 @@ patch_module_list = [
] ]
def manual_cast_forward(target_dtype, target_device=None): def manual_cast_forward(target_dtype, target_device=None, copy=False):
params = dict() params = dict()
if supports_non_blocking(): if supports_non_blocking():
params['non_blocking'] = True params['non_blocking'] = True
@ -193,8 +194,17 @@ def manual_cast_forward(target_dtype, target_device=None):
org_dtype = param.dtype org_dtype = param.dtype
break break
if copy:
copied = deepcopy(self)
if org_dtype != target_dtype:
copied.to(**params)
result = copied.org_forward(*args, **kwargs)
del copied
else:
if org_dtype != target_dtype: if org_dtype != target_dtype:
self.to(**params) self.to(**params)
result = self.org_forward(*args, **kwargs) result = self.org_forward(*args, **kwargs)
if org_dtype != target_dtype: if org_dtype != target_dtype:
@ -220,15 +230,17 @@ def manual_cast_forward(target_dtype, target_device=None):
def manual_cast(target_dtype, target_device=None): def manual_cast(target_dtype, target_device=None):
applied = False applied = False
copy = shared.opts.lora_without_backup_weight
for module_type in patch_module_list: for module_type in patch_module_list:
if hasattr(module_type, "org_forward"): if hasattr(module_type, "org_forward"):
continue continue
applied = True applied = True
org_forward = module_type.forward org_forward = module_type.forward
if module_type == torch.nn.MultiheadAttention: if module_type == torch.nn.MultiheadAttention:
module_type.forward = manual_cast_forward(torch.float32, target_device) module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
else: else:
module_type.forward = manual_cast_forward(target_dtype, target_device) module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
module_type.org_forward = org_forward module_type.org_forward = org_forward
try: try:
yield None yield None