mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-01-01 20:45:04 +08:00
Add files via upload
This commit is contained in:
parent
1e2648804c
commit
7e4992eb22
@ -23,16 +23,14 @@ try:
|
|||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from infer.modules.ipex import ipex_init
|
from infer.modules.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
|
|
||||||
from torch.xpu.amp import autocast
|
|
||||||
from infer.modules.ipex.gradscaler import gradscaler_init
|
from infer.modules.ipex.gradscaler import gradscaler_init
|
||||||
|
from torch.xpu.amp import autocast
|
||||||
|
|
||||||
GradScaler = gradscaler_init()
|
GradScaler = gradscaler_init()
|
||||||
|
ipex_init()
|
||||||
else:
|
else:
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception:
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
@ -106,11 +104,10 @@ def main():
|
|||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||||
children = []
|
children = []
|
||||||
logger = utils.get_logger(hps.model_dir)
|
|
||||||
for i in range(n_gpus):
|
for i in range(n_gpus):
|
||||||
subproc = mp.Process(
|
subproc = mp.Process(
|
||||||
target=run,
|
target=run,
|
||||||
args=(i, n_gpus, hps, logger),
|
args=(i, n_gpus, hps),
|
||||||
)
|
)
|
||||||
children.append(subproc)
|
children.append(subproc)
|
||||||
subproc.start()
|
subproc.start()
|
||||||
@ -119,10 +116,10 @@ def main():
|
|||||||
children[i].join()
|
children[i].join()
|
||||||
|
|
||||||
|
|
||||||
def run(rank, n_gpus, hps, logger: logging.Logger):
|
def run(rank, n_gpus, hps,):
|
||||||
global global_step
|
global global_step
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# logger = utils.get_logger(hps.model_dir)
|
logger = utils.get_logger(hps.model_dir)
|
||||||
logger.info(hps)
|
logger.info(hps)
|
||||||
# utils.check_git_hash(hps.model_dir)
|
# utils.check_git_hash(hps.model_dir)
|
||||||
writer = SummaryWriter(log_dir=hps.model_dir)
|
writer = SummaryWriter(log_dir=hps.model_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user