From 725238762165b4b142f6f3a700ccad7e260f0b54 Mon Sep 17 00:00:00 2001 From: VSlobolinskyi Date: Sat, 22 Mar 2025 00:41:08 +0200 Subject: [PATCH] Add cuda streams --- modules/merged_ui/main.py | 4 +- modules/merged_ui/utils.py | 392 +++++++++++++++++++++++-------------- 2 files changed, 245 insertions(+), 151 deletions(-) diff --git a/modules/merged_ui/main.py b/modules/merged_ui/main.py index 7901add..7b1f1c5 100644 --- a/modules/merged_ui/main.py +++ b/modules/merged_ui/main.py @@ -1,7 +1,7 @@ import gradio as gr # Import modules from your packages -from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc +from merged_ui.utils import generate_and_process_with_rvc_parallel, modified_get_vc from rvc_ui.initialization import config from rvc_ui.main import names, index_paths @@ -135,7 +135,7 @@ def build_merged_ui(): # Connect generate function to button with streaming enabled generate_with_rvc_button.click( - generate_and_process_with_rvc, + generate_and_process_with_rvc_parallel, inputs=[ tts_text_input, prompt_text_input, diff --git a/modules/merged_ui/utils.py b/modules/merged_ui/utils.py index 9c696f2..2771e6e 100644 --- a/modules/merged_ui/utils.py +++ b/modules/merged_ui/utils.py @@ -9,7 +9,9 @@ import numpy as np import soundfile as sf from pydub import AudioSegment import torch -from pydub import AudioSegment +import threading +from queue import Queue, Empty +from contextlib import nullcontext # Import modules from your packages from spark.cli.SparkTTS import SparkTTS @@ -48,10 +50,11 @@ def run_tts( gender=None, pitch=None, speed=None, - save_dir="TEMP/spark", # Updated default save directory - save_filename=None, # New parameter to specify filename + save_dir="TEMP/spark", + save_filename=None, + cuda_stream=None, ): - """Perform TTS inference and save the generated audio.""" + """Perform TTS inference using a specific CUDA stream.""" model = initialize_model(model_dir, device=device) logging.info(f"Saving audio to: {save_dir}") @@ -61,30 +64,75 @@ def run_tts( # Ensure the save directory exists os.makedirs(save_dir, exist_ok=True) - # Determine the save path based on save_filename if provided; otherwise, use a timestamp + # Determine the save path if save_filename: save_path = os.path.join(save_dir, save_filename) else: - timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") save_path = os.path.join(save_dir, f"{timestamp}.wav") - logging.info("Starting inference...") + logging.info("Starting TTS inference...") - # Perform inference and save the output audio - with torch.no_grad(): - wav = model.inference( - text, - prompt_speech, - prompt_text, - gender, - pitch, - speed, - ) - sf.write(save_path, wav, samplerate=16000) - - logging.info(f"Audio saved at: {save_path}") + # Perform inference using the specified CUDA stream + with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext(): + with torch.no_grad(): + wav = model.inference( + text, + prompt_speech, + prompt_text, + gender, + pitch, + speed, + ) + + # Save the audio (CPU operation) + sf.write(save_path, wav, samplerate=16000) + + logging.info(f"TTS audio saved at: {save_path}") return save_path + +def process_with_rvc( + spk_item, input_path, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect, + output_path, cuda_stream=None +): + """Process audio through RVC with a specific CUDA stream.""" + logging.info(f"Starting RVC inference for {input_path}...") + + # Set the CUDA stream if provided + with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext(): + # Call RVC processing function + f0_file = None # We're not using an F0 curve file + output_info, output_audio = vc.vc_single( + spk_item, input_path, vc_transform, f0_file, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect + ) + + # Save RVC output (CPU operation) + rvc_saved = False + try: + if isinstance(output_audio, str) and os.path.exists(output_audio): + # Case 1: output_audio is a file path string + shutil.copy2(output_audio, output_path) + rvc_saved = True + elif isinstance(output_audio, tuple) and len(output_audio) >= 2: + # Case 2: output_audio might be (sample_rate, audio_data) + sf.write(output_path, output_audio[1], output_audio[0]) + rvc_saved = True + elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name): + # Case 3: output_audio might be a file-like object + shutil.copy2(output_audio.name, output_path) + rvc_saved = True + except Exception as e: + output_info += f"\nError saving RVC output: {str(e)}" + + logging.info(f"RVC inference completed for {input_path}") + return rvc_saved, output_info + + def split_into_sentences(text): """ Split text into sentences using regular expressions. @@ -101,74 +149,188 @@ def split_into_sentences(text): sentences = [s.strip() for s in sentences if s.strip()] return sentences -def process_single_sentence( - sentence_index, sentence, prompt_speech, prompt_text_clean, + +def generate_and_process_with_rvc_parallel( + text, prompt_text, prompt_wav_upload, prompt_wav_record, spk_item, vc_transform, f0method, file_index1, file_index2, index_rate, filter_radius, - resample_sr, rms_mix_rate, protect, - base_fragment_num + resample_sr, rms_mix_rate, protect ): """ - Process a single sentence through the TTS and RVC pipeline. + Handle combined TTS and RVC processing using CUDA streams for parallel operation. + Uses a producer-consumer pattern where TTS produces audio files for RVC to consume. """ - fragment_num = base_fragment_num + sentence_index - - # Generate TTS audio for this sentence, saving directly to the correct location - tts_path = run_tts( - sentence, - prompt_text=prompt_text_clean, - prompt_speech=prompt_speech, - save_dir="./TEMP/spark", - save_filename=f"fragment_{fragment_num}.wav" - ) - - # Make sure we have a TTS file to process - if not tts_path or not os.path.exists(tts_path): - return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}" - - # Use the tts_path as the Spark output (no need to copy) - spark_output_path = tts_path - - # Call RVC processing function - f0_file = None # We're not using an F0 curve file in this pipeline - output_info, output_audio = vc.vc_single( - spk_item, tts_path, vc_transform, f0_file, f0method, - file_index1, file_index2, index_rate, filter_radius, - resample_sr, rms_mix_rate, protect - ) - - # Save RVC output to TEMP/rvc directory - rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav" - rvc_saved = False - - try: - if isinstance(output_audio, str) and os.path.exists(output_audio): - # Case 1: output_audio is a file path string - shutil.copy2(output_audio, rvc_output_path) - rvc_saved = True - elif isinstance(output_audio, tuple) and len(output_audio) >= 2: - # Case 2: output_audio might be (sample_rate, audio_data) - try: - sf.write(rvc_output_path, output_audio[1], output_audio[0]) - rvc_saved = True - except Exception as inner_e: - output_info += f"\nFailed to save RVC tuple format: {str(inner_e)}" - elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name): - # Case 3: output_audio might be a file-like object - shutil.copy2(output_audio.name, rvc_output_path) - rvc_saved = True - except Exception as e: - output_info += f"\nError saving RVC output: {str(e)}" - - # Prepare info message - info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n" - info_message += f" - Spark output: {spark_output_path}\n" - if rvc_saved: - info_message += f" - RVC output: {rvc_output_path}" + # Ensure TEMP directories exist + os.makedirs("./TEMP/spark", exist_ok=True) + os.makedirs("./TEMP/rvc", exist_ok=True) + + # Split text into sentences + sentences = split_into_sentences(text) + if not sentences: + yield "No valid text to process.", None + return + + # Get next base fragment number + base_fragment_num = 1 + while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or + os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav") + for i in range(len(sentences))): + base_fragment_num += 1 + + # Process reference speech + prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record + prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text + + # Create CUDA streams if CUDA is available + use_cuda = torch.cuda.is_available() + if use_cuda: + spark_stream = torch.cuda.Stream() + rvc_stream = torch.cuda.Stream() + logging.info("Using separate CUDA streams for Spark TTS and RVC") else: - info_message += f" - Could not save RVC output to {rvc_output_path}" + spark_stream = None + rvc_stream = None + logging.info("CUDA not available, parallel processing will be limited") + + # Create queues for communication between TTS and RVC + tts_to_rvc_queue = Queue() + rvc_results_queue = Queue() + + # Flag to signal completion + processing_complete = threading.Event() + + info_messages = [f"Processing {len(sentences)} sentences using parallel CUDA streams..."] + + # Yield initial message with no audio yet + yield "\n".join(info_messages), None + + # TTS worker function + def tts_worker(): + for i, sentence in enumerate(sentences): + fragment_num = base_fragment_num + i + tts_filename = f"fragment_{fragment_num}.wav" + + try: + # Use the TTS CUDA stream + path = run_tts( + sentence, + prompt_text=prompt_text_clean, + prompt_speech=prompt_speech, + save_dir="./TEMP/spark", + save_filename=tts_filename, + cuda_stream=spark_stream + ) + # Put the path and sentence info to the queue for RVC processing + tts_to_rvc_queue.put((i, fragment_num, sentence, path)) + except Exception as e: + logging.error(f"TTS processing error for sentence {i}: {str(e)}") + tts_to_rvc_queue.put((i, fragment_num, sentence, None, str(e))) + + # Signal TTS completion + tts_to_rvc_queue.put(None) + + # RVC worker function + def rvc_worker(): + while True: + # Get item from the queue + item = tts_to_rvc_queue.get() + + # Check for the sentinel value (None) that signals completion + if item is None: + break + + # Unpack the item + if len(item) == 5: # Error case + i, fragment_num, sentence, _, error = item + rvc_results_queue.put((i, None, None, False, f"TTS error for sentence {i+1}: {error}")) + continue + + i, fragment_num, sentence, tts_path = item + + if not tts_path or not os.path.exists(tts_path): + rvc_results_queue.put((i, None, None, False, f"No TTS output for sentence {i+1}")) + continue + + # Prepare RVC path + rvc_path = os.path.join("./TEMP/rvc", f"fragment_{fragment_num}.wav") + + try: + # Process with RVC + rvc_success, rvc_info = process_with_rvc( + spk_item, tts_path, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect, + rvc_path, cuda_stream=rvc_stream + ) + + # Create info message + info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n" + info_message += f" - Spark output: {tts_path}\n" + if rvc_success: + info_message += f" - RVC output: {rvc_path}" + else: + info_message += f" - Could not save RVC output to {rvc_path}" + + # Put the results to the queue + rvc_results_queue.put((i, tts_path, rvc_path if rvc_success else None, rvc_success, info_message)) + except Exception as e: + logging.error(f"RVC processing error for sentence {i}: {str(e)}") + info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n" + info_message += f" - Spark output: {tts_path}\n" + info_message += f" - RVC processing error: {str(e)}" + rvc_results_queue.put((i, tts_path, None, False, info_message)) + + # Signal RVC completion + processing_complete.set() + + # Start the worker threads + tts_thread = threading.Thread(target=tts_worker) + rvc_thread = threading.Thread(target=rvc_worker) + + tts_thread.start() + rvc_thread.start() + + # Process results as they become available + completed_sentences = {} + next_to_yield = 0 + + while not processing_complete.is_set() or not rvc_results_queue.empty(): + try: + # Try to get an item from the results queue with a timeout + try: + i, tts_path, rvc_path, success, info = rvc_results_queue.get(timeout=0.1) + completed_sentences[i] = (tts_path, rvc_path, success, info) + except Empty: + # No results available yet, continue the loop + continue + + # Check if we can yield the next sentence + while next_to_yield in completed_sentences: + _, rvc_path, _, info = completed_sentences[next_to_yield] + info_messages.append(info) + + # Yield the current state + yield "\n".join(info_messages), rvc_path + + # Move to the next sentence + next_to_yield += 1 + except Exception as e: + logging.error(f"Error in main processing loop: {str(e)}") + info_messages.append(f"Error in processing: {str(e)}") + yield "\n".join(info_messages), None + break + + # Join the threads + tts_thread.join() + rvc_thread.join() + + # Yield any remaining sentences in order + remaining_indices = sorted([i for i in completed_sentences if i >= next_to_yield]) + for i in remaining_indices: + _, rvc_path, _, info = completed_sentences[i] + info_messages.append(info) + yield "\n".join(info_messages), rvc_path - return spark_output_path, rvc_output_path, rvc_saved, info_message def concatenate_audio_files(file_paths, output_path, sample_rate=44100): """ @@ -213,74 +375,6 @@ def concatenate_audio_files(file_paths, output_path, sample_rate=44100): print(f"Fallback concatenation failed: {str(e2)}") return False -def generate_and_process_with_rvc( - text, prompt_text, prompt_wav_upload, prompt_wav_record, - spk_item, vc_transform, f0method, - file_index1, file_index2, index_rate, filter_radius, - resample_sr, rms_mix_rate, protect -): - """ - Handle combined TTS and RVC processing for multiple sentences and yield outputs as they are processed. - The output is just the latest processed audio. - Before yielding a new audio fragment, the function waits for the previous one to finish playing, - based on its duration. - """ - # Ensure TEMP directories exist - os.makedirs("./TEMP/spark", exist_ok=True) - os.makedirs("./TEMP/rvc", exist_ok=True) - - # Split text into sentences - sentences = split_into_sentences(text) - if not sentences: - yield "No valid text to process.", None - return - - # Get next base fragment number - base_fragment_num = 1 - while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or - os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav") - for i in range(len(sentences))): - base_fragment_num += 1 - - # Process reference speech - prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record - prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text - - info_messages = [f"Processing {len(sentences)} sentences..."] - - # Yield initial message with no audio yet - yield "\n".join(info_messages), None - - # Set up a timer to simulate playback duration - next_available_time = time.time() - - for i, sentence in enumerate(sentences): - spark_path, rvc_path, success, info = process_single_sentence( - i, sentence, prompt_speech, prompt_text_clean, - spk_item, vc_transform, f0method, - file_index1, file_index2, index_rate, filter_radius, - resample_sr, rms_mix_rate, protect, - base_fragment_num - ) - - info_messages.append(info) - # Only update output if processing was successful and we have an audio file - if success and rvc_path: - try: - audio_seg = AudioSegment.from_file(rvc_path) - duration = audio_seg.duration_seconds - except Exception as e: - duration = 0 - - current_time = time.time() - if current_time < next_available_time: - time.sleep(next_available_time - current_time) - - yield "\n".join(info_messages), rvc_path - - next_available_time = time.time() + duration - - yield "\n".join(info_messages), rvc_path def modified_get_vc(sid0_value, protect0_value, file_index2_component): """