Retrieval-based-Voice-Conve.../infer/lib/train/graph_generation.py

142 lines
5.0 KiB
Python
Raw Normal View History

import os
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import logging
# Suppress TensorBoard event processing logs
logging.getLogger('tensorboard').setLevel(logging.ERROR)
# Import TensorBoard modules after setting logging level
from tensorboard.backend.event_processing import event_accumulator
def generate_loss_graphs(log_dir):
"""
Generates and saves plots for the given TensorBoard logs.
Args:
log_dir (str): Directory with TensorBoard logs.
"""
scalar_data = extract_scalar_data(log_dir)
plot_scalar_data(scalar_data, log_dir)
def smooth(scalars: List[float], weight: float) -> List[float]:
"""
Smooths the given list of scalars using exponential smoothing.
Args:
scalars (List[float]): List of scalar values to be smoothed.
weight (float): Weight for smoothing, between 0 and 1.
Returns:
List[float]: Smoothed scalar values.
"""
last = scalars[0]
smoothed = []
for point in scalars:
smoothed_val = last * weight + (1 - weight) * point
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def extract_scalar_data(log_dir):
"""
Extracts specific scalar data from TensorBoard logs.
Args:
log_dir (str): Directory with TensorBoard logs.
Returns:
dict: A dictionary where keys are scalar names and values are lists of scalar events.
"""
ea = event_accumulator.EventAccumulator(log_dir)
ea.Reload()
scalar_data = {}
desired_tags = ["loss/d/total", "loss/g/total", "loss/g/fm", "loss/g/mel", "loss/g/kl"]
for tag in desired_tags:
if tag in ea.Tags()['scalars']:
scalar_events = ea.Scalars(tag)
scalar_data[tag] = {}
previous_value = 0.0 # Initialize fallback value
for event in scalar_events:
value = event.value
# Check if value is NaN, use previous value or fallback to 0.0
if np.isnan(value):
value = previous_value
if event.step not in scalar_data[tag]:
scalar_data[tag][event.step] = [value]
else:
scalar_data[tag][event.step].append(value)
# Update previous value for the next iteration only if the value is not NaN
if not np.isnan(value):
previous_value = value
# Calculate the average for each step. Restarting training can cause multiple events for the same step.
scalar_data[tag] = {step: sum(values) / len(values) for step, values in scalar_data[tag].items()}
else:
print(f"Tag: {tag} not found in the TensorBoard logs.")
return scalar_data
def sanitize_filename(filename):
"""
Sanitize the filename by replacing invalid characters with underscores.
Args:
filename (str): The original filename or tag to sanitize.
Returns:
str: Sanitized filename.
"""
# Replace slashes and other invalid characters with underscores
return filename.replace('/', '_').replace('\\', '_')
def plot_scalar_data(scalar_data, log_dir, output_dir="loss_graphs", smooth_weight=0.75):
"""
Generates and saves plots for the given scalar data, with optional smoothing.
Args:
scalar_data (dict): A dictionary where keys are scalar names and values are lists of scalar events.
log_dir (str): Base directory where the generated JPEG files will be saved.
output_dir (str): Subdirectory under `log_dir` where the JPEG files will be saved.
smooth_weight (float): Weight for smoothing, between 0 and 1.
"""
loss_graph_dir = os.path.join(log_dir, output_dir)
if not os.path.exists(loss_graph_dir):
os.makedirs(loss_graph_dir)
for tag, events in scalar_data.items():
# Sanitize the tag for use in the filename
sanitized_tag = sanitize_filename(tag)
file_path = os.path.join(loss_graph_dir, f'{sanitized_tag}.jpeg')
# Extract steps and values from the scalar events
steps = list(events.keys())
values = list(events.values())
# Print the last tag, step, and value
if steps and values: # Ensure that the list is not empty
last_step = steps[-1]
last_value = values[-1]
print(f'Last entry - Tag: {tag}, Step: {last_step}, Value: {last_value}')
# Apply smoothing
smoothed_values = smooth(values, smooth_weight)
plt.figure(figsize=(20, 12))
plt.plot(steps, values, label=f'{tag} (original)')
plt.plot(steps, smoothed_values, label=f'{tag} (smoothed)', linestyle='--')
plt.xlabel('Steps')
plt.ylabel(tag)
plt.title(f'{tag} over time')
plt.yscale('log') # Set y-axis to logarithmic scale for better visualization, reminder another approach could be identify a cutoff point and disregard a few datapoints at the beginning of the training
plt.grid(True)
plt.legend()
plt.savefig(file_path)
plt.close()