From 87e06f30ecb1d88569404a2a3e0f7157f777659d Mon Sep 17 00:00:00 2001 From: VSlobolinskyi Date: Fri, 21 Mar 2025 21:38:46 +0200 Subject: [PATCH] Decouple merged utils from spark_ui --- modules/merged_ui/utils.py | 116 +++++++++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 29 deletions(-) diff --git a/modules/merged_ui/utils.py b/modules/merged_ui/utils.py index 7e95ee4..f4336f8 100644 --- a/modules/merged_ui/utils.py +++ b/modules/merged_ui/utils.py @@ -1,20 +1,87 @@ +from datetime import datetime +import logging import os +import platform import shutil import re import numpy as np -from time import sleep import soundfile as sf from pydub import AudioSegment +import torch # Import modules from your packages +from spark.cli.SparkTTS import SparkTTS from rvc_ui.initialization import vc -from spark_ui.main import initialize_model, run_tts -from spark.sparktts.utils.token_parser import LEVELS_MAP_UI # Initialize the Spark TTS model (moved outside function to avoid reinitializing) model_dir = "spark/pretrained_models/Spark-TTS-0.5B" device = 0 -spark_model = initialize_model(model_dir, device=device) + +def initialize_model(model_dir, device): + """Load the model once at the beginning.""" + logging.info(f"Loading model from: {model_dir}") + + # Determine appropriate device based on platform and availability + if platform.system() == "Darwin": + # macOS with MPS support (Apple Silicon) + device = torch.device(f"mps:{device}") + logging.info(f"Using MPS device: {device}") + elif torch.cuda.is_available(): + # System with CUDA support + device = torch.device(f"cuda:{device}") + logging.info(f"Using CUDA device: {device}") + else: + # Fall back to CPU + device = torch.device("cpu") + logging.info("GPU acceleration not available, using CPU") + + model = SparkTTS(model_dir, device) + return model + + +def run_tts( + text, + prompt_text=None, + prompt_speech=None, + gender=None, + pitch=None, + speed=None, + save_dir="TEMP/spark", # Updated default save directory + save_filename=None, # New parameter to specify filename +): + """Perform TTS inference and save the generated audio.""" + model = initialize_model(model_dir, device=device) + logging.info(f"Saving audio to: {save_dir}") + + if prompt_text is not None: + prompt_text = None if len(prompt_text) <= 1 else prompt_text + + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Determine the save path based on save_filename if provided; otherwise, use a timestamp + if save_filename: + save_path = os.path.join(save_dir, save_filename) + else: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + save_path = os.path.join(save_dir, f"{timestamp}.wav") + + logging.info("Starting inference...") + + # Perform inference and save the output audio + with torch.no_grad(): + wav = model.inference( + text, + prompt_speech, + prompt_text, + gender, + pitch, + speed, + ) + sf.write(save_path, wav, samplerate=16000) + + logging.info(f"Audio saved at: {save_path}") + return save_path def split_into_sentences(text): """ @@ -40,34 +107,26 @@ def process_single_sentence( base_fragment_num ): """ - Process a single sentence through the TTS and RVC pipeline - - Args: - sentence_index (int): Index of the sentence in the original text - sentence (str): The sentence text to process - ... (other parameters are the same as generate_and_process_with_rvc) - - Returns: - tuple: (spark_output_path, rvc_output_path, success, info_message) + Process a single sentence through the TTS and RVC pipeline. """ fragment_num = base_fragment_num + sentence_index - - # Generate TTS audio for this sentence + + # Generate TTS audio for this sentence, saving directly to the correct location tts_path = run_tts( sentence, - spark_model, prompt_text=prompt_text_clean, - prompt_speech=prompt_speech + prompt_speech=prompt_speech, + save_dir="./TEMP/spark", + save_filename=f"fragment_{fragment_num}.wav" ) - + # Make sure we have a TTS file to process if not tts_path or not os.path.exists(tts_path): return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}" - - # Save Spark output to TEMP/spark - spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav" - shutil.copy2(tts_path, spark_output_path) - + + # Use the tts_path as the Spark output (no need to copy) + spark_output_path = tts_path + # Call RVC processing function f0_file = None # We're not using an F0 curve file in this pipeline output_info, output_audio = vc.vc_single( @@ -75,12 +134,11 @@ def process_single_sentence( file_index1, file_index2, index_rate, filter_radius, resample_sr, rms_mix_rate, protect ) - + # Save RVC output to TEMP/rvc directory rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav" rvc_saved = False - - # Try different ways to save the RVC output based on common formats + try: if isinstance(output_audio, str) and os.path.exists(output_audio): # Case 1: output_audio is a file path string @@ -99,7 +157,7 @@ def process_single_sentence( rvc_saved = True except Exception as e: output_info += f"\nError saving RVC output: {str(e)}" - + # Prepare info message info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n" info_message += f" - Spark output: {spark_output_path}\n" @@ -107,7 +165,7 @@ def process_single_sentence( info_message += f" - RVC output: {rvc_output_path}" else: info_message += f" - Could not save RVC output to {rvc_output_path}" - + return spark_output_path, rvc_output_path, rvc_saved, info_message def concatenate_audio_files(file_paths, output_path, sample_rate=44100): @@ -225,4 +283,4 @@ def modified_get_vc(sid0_value, protect0_value, file_index2_component): if isinstance(outputs, tuple) and len(outputs) >= 3: return outputs[0], outputs[1], outputs[3] - return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else "" + return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else "" \ No newline at end of file