Merge pull request #1410 from RVC-Project/dev

chore(sync): merge dev into main
This commit is contained in:
RVC-Boss 2023-10-10 17:23:18 +08:00 committed by GitHub
commit 5e22271924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,16 +23,14 @@ try:
if torch.xpu.is_available(): if torch.xpu.is_available():
from infer.modules.ipex import ipex_init from infer.modules.ipex import ipex_init
ipex_init()
from torch.xpu.amp import autocast
from infer.modules.ipex.gradscaler import gradscaler_init from infer.modules.ipex.gradscaler import gradscaler_init
from torch.xpu.amp import autocast
GradScaler = gradscaler_init() GradScaler = gradscaler_init()
ipex_init()
else: else:
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
except Exception: # pylint: disable=broad-exception-caught except Exception:
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler, autocast
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
@ -106,11 +104,10 @@ def main():
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555)) os.environ["MASTER_PORT"] = str(randint(20000, 55555))
children = [] children = []
logger = utils.get_logger(hps.model_dir)
for i in range(n_gpus): for i in range(n_gpus):
subproc = mp.Process( subproc = mp.Process(
target=run, target=run,
args=(i, n_gpus, hps, logger), args=(i, n_gpus, hps),
) )
children.append(subproc) children.append(subproc)
subproc.start() subproc.start()
@ -119,10 +116,14 @@ def main():
children[i].join() children[i].join()
def run(rank, n_gpus, hps, logger: logging.Logger): def run(
rank,
n_gpus,
hps,
):
global global_step global global_step
if rank == 0: if rank == 0:
# logger = utils.get_logger(hps.model_dir) logger = utils.get_logger(hps.model_dir)
logger.info(hps) logger.info(hps)
# utils.check_git_hash(hps.model_dir) # utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir)