Add cuda streams

This commit is contained in:
VSlobolinskyi 2025-03-22 00:41:08 +02:00
parent 1db54518ff
commit 7252387621
2 changed files with 245 additions and 151 deletions

View File

@ -1,7 +1,7 @@
import gradio as gr import gradio as gr
# Import modules from your packages # Import modules from your packages
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc from merged_ui.utils import generate_and_process_with_rvc_parallel, modified_get_vc
from rvc_ui.initialization import config from rvc_ui.initialization import config
from rvc_ui.main import names, index_paths from rvc_ui.main import names, index_paths
@ -135,7 +135,7 @@ def build_merged_ui():
# Connect generate function to button with streaming enabled # Connect generate function to button with streaming enabled
generate_with_rvc_button.click( generate_with_rvc_button.click(
generate_and_process_with_rvc, generate_and_process_with_rvc_parallel,
inputs=[ inputs=[
tts_text_input, tts_text_input,
prompt_text_input, prompt_text_input,

View File

@ -9,7 +9,9 @@ import numpy as np
import soundfile as sf import soundfile as sf
from pydub import AudioSegment from pydub import AudioSegment
import torch import torch
from pydub import AudioSegment import threading
from queue import Queue, Empty
from contextlib import nullcontext
# Import modules from your packages # Import modules from your packages
from spark.cli.SparkTTS import SparkTTS from spark.cli.SparkTTS import SparkTTS
@ -48,10 +50,11 @@ def run_tts(
gender=None, gender=None,
pitch=None, pitch=None,
speed=None, speed=None,
save_dir="TEMP/spark", # Updated default save directory save_dir="TEMP/spark",
save_filename=None, # New parameter to specify filename save_filename=None,
cuda_stream=None,
): ):
"""Perform TTS inference and save the generated audio.""" """Perform TTS inference using a specific CUDA stream."""
model = initialize_model(model_dir, device=device) model = initialize_model(model_dir, device=device)
logging.info(f"Saving audio to: {save_dir}") logging.info(f"Saving audio to: {save_dir}")
@ -61,30 +64,75 @@ def run_tts(
# Ensure the save directory exists # Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Determine the save path based on save_filename if provided; otherwise, use a timestamp # Determine the save path
if save_filename: if save_filename:
save_path = os.path.join(save_dir, save_filename) save_path = os.path.join(save_dir, save_filename)
else: else:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav") save_path = os.path.join(save_dir, f"{timestamp}.wav")
logging.info("Starting inference...") logging.info("Starting TTS inference...")
# Perform inference and save the output audio # Perform inference using the specified CUDA stream
with torch.no_grad(): with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext():
wav = model.inference( with torch.no_grad():
text, wav = model.inference(
prompt_speech, text,
prompt_text, prompt_speech,
gender, prompt_text,
pitch, gender,
speed, pitch,
) speed,
sf.write(save_path, wav, samplerate=16000) )
logging.info(f"Audio saved at: {save_path}") # Save the audio (CPU operation)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"TTS audio saved at: {save_path}")
return save_path return save_path
def process_with_rvc(
spk_item, input_path, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect,
output_path, cuda_stream=None
):
"""Process audio through RVC with a specific CUDA stream."""
logging.info(f"Starting RVC inference for {input_path}...")
# Set the CUDA stream if provided
with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext():
# Call RVC processing function
f0_file = None # We're not using an F0 curve file
output_info, output_audio = vc.vc_single(
spk_item, input_path, vc_transform, f0_file, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
)
# Save RVC output (CPU operation)
rvc_saved = False
try:
if isinstance(output_audio, str) and os.path.exists(output_audio):
# Case 1: output_audio is a file path string
shutil.copy2(output_audio, output_path)
rvc_saved = True
elif isinstance(output_audio, tuple) and len(output_audio) >= 2:
# Case 2: output_audio might be (sample_rate, audio_data)
sf.write(output_path, output_audio[1], output_audio[0])
rvc_saved = True
elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name):
# Case 3: output_audio might be a file-like object
shutil.copy2(output_audio.name, output_path)
rvc_saved = True
except Exception as e:
output_info += f"\nError saving RVC output: {str(e)}"
logging.info(f"RVC inference completed for {input_path}")
return rvc_saved, output_info
def split_into_sentences(text): def split_into_sentences(text):
""" """
Split text into sentences using regular expressions. Split text into sentences using regular expressions.
@ -101,74 +149,188 @@ def split_into_sentences(text):
sentences = [s.strip() for s in sentences if s.strip()] sentences = [s.strip() for s in sentences if s.strip()]
return sentences return sentences
def process_single_sentence(
sentence_index, sentence, prompt_speech, prompt_text_clean, def generate_and_process_with_rvc_parallel(
text, prompt_text, prompt_wav_upload, prompt_wav_record,
spk_item, vc_transform, f0method, spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius, file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect, resample_sr, rms_mix_rate, protect
base_fragment_num
): ):
""" """
Process a single sentence through the TTS and RVC pipeline. Handle combined TTS and RVC processing using CUDA streams for parallel operation.
Uses a producer-consumer pattern where TTS produces audio files for RVC to consume.
""" """
fragment_num = base_fragment_num + sentence_index # Ensure TEMP directories exist
os.makedirs("./TEMP/spark", exist_ok=True)
# Generate TTS audio for this sentence, saving directly to the correct location os.makedirs("./TEMP/rvc", exist_ok=True)
tts_path = run_tts(
sentence, # Split text into sentences
prompt_text=prompt_text_clean, sentences = split_into_sentences(text)
prompt_speech=prompt_speech, if not sentences:
save_dir="./TEMP/spark", yield "No valid text to process.", None
save_filename=f"fragment_{fragment_num}.wav" return
)
# Get next base fragment number
# Make sure we have a TTS file to process base_fragment_num = 1
if not tts_path or not os.path.exists(tts_path): while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}" os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav")
for i in range(len(sentences))):
# Use the tts_path as the Spark output (no need to copy) base_fragment_num += 1
spark_output_path = tts_path
# Process reference speech
# Call RVC processing function prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
f0_file = None # We're not using an F0 curve file in this pipeline prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
output_info, output_audio = vc.vc_single(
spk_item, tts_path, vc_transform, f0_file, f0method, # Create CUDA streams if CUDA is available
file_index1, file_index2, index_rate, filter_radius, use_cuda = torch.cuda.is_available()
resample_sr, rms_mix_rate, protect if use_cuda:
) spark_stream = torch.cuda.Stream()
rvc_stream = torch.cuda.Stream()
# Save RVC output to TEMP/rvc directory logging.info("Using separate CUDA streams for Spark TTS and RVC")
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
rvc_saved = False
try:
if isinstance(output_audio, str) and os.path.exists(output_audio):
# Case 1: output_audio is a file path string
shutil.copy2(output_audio, rvc_output_path)
rvc_saved = True
elif isinstance(output_audio, tuple) and len(output_audio) >= 2:
# Case 2: output_audio might be (sample_rate, audio_data)
try:
sf.write(rvc_output_path, output_audio[1], output_audio[0])
rvc_saved = True
except Exception as inner_e:
output_info += f"\nFailed to save RVC tuple format: {str(inner_e)}"
elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name):
# Case 3: output_audio might be a file-like object
shutil.copy2(output_audio.name, rvc_output_path)
rvc_saved = True
except Exception as e:
output_info += f"\nError saving RVC output: {str(e)}"
# Prepare info message
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
info_message += f" - Spark output: {spark_output_path}\n"
if rvc_saved:
info_message += f" - RVC output: {rvc_output_path}"
else: else:
info_message += f" - Could not save RVC output to {rvc_output_path}" spark_stream = None
rvc_stream = None
logging.info("CUDA not available, parallel processing will be limited")
# Create queues for communication between TTS and RVC
tts_to_rvc_queue = Queue()
rvc_results_queue = Queue()
# Flag to signal completion
processing_complete = threading.Event()
info_messages = [f"Processing {len(sentences)} sentences using parallel CUDA streams..."]
# Yield initial message with no audio yet
yield "\n".join(info_messages), None
# TTS worker function
def tts_worker():
for i, sentence in enumerate(sentences):
fragment_num = base_fragment_num + i
tts_filename = f"fragment_{fragment_num}.wav"
try:
# Use the TTS CUDA stream
path = run_tts(
sentence,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech,
save_dir="./TEMP/spark",
save_filename=tts_filename,
cuda_stream=spark_stream
)
# Put the path and sentence info to the queue for RVC processing
tts_to_rvc_queue.put((i, fragment_num, sentence, path))
except Exception as e:
logging.error(f"TTS processing error for sentence {i}: {str(e)}")
tts_to_rvc_queue.put((i, fragment_num, sentence, None, str(e)))
# Signal TTS completion
tts_to_rvc_queue.put(None)
# RVC worker function
def rvc_worker():
while True:
# Get item from the queue
item = tts_to_rvc_queue.get()
# Check for the sentinel value (None) that signals completion
if item is None:
break
# Unpack the item
if len(item) == 5: # Error case
i, fragment_num, sentence, _, error = item
rvc_results_queue.put((i, None, None, False, f"TTS error for sentence {i+1}: {error}"))
continue
i, fragment_num, sentence, tts_path = item
if not tts_path or not os.path.exists(tts_path):
rvc_results_queue.put((i, None, None, False, f"No TTS output for sentence {i+1}"))
continue
# Prepare RVC path
rvc_path = os.path.join("./TEMP/rvc", f"fragment_{fragment_num}.wav")
try:
# Process with RVC
rvc_success, rvc_info = process_with_rvc(
spk_item, tts_path, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect,
rvc_path, cuda_stream=rvc_stream
)
# Create info message
info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
info_message += f" - Spark output: {tts_path}\n"
if rvc_success:
info_message += f" - RVC output: {rvc_path}"
else:
info_message += f" - Could not save RVC output to {rvc_path}"
# Put the results to the queue
rvc_results_queue.put((i, tts_path, rvc_path if rvc_success else None, rvc_success, info_message))
except Exception as e:
logging.error(f"RVC processing error for sentence {i}: {str(e)}")
info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
info_message += f" - Spark output: {tts_path}\n"
info_message += f" - RVC processing error: {str(e)}"
rvc_results_queue.put((i, tts_path, None, False, info_message))
# Signal RVC completion
processing_complete.set()
# Start the worker threads
tts_thread = threading.Thread(target=tts_worker)
rvc_thread = threading.Thread(target=rvc_worker)
tts_thread.start()
rvc_thread.start()
# Process results as they become available
completed_sentences = {}
next_to_yield = 0
while not processing_complete.is_set() or not rvc_results_queue.empty():
try:
# Try to get an item from the results queue with a timeout
try:
i, tts_path, rvc_path, success, info = rvc_results_queue.get(timeout=0.1)
completed_sentences[i] = (tts_path, rvc_path, success, info)
except Empty:
# No results available yet, continue the loop
continue
# Check if we can yield the next sentence
while next_to_yield in completed_sentences:
_, rvc_path, _, info = completed_sentences[next_to_yield]
info_messages.append(info)
# Yield the current state
yield "\n".join(info_messages), rvc_path
# Move to the next sentence
next_to_yield += 1
except Exception as e:
logging.error(f"Error in main processing loop: {str(e)}")
info_messages.append(f"Error in processing: {str(e)}")
yield "\n".join(info_messages), None
break
# Join the threads
tts_thread.join()
rvc_thread.join()
# Yield any remaining sentences in order
remaining_indices = sorted([i for i in completed_sentences if i >= next_to_yield])
for i in remaining_indices:
_, rvc_path, _, info = completed_sentences[i]
info_messages.append(info)
yield "\n".join(info_messages), rvc_path
return spark_output_path, rvc_output_path, rvc_saved, info_message
def concatenate_audio_files(file_paths, output_path, sample_rate=44100): def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
""" """
@ -213,74 +375,6 @@ def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
print(f"Fallback concatenation failed: {str(e2)}") print(f"Fallback concatenation failed: {str(e2)}")
return False return False
def generate_and_process_with_rvc(
text, prompt_text, prompt_wav_upload, prompt_wav_record,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
"""
Handle combined TTS and RVC processing for multiple sentences and yield outputs as they are processed.
The output is just the latest processed audio.
Before yielding a new audio fragment, the function waits for the previous one to finish playing,
based on its duration.
"""
# Ensure TEMP directories exist
os.makedirs("./TEMP/spark", exist_ok=True)
os.makedirs("./TEMP/rvc", exist_ok=True)
# Split text into sentences
sentences = split_into_sentences(text)
if not sentences:
yield "No valid text to process.", None
return
# Get next base fragment number
base_fragment_num = 1
while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or
os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav")
for i in range(len(sentences))):
base_fragment_num += 1
# Process reference speech
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
info_messages = [f"Processing {len(sentences)} sentences..."]
# Yield initial message with no audio yet
yield "\n".join(info_messages), None
# Set up a timer to simulate playback duration
next_available_time = time.time()
for i, sentence in enumerate(sentences):
spark_path, rvc_path, success, info = process_single_sentence(
i, sentence, prompt_speech, prompt_text_clean,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect,
base_fragment_num
)
info_messages.append(info)
# Only update output if processing was successful and we have an audio file
if success and rvc_path:
try:
audio_seg = AudioSegment.from_file(rvc_path)
duration = audio_seg.duration_seconds
except Exception as e:
duration = 0
current_time = time.time()
if current_time < next_available_time:
time.sleep(next_available_time - current_time)
yield "\n".join(info_messages), rvc_path
next_available_time = time.time() + duration
yield "\n".join(info_messages), rvc_path
def modified_get_vc(sid0_value, protect0_value, file_index2_component): def modified_get_vc(sid0_value, protect0_value, file_index2_component):
""" """