From 284264fdbef0ef2bbcfe8b988a42f0b77a55bce1 Mon Sep 17 00:00:00 2001 From: VSlobolinskyi Date: Fri, 21 Mar 2025 19:17:49 +0200 Subject: [PATCH] Stream functionality preparation --- modules/merged_ui/main.py | 13 +-- modules/merged_ui/streaming_utils.py | 111 ++++++++++++++++++++++++++ modules/merged_ui/utils.py | 114 ++++++++++++++++++++------- modules/spark_ui/main.py | 66 ---------------- 4 files changed, 205 insertions(+), 99 deletions(-) create mode 100644 modules/merged_ui/streaming_utils.py diff --git a/modules/merged_ui/main.py b/modules/merged_ui/main.py index ade00a7..6272479 100644 --- a/modules/merged_ui/main.py +++ b/modules/merged_ui/main.py @@ -1,7 +1,8 @@ import gradio as gr # Import modules from your packages -from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc +from merged_ui.utils import modified_get_vc +from merged_ui.streaming_utils import generate_and_process_with_rvc_streaming from rvc_ui.initialization import config from rvc_ui.main import names, index_paths @@ -17,7 +18,7 @@ def build_merged_ui(): with gr.Tabs(): with gr.TabItem("TTS-to-RVC Pipeline"): gr.Markdown("### Generate speech with Spark TTS and convert with RVC") - gr.Markdown("*Note: For multi-sentence text, each sentence will be processed separately and then combined.*") + gr.Markdown("*Note: Sentences will begin playing as soon as the first one is ready*") # TTS Generation Section with gr.Row(): @@ -130,12 +131,14 @@ def build_merged_ui(): generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary") with gr.Row(): + # Status for updating information (will be updated during processing) vc_output1 = gr.Textbox(label="Output information", lines=10) - vc_output2 = gr.Audio(label="Final concatenated audio") + # Streaming audio output to support incremental updates + vc_output2 = gr.Audio(label="Audio (updates as sentences are processed)") - # Connect generate function to button + # Connect generate function to button with streaming output generate_with_rvc_button.click( - generate_and_process_with_rvc, + generate_and_process_with_rvc_streaming, inputs=[ tts_text_input, prompt_text_input, diff --git a/modules/merged_ui/streaming_utils.py b/modules/merged_ui/streaming_utils.py new file mode 100644 index 0000000..76deb78 --- /dev/null +++ b/modules/merged_ui/streaming_utils.py @@ -0,0 +1,111 @@ +import os +import time + +# Import and reuse existing functions +from merged_ui.utils import ( + split_into_sentences, + process_single_sentence, + concatenate_audio_files, + split_into_sentences, + process_single_sentence, + initialize_model +) + +# Initialize the Spark TTS model (moved outside function to avoid reinitializing) +model_dir = "spark/pretrained_models/Spark-TTS-0.5B" +device = 0 +spark_model = initialize_model(model_dir, device=device) + +def generate_and_process_with_rvc_streaming( + text, prompt_text, prompt_wav_upload, prompt_wav_record, + spk_item, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect +): + """ + Stream process TTS and RVC, yielding audio updates as sentences are processed + This is a generator function that yields (status_text, audio_path) tuples + """ + # Ensure TEMP directories exist + os.makedirs("./TEMP/spark", exist_ok=True) + os.makedirs("./TEMP/rvc", exist_ok=True) + os.makedirs("./TEMP/stream", exist_ok=True) + + # Split text into sentences + sentences = split_into_sentences(text) + if not sentences: + yield "No valid text to process.", None + return + + # Get timestamp to create unique session ID for this run + session_id = str(int(time.time())) + + # Process reference speech + prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record + prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text + + # Initialize status + status_messages = [f"Starting to process {len(sentences)} sentences..."] + # Create a list to track processed fragments + processed_fragments = [] + + # Create a temporary directory for streaming + stream_dir = f"./TEMP/stream/{session_id}" + os.makedirs(stream_dir, exist_ok=True) + + # Yield initial status + yield "\n".join(status_messages), None + + # Process each sentence and update the stream + for i, sentence in enumerate(sentences): + # Add current sentence to status + current_msg = f"Processing sentence {i+1}/{len(sentences)}: {sentence[:30]}..." + status_messages.append(current_msg) + yield "\n".join(status_messages), None + + # Process this sentence + spark_path, rvc_path, success, info = process_single_sentence( + i, sentence, prompt_speech, prompt_text_clean, + spk_item, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect, + int(session_id) # Use session ID as base fragment number + ) + + # Update status with processing result + status_messages[-1] = info + + if success and rvc_path and os.path.exists(rvc_path): + processed_fragments.append(rvc_path) + + # Create a streaming update file by concatenating all fragments processed so far + stream_path = os.path.join(stream_dir, f"stream_update_{i+1}.wav") + + # Concatenate all fragments processed so far + concatenate_success = concatenate_audio_files(processed_fragments, stream_path) + + if concatenate_success: + # Yield the updated status and the current stream path + yield "\n".join(status_messages), stream_path + else: + # If concatenation failed, just yield the most recent fragment + yield "\n".join(status_messages), rvc_path + else: + # If processing failed, update status but don't update audio + yield "\n".join(status_messages), None if not processed_fragments else processed_fragments[-1] + + # Final streaming update with completion message + if processed_fragments: + # Create final output file + final_output_path = f"./TEMP/stream/{session_id}/final_output.wav" + concatenate_success = concatenate_audio_files(processed_fragments, final_output_path) + + if concatenate_success: + status_messages.append(f"\nAll {len(sentences)} sentences processed successfully!") + yield "\n".join(status_messages), final_output_path + else: + status_messages.append("\nWarning: Failed to create final concatenated file.") + yield "\n".join(status_messages), processed_fragments[-1] + else: + status_messages.append("\nNo sentences were successfully processed.") + yield "\n".join(status_messages), None \ No newline at end of file diff --git a/modules/merged_ui/utils.py b/modules/merged_ui/utils.py index 7e95ee4..eacda36 100644 --- a/modules/merged_ui/utils.py +++ b/modules/merged_ui/utils.py @@ -1,20 +1,87 @@ +from datetime import datetime +import logging import os +import platform import shutil import re import numpy as np -from time import sleep import soundfile as sf from pydub import AudioSegment +import torch # Import modules from your packages +from spark.cli.SparkTTS import SparkTTS from rvc_ui.initialization import vc -from spark_ui.main import initialize_model, run_tts -from spark.sparktts.utils.token_parser import LEVELS_MAP_UI # Initialize the Spark TTS model (moved outside function to avoid reinitializing) model_dir = "spark/pretrained_models/Spark-TTS-0.5B" device = 0 -spark_model = initialize_model(model_dir, device=device) + +def initialize_model(model_dir, device): + """Load the model once at the beginning.""" + logging.info(f"Loading model from: {model_dir}") + + # Determine appropriate device based on platform and availability + if platform.system() == "Darwin": + # macOS with MPS support (Apple Silicon) + device = torch.device(f"mps:{device}") + logging.info(f"Using MPS device: {device}") + elif torch.cuda.is_available(): + # System with CUDA support + device = torch.device(f"cuda:{device}") + logging.info(f"Using CUDA device: {device}") + else: + # Fall back to CPU + device = torch.device("cpu") + logging.info("GPU acceleration not available, using CPU") + + model = SparkTTS(model_dir, device) + return model + + +def run_tts( + text, + prompt_text=None, + prompt_speech=None, + gender=None, + pitch=None, + speed=None, + save_dir="TEMP/spark", # Updated default save directory + save_filename=None, # New parameter to specify filename +): + """Perform TTS inference and save the generated audio.""" + model = initialize_model(model_dir, device=device) + logging.info(f"Saving audio to: {save_dir}") + + if prompt_text is not None: + prompt_text = None if len(prompt_text) <= 1 else prompt_text + + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Determine the save path based on save_filename if provided; otherwise, use a timestamp + if save_filename: + save_path = os.path.join(save_dir, save_filename) + else: + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + save_path = os.path.join(save_dir, f"{timestamp}.wav") + + logging.info("Starting inference...") + + # Perform inference and save the output audio + with torch.no_grad(): + wav = model.inference( + text, + prompt_speech, + prompt_text, + gender, + pitch, + speed, + ) + sf.write(save_path, wav, samplerate=16000) + + logging.info(f"Audio saved at: {save_path}") + return save_path def split_into_sentences(text): """ @@ -40,34 +107,26 @@ def process_single_sentence( base_fragment_num ): """ - Process a single sentence through the TTS and RVC pipeline - - Args: - sentence_index (int): Index of the sentence in the original text - sentence (str): The sentence text to process - ... (other parameters are the same as generate_and_process_with_rvc) - - Returns: - tuple: (spark_output_path, rvc_output_path, success, info_message) + Process a single sentence through the TTS and RVC pipeline. """ fragment_num = base_fragment_num + sentence_index - - # Generate TTS audio for this sentence + + # Generate TTS audio for this sentence, saving directly to the correct location tts_path = run_tts( sentence, - spark_model, prompt_text=prompt_text_clean, - prompt_speech=prompt_speech + prompt_speech=prompt_speech, + save_dir="./TEMP/spark", + save_filename=f"fragment_{fragment_num}.wav" ) - + # Make sure we have a TTS file to process if not tts_path or not os.path.exists(tts_path): return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}" - - # Save Spark output to TEMP/spark - spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav" - shutil.copy2(tts_path, spark_output_path) - + + # Use the tts_path as the Spark output (no need to copy) + spark_output_path = tts_path + # Call RVC processing function f0_file = None # We're not using an F0 curve file in this pipeline output_info, output_audio = vc.vc_single( @@ -75,12 +134,11 @@ def process_single_sentence( file_index1, file_index2, index_rate, filter_radius, resample_sr, rms_mix_rate, protect ) - + # Save RVC output to TEMP/rvc directory rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav" rvc_saved = False - - # Try different ways to save the RVC output based on common formats + try: if isinstance(output_audio, str) and os.path.exists(output_audio): # Case 1: output_audio is a file path string @@ -99,7 +157,7 @@ def process_single_sentence( rvc_saved = True except Exception as e: output_info += f"\nError saving RVC output: {str(e)}" - + # Prepare info message info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n" info_message += f" - Spark output: {spark_output_path}\n" @@ -107,7 +165,7 @@ def process_single_sentence( info_message += f" - RVC output: {rvc_output_path}" else: info_message += f" - Could not save RVC output to {rvc_output_path}" - + return spark_output_path, rvc_output_path, rvc_saved, info_message def concatenate_audio_files(file_paths, output_path, sample_rate=44100): diff --git a/modules/spark_ui/main.py b/modules/spark_ui/main.py index faa9529..b949011 100644 --- a/modules/spark_ui/main.py +++ b/modules/spark_ui/main.py @@ -11,72 +11,6 @@ from datetime import datetime from spark.cli.SparkTTS import SparkTTS from spark.sparktts.utils.token_parser import LEVELS_MAP_UI - -def initialize_model(model_dir, device): - """Load the model once at the beginning.""" - logging.info(f"Loading model from: {model_dir}") - - # Determine appropriate device based on platform and availability - if platform.system() == "Darwin": - # macOS with MPS support (Apple Silicon) - device = torch.device(f"mps:{device}") - logging.info(f"Using MPS device: {device}") - elif torch.cuda.is_available(): - # System with CUDA support - device = torch.device(f"cuda:{device}") - logging.info(f"Using CUDA device: {device}") - else: - # Fall back to CPU - device = torch.device("cpu") - logging.info("GPU acceleration not available, using CPU") - - model = SparkTTS(model_dir, device) - return model - - -def run_tts( - text, - model, - prompt_text=None, - prompt_speech=None, - gender=None, - pitch=None, - speed=None, - save_dir="spark/example/results", -): - """Perform TTS inference and save the generated audio.""" - logging.info(f"Saving audio to: {save_dir}") - - if prompt_text is not None: - prompt_text = None if len(prompt_text) <= 1 else prompt_text - - # Ensure the save directory exists - os.makedirs(save_dir, exist_ok=True) - - # Generate unique filename using timestamp - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") - save_path = os.path.join(save_dir, f"{timestamp}.wav") - - logging.info("Starting inference...") - - # Perform inference and save the output audio - with torch.no_grad(): - wav = model.inference( - text, - prompt_speech, - prompt_text, - gender, - pitch, - speed, - ) - - sf.write(save_path, wav, samplerate=16000) - - logging.info(f"Audio saved at: {save_path}") - - return save_path - - def build_spark_ui(): model_dir = "spark/pretrained_models/Spark-TTS-0.5B" device = 0