Decouple merged utils from spark_ui

This commit is contained in:
VSlobolinskyi 2025-03-21 21:38:46 +02:00
parent 931870e527
commit 87e06f30ec

View File

@ -1,20 +1,87 @@
from datetime import datetime
import logging
import os
import platform
import shutil
import re
import numpy as np
from time import sleep
import soundfile as sf
from pydub import AudioSegment
import torch
# Import modules from your packages
from spark.cli.SparkTTS import SparkTTS
from rvc_ui.initialization import vc
from spark_ui.main import initialize_model, run_tts
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
device = 0
spark_model = initialize_model(model_dir, device=device)
def initialize_model(model_dir, device):
"""Load the model once at the beginning."""
logging.info(f"Loading model from: {model_dir}")
# Determine appropriate device based on platform and availability
if platform.system() == "Darwin":
# macOS with MPS support (Apple Silicon)
device = torch.device(f"mps:{device}")
logging.info(f"Using MPS device: {device}")
elif torch.cuda.is_available():
# System with CUDA support
device = torch.device(f"cuda:{device}")
logging.info(f"Using CUDA device: {device}")
else:
# Fall back to CPU
device = torch.device("cpu")
logging.info("GPU acceleration not available, using CPU")
model = SparkTTS(model_dir, device)
return model
def run_tts(
text,
prompt_text=None,
prompt_speech=None,
gender=None,
pitch=None,
speed=None,
save_dir="TEMP/spark", # Updated default save directory
save_filename=None, # New parameter to specify filename
):
"""Perform TTS inference and save the generated audio."""
model = initialize_model(model_dir, device=device)
logging.info(f"Saving audio to: {save_dir}")
if prompt_text is not None:
prompt_text = None if len(prompt_text) <= 1 else prompt_text
# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
if save_filename:
save_path = os.path.join(save_dir, save_filename)
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
with torch.no_grad():
wav = model.inference(
text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,
)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
return save_path
def split_into_sentences(text):
"""
@ -40,34 +107,26 @@ def process_single_sentence(
base_fragment_num
):
"""
Process a single sentence through the TTS and RVC pipeline
Args:
sentence_index (int): Index of the sentence in the original text
sentence (str): The sentence text to process
... (other parameters are the same as generate_and_process_with_rvc)
Returns:
tuple: (spark_output_path, rvc_output_path, success, info_message)
Process a single sentence through the TTS and RVC pipeline.
"""
fragment_num = base_fragment_num + sentence_index
# Generate TTS audio for this sentence
# Generate TTS audio for this sentence, saving directly to the correct location
tts_path = run_tts(
sentence,
spark_model,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
prompt_speech=prompt_speech,
save_dir="./TEMP/spark",
save_filename=f"fragment_{fragment_num}.wav"
)
# Make sure we have a TTS file to process
if not tts_path or not os.path.exists(tts_path):
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
# Save Spark output to TEMP/spark
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
shutil.copy2(tts_path, spark_output_path)
# Use the tts_path as the Spark output (no need to copy)
spark_output_path = tts_path
# Call RVC processing function
f0_file = None # We're not using an F0 curve file in this pipeline
output_info, output_audio = vc.vc_single(
@ -75,12 +134,11 @@ def process_single_sentence(
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
)
# Save RVC output to TEMP/rvc directory
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
rvc_saved = False
# Try different ways to save the RVC output based on common formats
try:
if isinstance(output_audio, str) and os.path.exists(output_audio):
# Case 1: output_audio is a file path string
@ -99,7 +157,7 @@ def process_single_sentence(
rvc_saved = True
except Exception as e:
output_info += f"\nError saving RVC output: {str(e)}"
# Prepare info message
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
info_message += f" - Spark output: {spark_output_path}\n"
@ -107,7 +165,7 @@ def process_single_sentence(
info_message += f" - RVC output: {rvc_output_path}"
else:
info_message += f" - Could not save RVC output to {rvc_output_path}"
return spark_output_path, rvc_output_path, rvc_saved, info_message
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
@ -225,4 +283,4 @@ def modified_get_vc(sid0_value, protect0_value, file_index2_component):
if isinstance(outputs, tuple) and len(outputs) >= 3:
return outputs[0], outputs[1], outputs[3]
return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else ""
return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else ""