From 0b0bd911d9536281f509818eb3069c80c489c195 Mon Sep 17 00:00:00 2001 From: Xerxes-2 Date: Wed, 17 May 2023 01:54:35 +1000 Subject: [PATCH] Add timestamp and elapsed time for epoch (#273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add timestamp and epoch elapsed time * don't need a class * Revert "add timestamp and epoch elapsed time" This reverts commit 93b8d4a7afd9d525069d9065bab5aebc97063a1d. * adjust class def * delete duplicate import --------- Co-authored-by: Ftps <63702646+Tps-F@users.noreply.github.com> Co-authored-by: 源文雨 <41315874+fumiama@users.noreply.github.com> --- train_nsf_sim_cache_sid_load_pretrain.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/train_nsf_sim_cache_sid_load_pretrain.py b/train_nsf_sim_cache_sid_load_pretrain.py index 481f754..f8f40ee 100644 --- a/train_nsf_sim_cache_sid_load_pretrain.py +++ b/train_nsf_sim_cache_sid_load_pretrain.py @@ -4,6 +4,7 @@ now_dir = os.getcwd() sys.path.append(os.path.join(now_dir)) sys.path.append(os.path.join(now_dir, "train")) import utils +import datetime hps = utils.get_hparams() os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",") @@ -50,6 +51,18 @@ from process_ckpt import savee global_step = 0 +class EpochRecorder: + def __init__(self): + self.last_time = ttime() + + + def record(self): + now_time = ttime() + elapsed_time = now_time - self.last_time + self.last_time = now_time + elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time)) + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + return f"[{current_time}] | ({elapsed_time_str})" def main(): # n_gpus = torch.cuda.device_count() @@ -323,6 +336,7 @@ def train_and_evaluate( data_iterator = enumerate(train_loader) # Run steps + epoch_recorder = EpochRecorder() for batch_idx, info in data_iterator: # Data ## Unpack @@ -542,7 +556,7 @@ def train_and_evaluate( ) if rank == 0: - logger.info("====> Epoch: {}".format(epoch)) + logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record())) if epoch >= hps.total_epoch and rank == 0: logger.info("Training is done. The program is closed.")