remove parsing command line from devices.py

This commit is contained in:
AUTOMATIC 2022-10-22 14:04:14 +03:00
parent e80bdcab91
commit 50b5504401
2 changed files with 9 additions and 14 deletions

View File

@ -15,14 +15,10 @@ def extract_device_id(args, name):
def get_optimal_device(): def get_optimal_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
# CUDA device selection support: from modules import shared
if "shared" not in sys.modules:
commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. device_id = shared.cmd_opts.device_id
sys.argv += shlex.split(commandline_args)
device_id = extract_device_id(sys.argv, '--device-id')
else:
device_id = shared.cmd_opts.device_id
if device_id is not None: if device_id is not None:
cuda_device = f"cuda:{device_id}" cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device) return torch.device(cuda_device)
@ -49,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16

View File

@ -1,9 +1,8 @@
import torch import torch
from modules.devices import get_optimal_device from modules import devices
module_in_gpu = None module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
device = gpu = get_optimal_device()
def send_everything_to_cpu(): def send_everything_to_cpu():
@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None: if module_in_gpu is not None:
module_in_gpu.to(cpu) module_in_gpu.to(cpu)
module.to(gpu) module.to(devices.device)
module_in_gpu = module module_in_gpu = module
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models # register hooks for those the first two models
@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time # so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device) sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model # install hooks for bits of third model