mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-01-31 10:42:52 +08:00
**Subject:** *Add loss graph functionality for Gradio interface*
**Body:** * Implements helper functions to retrieve and display loss graphs for different projects. * Integrates with Gradio to provide interactive tabs with loss graph images. * Handles project selection and updates loss graphs accordingly. * Includes helper functions for generating and saving loss graphs from TensorBoard logs
This commit is contained in:
parent
85829da3a0
commit
1b7add90ea
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Pitch settings",
|
"音调设置": "Pitch settings",
|
||||||
"音频设备": "Audio device",
|
"音频设备": "Audio device",
|
||||||
"音高算法": "pitch detection algorithm",
|
"音高算法": "pitch detection algorithm",
|
||||||
"额外推理时长": "Extra inference time"
|
"额外推理时长": "Extra inference time",
|
||||||
|
"损失图": "Loss Graph",
|
||||||
|
"选择语音": "Select voice",
|
||||||
|
"更新损失图": "Update Loss Graph",
|
||||||
|
"更新语音列表": "Update Voice List",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Training Progress Overview: Lower values indicate better model performance. For detailed insights, explore TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Ajuste de tono",
|
"音调设置": "Ajuste de tono",
|
||||||
"音频设备": "Dispositivo de audio",
|
"音频设备": "Dispositivo de audio",
|
||||||
"音高算法": "Algoritmo de tono",
|
"音高算法": "Algoritmo de tono",
|
||||||
"额外推理时长": "Tiempo de inferencia adicional"
|
"额外推理时长": "Tiempo de inferencia adicional",
|
||||||
|
"损失图": "Gráfico de pérdida",
|
||||||
|
"选择语音": "Seleccione el audio",
|
||||||
|
"更新损失图": "Actualizar gráfico de pérdida",
|
||||||
|
"更新语音列表": "Actualizar lista de audio",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Resumen del progreso de entrenamiento: cuanto menor sea el valor, mejor será el rendimiento del modelo. Para obtener una visión detallada, explore TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Réglages de la hauteur",
|
"音调设置": "Réglages de la hauteur",
|
||||||
"音频设备": "Périphérique audio",
|
"音频设备": "Périphérique audio",
|
||||||
"音高算法": "algorithme de détection de la hauteur",
|
"音高算法": "algorithme de détection de la hauteur",
|
||||||
"额外推理时长": "Temps d'inférence supplémentaire"
|
"额外推理时长": "Temps d'inférence supplémentaire",
|
||||||
|
"损失图": "Graphique de perte",
|
||||||
|
"选择语音": "Sélectionner la voix",
|
||||||
|
"更新损失图": "Actualiser le graphique de perte",
|
||||||
|
"更新语音列表": "Actualiser la liste des voix",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Aperçu de la progression de l'entraînement : les valeurs inférieures indiquent des performances de modèle plus élevées. Pour une compréhension détaillée, explorez TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Impostazioni del tono",
|
"音调设置": "Impostazioni del tono",
|
||||||
"音频设备": "Dispositivo audio",
|
"音频设备": "Dispositivo audio",
|
||||||
"音高算法": "音高算法",
|
"音高算法": "音高算法",
|
||||||
"额外推理时长": "Tempo di inferenza extra"
|
"额外推理时长": "Tempo di inferenza extra",
|
||||||
|
"损失图": "Grafico delle perdite",
|
||||||
|
"选择语音": "Seleziona la voce",
|
||||||
|
"更新损失图": "Aggiornare il grafico delle perdite",
|
||||||
|
"更新语音列表": "Aggiornare l'elenco delle voci",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Riepilogo del progresso dell'addestramento: i valori più bassi indicano prestazioni di modello migliori. Per una comprensione più approfondita, esplora TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "音程設定",
|
"音调设置": "音程設定",
|
||||||
"音频设备": "オーディオデバイス",
|
"音频设备": "オーディオデバイス",
|
||||||
"音高算法": "ピッチアルゴリズム",
|
"音高算法": "ピッチアルゴリズム",
|
||||||
"额外推理时长": "追加推論時間"
|
"额外推理时长": "追加推論時間",
|
||||||
|
"损失图": "損失グラフ",
|
||||||
|
"选择语音": "音声を選択",
|
||||||
|
"更新损失图": "損失グラフを更新する",
|
||||||
|
"更新语音列表": "音声リストを更新する",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "トレーニング進行状況の概要:値が低いほど、モデルの性能が良い。詳細な見解を得るには、TensorBoardを探索してください。"
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "음조 설정",
|
"音调设置": "음조 설정",
|
||||||
"音频设备": "音频设备",
|
"音频设备": "音频设备",
|
||||||
"音高算法": "음높이 알고리즘",
|
"音高算法": "음높이 알고리즘",
|
||||||
"额外推理时长": "추가 추론 시간"
|
"额外推理时长": "추가 추론 시간",
|
||||||
|
"损失图": "손실 그래프",
|
||||||
|
"选择语音": "음성 선택",
|
||||||
|
"更新损失图": "손실 그래프 업데이트",
|
||||||
|
"更新语音列表": "음성 목록 업데이트",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "훈련 진행 개요: 값이 낮을수록 모델 성능이 좋습니다. 자세한 내용은 TensorBoard를 탐색하세요."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Configurações de tom",
|
"音调设置": "Configurações de tom",
|
||||||
"音频设备": "音频设备",
|
"音频设备": "音频设备",
|
||||||
"音高算法": "Algoritmo de detecção de pitch",
|
"音高算法": "Algoritmo de detecção de pitch",
|
||||||
"额外推理时长": "Tempo extra de inferência"
|
"额外推理时长": "Tempo extra de inferência",
|
||||||
|
"损失图": "Gráfico de perda",
|
||||||
|
"选择语音": "Selecione o áudio",
|
||||||
|
"更新损失图": "Atualizar gráfico de perda",
|
||||||
|
"更新语音列表": "Atualizar lista de áudio",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Visão geral do progresso do treinamento: quanto menor o valor, melhor o desempenho do modelo. Para obter insights detalhados, explore o TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Настройка высоты звука",
|
"音调设置": "Настройка высоты звука",
|
||||||
"音频设备": "Аудиоустройство",
|
"音频设备": "Аудиоустройство",
|
||||||
"音高算法": "Алгоритм оценки высоты звука",
|
"音高算法": "Алгоритм оценки высоты звука",
|
||||||
"额外推理时长": "Доп. время переработки"
|
"额外推理时长": "Доп. время переработки",
|
||||||
|
"损失图": "График потерь",
|
||||||
|
"选择语音": "Выберите голос",
|
||||||
|
"更新损失图": "Обновить график потерь",
|
||||||
|
"更新语音列表": "Обновить список голосов",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Обзор обучения: чем ниже значение, тем лучше качество модели. Подробнее см. TensorBoard."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "Pitch ayarları",
|
"音调设置": "Pitch ayarları",
|
||||||
"音频设备": "Ses cihazı",
|
"音频设备": "Ses cihazı",
|
||||||
"音高算法": "音高算法",
|
"音高算法": "音高算法",
|
||||||
"额外推理时长": "Ekstra çıkartma süresi"
|
"额外推理时长": "Ekstra çıkartma süresi",
|
||||||
|
"损失图": "Kayıp grafiği",
|
||||||
|
"选择语音": "Konuşma seç",
|
||||||
|
"更新损失图": "Kayıp grafiğini güncelle",
|
||||||
|
"更新语音列表": "Konuşma listesini güncelle",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "Eğitim ilerlemesi özeti: Düşük değerler, model performansının daha iyi olduğunu gösterir. Daha fazla ayrıntı için TensorBoard'u keşfedin."
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "音调设置",
|
"音调设置": "音调设置",
|
||||||
"音频设备": "音频设备",
|
"音频设备": "音频设备",
|
||||||
"音高算法": "音高算法",
|
"音高算法": "音高算法",
|
||||||
"额外推理时长": "额外推理时长"
|
"额外推理时长": "额外推理时长",
|
||||||
|
"损失图": "损失图",
|
||||||
|
"选择语音": "选择语音",
|
||||||
|
"更新损失图": "更新损失图",
|
||||||
|
"更新语音列表": "更新语音列表",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。"
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "音調設定",
|
"音调设置": "音調設定",
|
||||||
"音频设备": "音訊設備",
|
"音频设备": "音訊設備",
|
||||||
"音高算法": "音高演算法",
|
"音高算法": "音高演算法",
|
||||||
"额外推理时长": "額外推理時長"
|
"额外推理时长": "額外推理時長",
|
||||||
|
"损失图": "損失圖",
|
||||||
|
"选择语音": "選擇語音",
|
||||||
|
"更新损失图": "更新損失圖",
|
||||||
|
"更新语音列表": "更新語音列表",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "訓練進度概覽:值越低,模型性能越好。如需詳細見解,請探索 TensorBoard。"
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "音調設定",
|
"音调设置": "音調設定",
|
||||||
"音频设备": "音訊設備",
|
"音频设备": "音訊設備",
|
||||||
"音高算法": "音高演算法",
|
"音高算法": "音高演算法",
|
||||||
"额外推理时长": "額外推理時長"
|
"额外推理时长": "額外推理時長",
|
||||||
|
"损失图": "損失圖",
|
||||||
|
"选择语音": "選擇語音",
|
||||||
|
"更新损失图": "更新損失圖",
|
||||||
|
"更新语音列表": "更新語音列表",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "訓練進度概覽:值越低,模型性能越好。如需詳細見解,請探索 TensorBoard。"
|
||||||
}
|
}
|
||||||
|
@ -133,5 +133,10 @@
|
|||||||
"音调设置": "音調設定",
|
"音调设置": "音調設定",
|
||||||
"音频设备": "音訊設備",
|
"音频设备": "音訊設備",
|
||||||
"音高算法": "音高演算法",
|
"音高算法": "音高演算法",
|
||||||
"额外推理时长": "額外推理時長"
|
"额外推理时长": "額外推理時長",
|
||||||
|
"损失图": "損失圖",
|
||||||
|
"选择语音": "选择语音",
|
||||||
|
"更新损失图": "更新損失圖",
|
||||||
|
"更新语音列表": "更新语音列表",
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。": "訓練進度概覽:值越低,模型性能越好。如需詳細見解,請探索 TensorBoard。"
|
||||||
}
|
}
|
||||||
|
136
infer-web.py
136
infer-web.py
@ -1,3 +1,4 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@ -806,6 +807,100 @@ def change_f0_method(f0method8):
|
|||||||
return {"visible": visible, "__type__": "update"}
|
return {"visible": visible, "__type__": "update"}
|
||||||
|
|
||||||
|
|
||||||
|
# start tab loss graph helper functions
|
||||||
|
desired_tags = ["loss_d_total", "loss_g_total", "loss_g_fm", "loss_g_mel", "loss_g_kl"]
|
||||||
|
|
||||||
|
def get_projects():
|
||||||
|
"""
|
||||||
|
Get the list of projects.
|
||||||
|
"""
|
||||||
|
return [name for name in os.listdir(index_root) if os.path.isdir(os.path.join(index_root, name)) and name != 'mute' and os.path.isdir(os.path.join(index_root, name, 'loss_graphs'))]
|
||||||
|
|
||||||
|
def get_loss_graph_images(selection):
|
||||||
|
"""
|
||||||
|
Gets loss graph images for a given project, assuming filenames match desired order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selection (str): Project name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary of image paths keyed by desired_tags.
|
||||||
|
"""
|
||||||
|
loss_graphs_path = os.path.join(index_root, selection, 'loss_graphs')
|
||||||
|
if not os.path.exists(loss_graphs_path):
|
||||||
|
print(f"Directory not found: {loss_graphs_path}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
graphs = {}
|
||||||
|
for tag in desired_tags:
|
||||||
|
image_path = os.path.join(loss_graphs_path, f"{tag}.jpeg")
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
graphs[tag] = image_path
|
||||||
|
|
||||||
|
return graphs
|
||||||
|
|
||||||
|
def get_loss_graph_tabs(project):
|
||||||
|
"""
|
||||||
|
Create Gradio Tabs and Image fields for the loss graphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project (str): Project name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
gr.Tabs, list: A tuple containing the Gradio Tabs component and a list of image fields.
|
||||||
|
"""
|
||||||
|
loss_graph_tabs = gr.Tabs()
|
||||||
|
loss_graph_images = get_loss_graph_images(project)
|
||||||
|
loss_graph_image_fields = {}
|
||||||
|
|
||||||
|
with loss_graph_tabs:
|
||||||
|
for tag, image_path in loss_graph_images.items():
|
||||||
|
with gr.TabItem(tag):
|
||||||
|
image_field = gr.Image(value=image_path, width="100%")
|
||||||
|
loss_graph_image_fields[tag] = image_field
|
||||||
|
return loss_graph_tabs, list(loss_graph_image_fields.values())
|
||||||
|
|
||||||
|
def update_loss_graph_images(selection):
|
||||||
|
"""
|
||||||
|
Update the loss graph images for a given project.
|
||||||
|
"""
|
||||||
|
loss_graph_images = get_loss_graph_images(selection)
|
||||||
|
updated_values = []
|
||||||
|
|
||||||
|
for i, tag in enumerate(desired_tags):
|
||||||
|
if i < len(image_fields):
|
||||||
|
if tag in loss_graph_images:
|
||||||
|
image_path = loss_graph_images[tag]
|
||||||
|
if os.path.exists(image_path) and os.path.isfile(image_path):
|
||||||
|
updated_values.append(image_path)
|
||||||
|
else:
|
||||||
|
print(f"Warning: Image file does not exist or is not a file: {image_path}")
|
||||||
|
updated_values.append(None)
|
||||||
|
else:
|
||||||
|
print(f"Warning: No image found for tag: {tag}")
|
||||||
|
updated_values.append(None)
|
||||||
|
|
||||||
|
return updated_values
|
||||||
|
|
||||||
|
def update_projects():
|
||||||
|
"""
|
||||||
|
Update the list of projects.
|
||||||
|
"""
|
||||||
|
projects = get_projects()
|
||||||
|
return {"choices": sorted(projects), "__type__": "update"}
|
||||||
|
|
||||||
|
projects = get_projects()
|
||||||
|
|
||||||
|
# Check if there are any projects before accessing
|
||||||
|
if projects:
|
||||||
|
default_project = projects[0]
|
||||||
|
default_loss_graph_images = get_loss_graph_images(projects[0])
|
||||||
|
else:
|
||||||
|
print("No projects found.")
|
||||||
|
default_project = None
|
||||||
|
default_loss_graph_images = []
|
||||||
|
|
||||||
|
# gradio app
|
||||||
with gr.Blocks(title="RVC WebUI") as app:
|
with gr.Blocks(title="RVC WebUI") as app:
|
||||||
gr.Markdown("## RVC WebUI")
|
gr.Markdown("## RVC WebUI")
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
@ -1420,7 +1515,48 @@ with gr.Blocks(title="RVC WebUI") as app:
|
|||||||
info3,
|
info3,
|
||||||
api_name="train_start_all",
|
api_name="train_start_all",
|
||||||
)
|
)
|
||||||
|
with gr.TabItem(i18n("损失图")):
|
||||||
|
gr.Markdown(
|
||||||
|
value=i18n(
|
||||||
|
"训练进度概览:值越低,模型性能越好。如需详细见解,请探索 TensorBoard。"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
voice_list_dropdown = gr.Dropdown(
|
||||||
|
label=i18n("选择语音"),
|
||||||
|
choices=sorted(projects),
|
||||||
|
interactive=True,
|
||||||
|
value=default_project
|
||||||
|
)
|
||||||
|
with gr.Column():
|
||||||
|
update_voice_list_button = gr.Button(
|
||||||
|
i18n("更新语音列表"),
|
||||||
|
variant="primary"
|
||||||
|
)
|
||||||
|
update_loss_graph_button = gr.Button(
|
||||||
|
i18n("更新损失图"),
|
||||||
|
variant="primary"
|
||||||
|
)
|
||||||
|
update_voice_list_button.click(
|
||||||
|
fn=update_projects,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[voice_list_dropdown],
|
||||||
|
api_name="infer_refresh"
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
tabs, image_fields = get_loss_graph_tabs(default_project)
|
||||||
|
|
||||||
|
voice_list_dropdown.change(
|
||||||
|
fn=update_loss_graph_images,
|
||||||
|
inputs=voice_list_dropdown,
|
||||||
|
outputs=image_fields
|
||||||
|
)
|
||||||
|
|
||||||
|
update_loss_graph_button.click(
|
||||||
|
fn=update_loss_graph_images,
|
||||||
|
inputs=voice_list_dropdown,
|
||||||
|
outputs=image_fields
|
||||||
|
)
|
||||||
with gr.TabItem(i18n("ckpt处理")):
|
with gr.TabItem(i18n("ckpt处理")):
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(value=i18n("模型融合, 可用于测试音色融合"))
|
gr.Markdown(value=i18n("模型融合, 可用于测试音色融合"))
|
||||||
|
128
infer/lib/train/graph_generation.py
Normal file
128
infer/lib/train/graph_generation.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Suppress TensorBoard event processing logs
|
||||||
|
logging.getLogger('tensorboard').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# Import TensorBoard modules after setting logging level
|
||||||
|
from tensorboard.backend.event_processing import event_accumulator
|
||||||
|
|
||||||
|
def generate_loss_graphs(log_dir):
|
||||||
|
"""
|
||||||
|
Generates and saves plots for the given TensorBoard logs.
|
||||||
|
Args:
|
||||||
|
log_dir (str): Directory with TensorBoard logs.
|
||||||
|
"""
|
||||||
|
scalar_data = extract_scalar_data(log_dir)
|
||||||
|
plot_scalar_data(scalar_data, log_dir)
|
||||||
|
|
||||||
|
def smooth(scalars: List[float], weight: float) -> List[float]:
|
||||||
|
"""
|
||||||
|
Smooths the given list of scalars using exponential smoothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scalars (List[float]): List of scalar values to be smoothed.
|
||||||
|
weight (float): Weight for smoothing, between 0 and 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: Smoothed scalar values.
|
||||||
|
"""
|
||||||
|
last = scalars[0]
|
||||||
|
smoothed = []
|
||||||
|
for point in scalars:
|
||||||
|
smoothed_val = last * weight + (1 - weight) * point
|
||||||
|
smoothed.append(smoothed_val)
|
||||||
|
last = smoothed_val
|
||||||
|
return smoothed
|
||||||
|
|
||||||
|
def extract_scalar_data(log_dir):
|
||||||
|
"""
|
||||||
|
Extracts specific scalar data from TensorBoard logs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir (str): Directory with TensorBoard logs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary where keys are scalar names and values are lists of scalar events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ea = event_accumulator.EventAccumulator(log_dir)
|
||||||
|
ea.Reload()
|
||||||
|
|
||||||
|
scalar_data = {}
|
||||||
|
desired_tags = ["loss/d/total", "loss/g/total", "loss/g/fm", "loss/g/mel", "loss/g/kl"]
|
||||||
|
|
||||||
|
for tag in desired_tags:
|
||||||
|
if tag in ea.Tags()['scalars']:
|
||||||
|
scalar_events = ea.Scalars(tag)
|
||||||
|
scalar_data[tag] = {}
|
||||||
|
for event in scalar_events:
|
||||||
|
if event.step not in scalar_data[tag]:
|
||||||
|
scalar_data[tag][event.step] = [event.value]
|
||||||
|
else:
|
||||||
|
scalar_data[tag][event.step].append(event.value)
|
||||||
|
# Calculate the average for each step. Restarting training can cause multiple events for the same step.
|
||||||
|
scalar_data[tag] = {step: sum(values) / len(values) for step, values in scalar_data[tag].items()}
|
||||||
|
else:
|
||||||
|
print(f"Tag: {tag} not found in the TensorBoard logs.")
|
||||||
|
|
||||||
|
return scalar_data
|
||||||
|
|
||||||
|
def sanitize_filename(filename):
|
||||||
|
"""
|
||||||
|
Sanitize the filename by replacing invalid characters with underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str): The original filename or tag to sanitize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Sanitized filename.
|
||||||
|
"""
|
||||||
|
# Replace slashes and other invalid characters with underscores
|
||||||
|
return filename.replace('/', '_').replace('\\', '_')
|
||||||
|
|
||||||
|
def plot_scalar_data(scalar_data, log_dir, output_dir="loss_graphs", smooth_weight=0.75):
|
||||||
|
"""
|
||||||
|
Generates and saves plots for the given scalar data, with optional smoothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scalar_data (dict): A dictionary where keys are scalar names and values are lists of scalar events.
|
||||||
|
log_dir (str): Base directory where the generated JPEG files will be saved.
|
||||||
|
output_dir (str): Subdirectory under `log_dir` where the JPEG files will be saved.
|
||||||
|
smooth_weight (float): Weight for smoothing, between 0 and 1.
|
||||||
|
"""
|
||||||
|
loss_graph_dir = os.path.join(log_dir, output_dir)
|
||||||
|
if not os.path.exists(loss_graph_dir):
|
||||||
|
os.makedirs(loss_graph_dir)
|
||||||
|
|
||||||
|
for tag, events in scalar_data.items():
|
||||||
|
# Sanitize the tag for use in the filename
|
||||||
|
sanitized_tag = sanitize_filename(tag)
|
||||||
|
file_path = os.path.join(loss_graph_dir, f'{sanitized_tag}.jpeg')
|
||||||
|
|
||||||
|
# Extract steps and values from the scalar events
|
||||||
|
steps = list(events.keys())
|
||||||
|
values = list(events.values())
|
||||||
|
|
||||||
|
# Print the last tag, step, and value
|
||||||
|
if steps and values: # Ensure that the list is not empty
|
||||||
|
last_step = steps[-1]
|
||||||
|
last_value = values[-1]
|
||||||
|
print(f'Last entry - Tag: {tag}, Step: {last_step}, Value: {last_value}')
|
||||||
|
|
||||||
|
# Apply smoothing
|
||||||
|
smoothed_values = smooth(values, smooth_weight)
|
||||||
|
|
||||||
|
plt.figure(figsize=(20, 12))
|
||||||
|
plt.plot(steps, values, label=f'{tag} (original)')
|
||||||
|
plt.plot(steps, smoothed_values, label=f'{tag} (smoothed)', linestyle='--')
|
||||||
|
plt.xlabel('Steps')
|
||||||
|
plt.ylabel(tag)
|
||||||
|
plt.title(f'{tag} over time')
|
||||||
|
plt.yscale('log') # Set y-axis to logarithmic scale for better visualization, reminder another approach could be identify a cutoff point and disregard a few datapoints at the beginning of the training
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(file_path)
|
||||||
|
plt.close()
|
@ -9,7 +9,7 @@ sys.path.append(os.path.join(now_dir))
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from infer.lib.train import utils
|
from infer.lib.train import utils, graph_generation
|
||||||
|
|
||||||
hps = utils.get_hparams()
|
hps = utils.get_hparams()
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
|
os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
|
||||||
@ -558,6 +558,7 @@ def train_and_evaluate(
|
|||||||
images=image_dict,
|
images=image_dict,
|
||||||
scalars=scalar_dict,
|
scalars=scalar_dict,
|
||||||
)
|
)
|
||||||
|
graph_generation.generate_loss_graphs(hps.model_dir)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
# /Run steps
|
# /Run steps
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user