diff --git a/train_nsf_sim_cache_sid_load_pretrain.py b/train_nsf_sim_cache_sid_load_pretrain.py index 3a94e82..d2d9c35 100644 --- a/train_nsf_sim_cache_sid_load_pretrain.py +++ b/train_nsf_sim_cache_sid_load_pretrain.py @@ -501,13 +501,12 @@ def train_and_evaluate( if rank == 0: logger.info("====> Epoch: {}".format(epoch)) - if(epoch>=hps.total_epoch): - if rank == 0: - logger.info("Training is done. The program is closed.") - from process_ckpt import savee#def savee(ckpt,sr,if_f0,name,epoch): - if hasattr(net_g, 'module'):ckpt = net_g.module.state_dict() - else:ckpt = net_g.state_dict() - print("saving final ckpt:",savee(ckpt,hps.sample_rate,hps.if_f0,hps.name,epoch)) + if(epoch>=hps.total_epoch and rank == 0): + logger.info("Training is done. The program is closed.") + from process_ckpt import savee#def savee(ckpt,sr,if_f0,name,epoch): + if hasattr(net_g, 'module'):ckpt = net_g.module.state_dict() + else:ckpt = net_g.state_dict() + logger.info("saving final ckpt:%s"%(savee(ckpt,hps.sample_rate,hps.if_f0,hps.name,epoch))) os._exit(2333333)