mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-04 21:14:54 +08:00
support copy option to reduce ram usage
This commit is contained in:
parent
2060886450
commit
2f72fd89ff
@ -1,5 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from copy import deepcopy
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -169,7 +170,7 @@ patch_module_list = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def manual_cast_forward(target_dtype, target_device=None):
|
def manual_cast_forward(target_dtype, target_device=None, copy=False):
|
||||||
params = dict()
|
params = dict()
|
||||||
if supports_non_blocking():
|
if supports_non_blocking():
|
||||||
params['non_blocking'] = True
|
params['non_blocking'] = True
|
||||||
@ -193,8 +194,17 @@ def manual_cast_forward(target_dtype, target_device=None):
|
|||||||
org_dtype = param.dtype
|
org_dtype = param.dtype
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if copy:
|
||||||
|
copied = deepcopy(self)
|
||||||
|
if org_dtype != target_dtype:
|
||||||
|
copied.to(**params)
|
||||||
|
|
||||||
|
result = copied.org_forward(*args, **kwargs)
|
||||||
|
del copied
|
||||||
|
else:
|
||||||
if org_dtype != target_dtype:
|
if org_dtype != target_dtype:
|
||||||
self.to(**params)
|
self.to(**params)
|
||||||
|
|
||||||
result = self.org_forward(*args, **kwargs)
|
result = self.org_forward(*args, **kwargs)
|
||||||
|
|
||||||
if org_dtype != target_dtype:
|
if org_dtype != target_dtype:
|
||||||
@ -220,15 +230,17 @@ def manual_cast_forward(target_dtype, target_device=None):
|
|||||||
def manual_cast(target_dtype, target_device=None):
|
def manual_cast(target_dtype, target_device=None):
|
||||||
applied = False
|
applied = False
|
||||||
|
|
||||||
|
copy = shared.opts.lora_without_backup_weight
|
||||||
|
|
||||||
for module_type in patch_module_list:
|
for module_type in patch_module_list:
|
||||||
if hasattr(module_type, "org_forward"):
|
if hasattr(module_type, "org_forward"):
|
||||||
continue
|
continue
|
||||||
applied = True
|
applied = True
|
||||||
org_forward = module_type.forward
|
org_forward = module_type.forward
|
||||||
if module_type == torch.nn.MultiheadAttention:
|
if module_type == torch.nn.MultiheadAttention:
|
||||||
module_type.forward = manual_cast_forward(torch.float32, target_device)
|
module_type.forward = manual_cast_forward(torch.float32, target_device, copy)
|
||||||
else:
|
else:
|
||||||
module_type.forward = manual_cast_forward(target_dtype, target_device)
|
module_type.forward = manual_cast_forward(target_dtype, target_device, copy)
|
||||||
module_type.org_forward = org_forward
|
module_type.org_forward = org_forward
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
|
Loading…
Reference in New Issue
Block a user