Update train.py

This commit is contained in:
RVC-Boss 2024-01-26 16:03:00 +08:00 committed by GitHub
parent b304564c9e
commit 9602ea649c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -104,10 +104,11 @@ def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
children = []
logger = utils.get_logger(hps.model_dir)
for i in range(n_gpus):
subproc = mp.Process(
target=run,
args=(i, n_gpus, hps),
args=(i, n_gpus, hps, logger),
)
children.append(subproc)
subproc.start()
@ -116,14 +117,10 @@ def main():
children[i].join()
def run(
rank,
n_gpus,
hps,
):
def run(rank, n_gpus, hps, logger: logging.Logger):
global global_step
if rank == 0:
logger = utils.get_logger(hps.model_dir)
# logger = utils.get_logger(hps.model_dir)
logger.info(hps)
# utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
@ -229,13 +226,13 @@ def run(
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainG))
if hasattr(net_g, "module"):
print(
logger.info(
net_g.module.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
) ##测试不加载优化器
else:
print(
logger.info(
net_g.load_state_dict(
torch.load(hps.pretrainG, map_location="cpu")["model"]
)
@ -244,13 +241,13 @@ def run(
if rank == 0:
logger.info("loaded pretrained %s" % (hps.pretrainD))
if hasattr(net_d, "module"):
print(
logger.info(
net_d.module.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)
)
else:
print(
logger.info(
net_d.load_state_dict(
torch.load(hps.pretrainD, map_location="cpu")["model"]
)