Update previous value for the next iteration only if the value is not NaN

This commit is contained in:
Wernervanrun 2024-08-09 13:31:51 +02:00
parent 254faa1b91
commit 62c6142250

View File

@ -72,7 +72,9 @@ def extract_scalar_data(log_dir):
else: else:
scalar_data[tag][event.step].append(value) scalar_data[tag][event.step].append(value)
previous_value = value # Update previous value for the next iteration # Update previous value for the next iteration only if the value is not NaN
if not np.isnan(value):
previous_value = value
# Calculate the average for each step. Restarting training can cause multiple events for the same step. # Calculate the average for each step. Restarting training can cause multiple events for the same step.
scalar_data[tag] = {step: sum(values) / len(values) for step, values in scalar_data[tag].items()} scalar_data[tag] = {step: sum(values) / len(values) for step, values in scalar_data[tag].items()}