fix: 多显卡训练时logger未赋值引用 (#1722)

modified:   infer/modules/train/train.py
This commit is contained in:
Chengjia Jiang 2024-01-16 19:30:10 +08:00 committed by GitHub
parent f6fa0c9cd9
commit 49434901d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 2 deletions

5
.gitignore vendored
View File

@ -21,3 +21,8 @@ rmvpe.pt
# To set a Python version for the project # To set a Python version for the project
.tool-versions .tool-versions
/runtime
/assets/weights/*
ffmpeg.*
ffprobe.*

View File

@ -104,10 +104,11 @@ def main():
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555)) os.environ["MASTER_PORT"] = str(randint(20000, 55555))
children = [] children = []
logger = utils.get_logger(hps.model_dir)
for i in range(n_gpus): for i in range(n_gpus):
subproc = mp.Process( subproc = mp.Process(
target=run, target=run,
args=(i, n_gpus, hps), args=(i, n_gpus, hps, logger),
) )
children.append(subproc) children.append(subproc)
subproc.start() subproc.start()
@ -120,10 +121,11 @@ def run(
rank, rank,
n_gpus, n_gpus,
hps, hps,
logger: logging.Logger
): ):
global global_step global global_step
if rank == 0: if rank == 0:
logger = utils.get_logger(hps.model_dir) # logger = utils.get_logger(hps.model_dir)
logger.info(hps) logger.info(hps)
# utils.check_git_hash(hps.model_dir) # utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir)