mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-09 23:44:55 +08:00
Use less RAM when creating models
This commit is contained in:
parent
f451994053
commit
0a89cd1a58
@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
|||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
|
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
|
@ -3,8 +3,31 @@ import open_clip
|
|||||||
import torch
|
import torch
|
||||||
import transformers.utils.hub
|
import transformers.utils.hub
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
class DisableInitialization:
|
|
||||||
|
class ReplaceHelper:
|
||||||
|
def __init__(self):
|
||||||
|
self.replaced = []
|
||||||
|
|
||||||
|
def replace(self, obj, field, func):
|
||||||
|
original = getattr(obj, field, None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.replaced.append((obj, field, original))
|
||||||
|
setattr(obj, field, func)
|
||||||
|
|
||||||
|
return original
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
for obj, field, original in self.replaced:
|
||||||
|
setattr(obj, field, original)
|
||||||
|
|
||||||
|
self.replaced.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class DisableInitialization(ReplaceHelper):
|
||||||
"""
|
"""
|
||||||
When an object of this class enters a `with` block, it starts:
|
When an object of this class enters a `with` block, it starts:
|
||||||
- preventing torch's layer initialization functions from working
|
- preventing torch's layer initialization functions from working
|
||||||
@ -21,7 +44,7 @@ class DisableInitialization:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, disable_clip=True):
|
def __init__(self, disable_clip=True):
|
||||||
self.replaced = []
|
super().__init__()
|
||||||
self.disable_clip = disable_clip
|
self.disable_clip = disable_clip
|
||||||
|
|
||||||
def replace(self, obj, field, func):
|
def replace(self, obj, field, func):
|
||||||
@ -86,8 +109,81 @@ class DisableInitialization:
|
|||||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
for obj, field, original in self.replaced:
|
self.restore()
|
||||||
setattr(obj, field, original)
|
|
||||||
|
|
||||||
self.replaced.clear()
|
|
||||||
|
|
||||||
|
class InitializeOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||||
|
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||||
|
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_device(x):
|
||||||
|
x["device"] = "meta"
|
||||||
|
return x
|
||||||
|
|
||||||
|
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||||
|
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||||
|
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||||
|
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStateDictOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||||
|
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||||
|
Meant to be used together with InitializeOnMeta above.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_dict, device):
|
||||||
|
super().__init__()
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
sd = self.state_dict
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
|
||||||
|
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
|
||||||
|
|
||||||
|
for name, param in params:
|
||||||
|
if param.is_meta:
|
||||||
|
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
|
||||||
|
|
||||||
|
original(self, state_dict, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
for name, _ in params:
|
||||||
|
key = prefix + name
|
||||||
|
if key in sd:
|
||||||
|
del sd[key]
|
||||||
|
|
||||||
|
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||||
|
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||||
|
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
@ -460,7 +460,6 @@ def get_empty_cond(sd_model):
|
|||||||
return sd_model.cond_stage_model([""])
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
sd_model = None
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
except Exception:
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
pass
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "creating model quickly", full_traceback=True)
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
sd_model.used_config = checkpoint_config
|
sd_model.used_config = checkpoint_config
|
||||||
|
|
||||||
timer.record("create model")
|
timer.record("create model")
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
||||||
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
|
4
webui.py
4
webui.py
@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
if modules.sd_hijack.current_optimizer is None:
|
if modules.sd_hijack.current_optimizer is None:
|
||||||
modules.sd_hijack.apply_optimizations()
|
modules.sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
Thread(target=load_model).start()
|
devices.first_time_calculation()
|
||||||
|
|
||||||
Thread(target=devices.first_time_calculation).start()
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
startup_timer.record("reload hypernetworks")
|
startup_timer.record("reload hypernetworks")
|
||||||
|
Loading…
Reference in New Issue
Block a user