2024-08-09 18:28:19 +08:00
import os
from typing import List
import matplotlib . pyplot as plt
2024-08-09 19:11:15 +08:00
import numpy as np
2024-08-09 18:28:19 +08:00
import logging
# Suppress TensorBoard event processing logs
logging . getLogger ( ' tensorboard ' ) . setLevel ( logging . ERROR )
# Import TensorBoard modules after setting logging level
from tensorboard . backend . event_processing import event_accumulator
def generate_loss_graphs ( log_dir ) :
"""
Generates and saves plots for the given TensorBoard logs .
Args :
log_dir ( str ) : Directory with TensorBoard logs .
"""
scalar_data = extract_scalar_data ( log_dir )
plot_scalar_data ( scalar_data , log_dir )
def smooth ( scalars : List [ float ] , weight : float ) - > List [ float ] :
"""
Smooths the given list of scalars using exponential smoothing .
Args :
scalars ( List [ float ] ) : List of scalar values to be smoothed .
weight ( float ) : Weight for smoothing , between 0 and 1.
Returns :
List [ float ] : Smoothed scalar values .
"""
last = scalars [ 0 ]
smoothed = [ ]
for point in scalars :
smoothed_val = last * weight + ( 1 - weight ) * point
smoothed . append ( smoothed_val )
last = smoothed_val
return smoothed
def extract_scalar_data ( log_dir ) :
"""
Extracts specific scalar data from TensorBoard logs .
Args :
log_dir ( str ) : Directory with TensorBoard logs .
Returns :
dict : A dictionary where keys are scalar names and values are lists of scalar events .
"""
ea = event_accumulator . EventAccumulator ( log_dir )
ea . Reload ( )
scalar_data = { }
desired_tags = [ " loss/d/total " , " loss/g/total " , " loss/g/fm " , " loss/g/mel " , " loss/g/kl " ]
for tag in desired_tags :
if tag in ea . Tags ( ) [ ' scalars ' ] :
scalar_events = ea . Scalars ( tag )
scalar_data [ tag ] = { }
2024-08-09 19:11:15 +08:00
previous_value = 0.0 # Initialize fallback value
2024-08-09 18:28:19 +08:00
for event in scalar_events :
2024-08-09 19:11:15 +08:00
value = event . value
# Check if value is NaN, use previous value or fallback to 0.0
if np . isnan ( value ) :
value = previous_value
2024-08-09 18:28:19 +08:00
if event . step not in scalar_data [ tag ] :
2024-08-09 19:11:15 +08:00
scalar_data [ tag ] [ event . step ] = [ value ]
2024-08-09 18:28:19 +08:00
else :
2024-08-09 19:11:15 +08:00
scalar_data [ tag ] [ event . step ] . append ( value )
2024-08-09 19:31:51 +08:00
# Update previous value for the next iteration only if the value is not NaN
if not np . isnan ( value ) :
previous_value = value
2024-08-09 19:11:15 +08:00
2024-08-09 18:28:19 +08:00
# Calculate the average for each step. Restarting training can cause multiple events for the same step.
scalar_data [ tag ] = { step : sum ( values ) / len ( values ) for step , values in scalar_data [ tag ] . items ( ) }
else :
print ( f " Tag: { tag } not found in the TensorBoard logs. " )
return scalar_data
def sanitize_filename ( filename ) :
"""
Sanitize the filename by replacing invalid characters with underscores .
Args :
filename ( str ) : The original filename or tag to sanitize .
Returns :
str : Sanitized filename .
"""
# Replace slashes and other invalid characters with underscores
return filename . replace ( ' / ' , ' _ ' ) . replace ( ' \\ ' , ' _ ' )
def plot_scalar_data ( scalar_data , log_dir , output_dir = " loss_graphs " , smooth_weight = 0.75 ) :
"""
Generates and saves plots for the given scalar data , with optional smoothing .
Args :
scalar_data ( dict ) : A dictionary where keys are scalar names and values are lists of scalar events .
log_dir ( str ) : Base directory where the generated JPEG files will be saved .
output_dir ( str ) : Subdirectory under ` log_dir ` where the JPEG files will be saved .
smooth_weight ( float ) : Weight for smoothing , between 0 and 1.
"""
loss_graph_dir = os . path . join ( log_dir , output_dir )
if not os . path . exists ( loss_graph_dir ) :
os . makedirs ( loss_graph_dir )
for tag , events in scalar_data . items ( ) :
# Sanitize the tag for use in the filename
sanitized_tag = sanitize_filename ( tag )
file_path = os . path . join ( loss_graph_dir , f ' { sanitized_tag } .jpeg ' )
# Extract steps and values from the scalar events
steps = list ( events . keys ( ) )
values = list ( events . values ( ) )
# Print the last tag, step, and value
if steps and values : # Ensure that the list is not empty
last_step = steps [ - 1 ]
last_value = values [ - 1 ]
print ( f ' Last entry - Tag: { tag } , Step: { last_step } , Value: { last_value } ' )
# Apply smoothing
smoothed_values = smooth ( values , smooth_weight )
plt . figure ( figsize = ( 20 , 12 ) )
plt . plot ( steps , values , label = f ' { tag } (original) ' )
plt . plot ( steps , smoothed_values , label = f ' { tag } (smoothed) ' , linestyle = ' -- ' )
plt . xlabel ( ' Steps ' )
plt . ylabel ( tag )
plt . title ( f ' { tag } over time ' )
plt . yscale ( ' log ' ) # Set y-axis to logarithmic scale for better visualization, reminder another approach could be identify a cutoff point and disregard a few datapoints at the beginning of the training
plt . grid ( True )
plt . legend ( )
plt . savefig ( file_path )
plt . close ( )