mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
add --medvram-sdxl
This commit is contained in:
parent
bb7dd7b646
commit
016554e437
@ -35,6 +35,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
|
|||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
|
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||||
|
@ -186,7 +186,6 @@ class InterrogateModels:
|
|||||||
res = ""
|
res = ""
|
||||||
shared.state.begin(job="interrogate")
|
shared.state.begin(job="interrogate")
|
||||||
try:
|
try:
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from modules import devices
|
from modules import devices, shared
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
@ -14,6 +14,20 @@ def send_everything_to_cpu():
|
|||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_needed(sd_model):
|
||||||
|
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
||||||
|
|
||||||
|
|
||||||
|
def apply(sd_model):
|
||||||
|
enable = is_needed(sd_model)
|
||||||
|
shared.parallel_processing_allowed = not enable
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
||||||
|
else:
|
||||||
|
sd_model.lowvram = False
|
||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
if getattr(sd_model, 'lowvram', False):
|
if getattr(sd_model, 'lowvram', False):
|
||||||
return
|
return
|
||||||
@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
|
|
||||||
|
|
||||||
def is_enabled(sd_model):
|
def is_enabled(sd_model):
|
||||||
return getattr(sd_model, 'lowvram', False)
|
return sd_model.lowvram
|
||||||
|
@ -517,7 +517,7 @@ def get_empty_cond(sd_model):
|
|||||||
|
|
||||||
|
|
||||||
def send_model_to_cpu(m):
|
def send_model_to_cpu(m):
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if m.lowvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
m.to(devices.cpu)
|
m.to(devices.cpu)
|
||||||
@ -525,17 +525,17 @@ def send_model_to_cpu(m):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def model_target_device():
|
def model_target_device(m):
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if lowvram.is_needed(m):
|
||||||
return devices.cpu
|
return devices.cpu
|
||||||
else:
|
else:
|
||||||
return devices.device
|
return devices.device
|
||||||
|
|
||||||
|
|
||||||
def send_model_to_device(m):
|
def send_model_to_device(m):
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
lowvram.apply(m)
|
||||||
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
|
||||||
else:
|
if not m.lowvram:
|
||||||
m.to(shared.device)
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
'': torch.float16,
|
'': torch.float16,
|
||||||
}
|
}
|
||||||
|
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
timer.record("load weights from state dict")
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
timer.record("script callbacks")
|
timer.record("script callbacks")
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ def apply_unet(option=None):
|
|||||||
if current_unet_option is None:
|
if current_unet_option is None:
|
||||||
current_unet = None
|
current_unet = None
|
||||||
|
|
||||||
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
if not shared.sd_model.lowvram:
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -263,7 +263,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
if loaded_vae_file == vae_file:
|
if loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if sd_model.lowvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
sd_model.to(devices.cpu)
|
sd_model.to(devices.cpu)
|
||||||
@ -275,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
|
@ -11,7 +11,7 @@ cmd_opts = shared_cmd_options.cmd_opts
|
|||||||
parser = shared_cmd_options.parser
|
parser = shared_cmd_options.parser
|
||||||
|
|
||||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
parallel_processing_allowed = True
|
||||||
styles_filename = cmd_opts.styles_file
|
styles_filename = cmd_opts.styles_file
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
Loading…
Reference in New Issue
Block a user