mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
Add cuda streams
This commit is contained in:
parent
1db54518ff
commit
7252387621
@ -1,7 +1,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
# Import modules from your packages
|
# Import modules from your packages
|
||||||
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc
|
from merged_ui.utils import generate_and_process_with_rvc_parallel, modified_get_vc
|
||||||
from rvc_ui.initialization import config
|
from rvc_ui.initialization import config
|
||||||
from rvc_ui.main import names, index_paths
|
from rvc_ui.main import names, index_paths
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ def build_merged_ui():
|
|||||||
|
|
||||||
# Connect generate function to button with streaming enabled
|
# Connect generate function to button with streaming enabled
|
||||||
generate_with_rvc_button.click(
|
generate_with_rvc_button.click(
|
||||||
generate_and_process_with_rvc,
|
generate_and_process_with_rvc_parallel,
|
||||||
inputs=[
|
inputs=[
|
||||||
tts_text_input,
|
tts_text_input,
|
||||||
prompt_text_input,
|
prompt_text_input,
|
||||||
|
@ -9,7 +9,9 @@ import numpy as np
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
import torch
|
import torch
|
||||||
from pydub import AudioSegment
|
import threading
|
||||||
|
from queue import Queue, Empty
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
# Import modules from your packages
|
# Import modules from your packages
|
||||||
from spark.cli.SparkTTS import SparkTTS
|
from spark.cli.SparkTTS import SparkTTS
|
||||||
@ -48,10 +50,11 @@ def run_tts(
|
|||||||
gender=None,
|
gender=None,
|
||||||
pitch=None,
|
pitch=None,
|
||||||
speed=None,
|
speed=None,
|
||||||
save_dir="TEMP/spark", # Updated default save directory
|
save_dir="TEMP/spark",
|
||||||
save_filename=None, # New parameter to specify filename
|
save_filename=None,
|
||||||
|
cuda_stream=None,
|
||||||
):
|
):
|
||||||
"""Perform TTS inference and save the generated audio."""
|
"""Perform TTS inference using a specific CUDA stream."""
|
||||||
model = initialize_model(model_dir, device=device)
|
model = initialize_model(model_dir, device=device)
|
||||||
logging.info(f"Saving audio to: {save_dir}")
|
logging.info(f"Saving audio to: {save_dir}")
|
||||||
|
|
||||||
@ -61,30 +64,75 @@ def run_tts(
|
|||||||
# Ensure the save directory exists
|
# Ensure the save directory exists
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
|
# Determine the save path
|
||||||
if save_filename:
|
if save_filename:
|
||||||
save_path = os.path.join(save_dir, save_filename)
|
save_path = os.path.join(save_dir, save_filename)
|
||||||
else:
|
else:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
||||||
|
|
||||||
logging.info("Starting inference...")
|
logging.info("Starting TTS inference...")
|
||||||
|
|
||||||
# Perform inference and save the output audio
|
# Perform inference using the specified CUDA stream
|
||||||
with torch.no_grad():
|
with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext():
|
||||||
wav = model.inference(
|
with torch.no_grad():
|
||||||
text,
|
wav = model.inference(
|
||||||
prompt_speech,
|
text,
|
||||||
prompt_text,
|
prompt_speech,
|
||||||
gender,
|
prompt_text,
|
||||||
pitch,
|
gender,
|
||||||
speed,
|
pitch,
|
||||||
)
|
speed,
|
||||||
sf.write(save_path, wav, samplerate=16000)
|
)
|
||||||
|
|
||||||
logging.info(f"Audio saved at: {save_path}")
|
# Save the audio (CPU operation)
|
||||||
|
sf.write(save_path, wav, samplerate=16000)
|
||||||
|
|
||||||
|
logging.info(f"TTS audio saved at: {save_path}")
|
||||||
return save_path
|
return save_path
|
||||||
|
|
||||||
|
|
||||||
|
def process_with_rvc(
|
||||||
|
spk_item, input_path, vc_transform, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect,
|
||||||
|
output_path, cuda_stream=None
|
||||||
|
):
|
||||||
|
"""Process audio through RVC with a specific CUDA stream."""
|
||||||
|
logging.info(f"Starting RVC inference for {input_path}...")
|
||||||
|
|
||||||
|
# Set the CUDA stream if provided
|
||||||
|
with torch.cuda.stream(cuda_stream) if cuda_stream and torch.cuda.is_available() else nullcontext():
|
||||||
|
# Call RVC processing function
|
||||||
|
f0_file = None # We're not using an F0 curve file
|
||||||
|
output_info, output_audio = vc.vc_single(
|
||||||
|
spk_item, input_path, vc_transform, f0_file, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save RVC output (CPU operation)
|
||||||
|
rvc_saved = False
|
||||||
|
try:
|
||||||
|
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
||||||
|
# Case 1: output_audio is a file path string
|
||||||
|
shutil.copy2(output_audio, output_path)
|
||||||
|
rvc_saved = True
|
||||||
|
elif isinstance(output_audio, tuple) and len(output_audio) >= 2:
|
||||||
|
# Case 2: output_audio might be (sample_rate, audio_data)
|
||||||
|
sf.write(output_path, output_audio[1], output_audio[0])
|
||||||
|
rvc_saved = True
|
||||||
|
elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name):
|
||||||
|
# Case 3: output_audio might be a file-like object
|
||||||
|
shutil.copy2(output_audio.name, output_path)
|
||||||
|
rvc_saved = True
|
||||||
|
except Exception as e:
|
||||||
|
output_info += f"\nError saving RVC output: {str(e)}"
|
||||||
|
|
||||||
|
logging.info(f"RVC inference completed for {input_path}")
|
||||||
|
return rvc_saved, output_info
|
||||||
|
|
||||||
|
|
||||||
def split_into_sentences(text):
|
def split_into_sentences(text):
|
||||||
"""
|
"""
|
||||||
Split text into sentences using regular expressions.
|
Split text into sentences using regular expressions.
|
||||||
@ -101,74 +149,188 @@ def split_into_sentences(text):
|
|||||||
sentences = [s.strip() for s in sentences if s.strip()]
|
sentences = [s.strip() for s in sentences if s.strip()]
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
def process_single_sentence(
|
|
||||||
sentence_index, sentence, prompt_speech, prompt_text_clean,
|
def generate_and_process_with_rvc_parallel(
|
||||||
|
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
||||||
spk_item, vc_transform, f0method,
|
spk_item, vc_transform, f0method,
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
resample_sr, rms_mix_rate, protect,
|
resample_sr, rms_mix_rate, protect
|
||||||
base_fragment_num
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process a single sentence through the TTS and RVC pipeline.
|
Handle combined TTS and RVC processing using CUDA streams for parallel operation.
|
||||||
|
Uses a producer-consumer pattern where TTS produces audio files for RVC to consume.
|
||||||
"""
|
"""
|
||||||
fragment_num = base_fragment_num + sentence_index
|
# Ensure TEMP directories exist
|
||||||
|
os.makedirs("./TEMP/spark", exist_ok=True)
|
||||||
# Generate TTS audio for this sentence, saving directly to the correct location
|
os.makedirs("./TEMP/rvc", exist_ok=True)
|
||||||
tts_path = run_tts(
|
|
||||||
sentence,
|
# Split text into sentences
|
||||||
prompt_text=prompt_text_clean,
|
sentences = split_into_sentences(text)
|
||||||
prompt_speech=prompt_speech,
|
if not sentences:
|
||||||
save_dir="./TEMP/spark",
|
yield "No valid text to process.", None
|
||||||
save_filename=f"fragment_{fragment_num}.wav"
|
return
|
||||||
)
|
|
||||||
|
# Get next base fragment number
|
||||||
# Make sure we have a TTS file to process
|
base_fragment_num = 1
|
||||||
if not tts_path or not os.path.exists(tts_path):
|
while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or
|
||||||
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
|
os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav")
|
||||||
|
for i in range(len(sentences))):
|
||||||
# Use the tts_path as the Spark output (no need to copy)
|
base_fragment_num += 1
|
||||||
spark_output_path = tts_path
|
|
||||||
|
# Process reference speech
|
||||||
# Call RVC processing function
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
||||||
f0_file = None # We're not using an F0 curve file in this pipeline
|
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
||||||
output_info, output_audio = vc.vc_single(
|
|
||||||
spk_item, tts_path, vc_transform, f0_file, f0method,
|
# Create CUDA streams if CUDA is available
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
use_cuda = torch.cuda.is_available()
|
||||||
resample_sr, rms_mix_rate, protect
|
if use_cuda:
|
||||||
)
|
spark_stream = torch.cuda.Stream()
|
||||||
|
rvc_stream = torch.cuda.Stream()
|
||||||
# Save RVC output to TEMP/rvc directory
|
logging.info("Using separate CUDA streams for Spark TTS and RVC")
|
||||||
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
|
||||||
rvc_saved = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
|
||||||
# Case 1: output_audio is a file path string
|
|
||||||
shutil.copy2(output_audio, rvc_output_path)
|
|
||||||
rvc_saved = True
|
|
||||||
elif isinstance(output_audio, tuple) and len(output_audio) >= 2:
|
|
||||||
# Case 2: output_audio might be (sample_rate, audio_data)
|
|
||||||
try:
|
|
||||||
sf.write(rvc_output_path, output_audio[1], output_audio[0])
|
|
||||||
rvc_saved = True
|
|
||||||
except Exception as inner_e:
|
|
||||||
output_info += f"\nFailed to save RVC tuple format: {str(inner_e)}"
|
|
||||||
elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name):
|
|
||||||
# Case 3: output_audio might be a file-like object
|
|
||||||
shutil.copy2(output_audio.name, rvc_output_path)
|
|
||||||
rvc_saved = True
|
|
||||||
except Exception as e:
|
|
||||||
output_info += f"\nError saving RVC output: {str(e)}"
|
|
||||||
|
|
||||||
# Prepare info message
|
|
||||||
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
|
||||||
info_message += f" - Spark output: {spark_output_path}\n"
|
|
||||||
if rvc_saved:
|
|
||||||
info_message += f" - RVC output: {rvc_output_path}"
|
|
||||||
else:
|
else:
|
||||||
info_message += f" - Could not save RVC output to {rvc_output_path}"
|
spark_stream = None
|
||||||
|
rvc_stream = None
|
||||||
|
logging.info("CUDA not available, parallel processing will be limited")
|
||||||
|
|
||||||
|
# Create queues for communication between TTS and RVC
|
||||||
|
tts_to_rvc_queue = Queue()
|
||||||
|
rvc_results_queue = Queue()
|
||||||
|
|
||||||
|
# Flag to signal completion
|
||||||
|
processing_complete = threading.Event()
|
||||||
|
|
||||||
|
info_messages = [f"Processing {len(sentences)} sentences using parallel CUDA streams..."]
|
||||||
|
|
||||||
|
# Yield initial message with no audio yet
|
||||||
|
yield "\n".join(info_messages), None
|
||||||
|
|
||||||
|
# TTS worker function
|
||||||
|
def tts_worker():
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
fragment_num = base_fragment_num + i
|
||||||
|
tts_filename = f"fragment_{fragment_num}.wav"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the TTS CUDA stream
|
||||||
|
path = run_tts(
|
||||||
|
sentence,
|
||||||
|
prompt_text=prompt_text_clean,
|
||||||
|
prompt_speech=prompt_speech,
|
||||||
|
save_dir="./TEMP/spark",
|
||||||
|
save_filename=tts_filename,
|
||||||
|
cuda_stream=spark_stream
|
||||||
|
)
|
||||||
|
# Put the path and sentence info to the queue for RVC processing
|
||||||
|
tts_to_rvc_queue.put((i, fragment_num, sentence, path))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"TTS processing error for sentence {i}: {str(e)}")
|
||||||
|
tts_to_rvc_queue.put((i, fragment_num, sentence, None, str(e)))
|
||||||
|
|
||||||
|
# Signal TTS completion
|
||||||
|
tts_to_rvc_queue.put(None)
|
||||||
|
|
||||||
|
# RVC worker function
|
||||||
|
def rvc_worker():
|
||||||
|
while True:
|
||||||
|
# Get item from the queue
|
||||||
|
item = tts_to_rvc_queue.get()
|
||||||
|
|
||||||
|
# Check for the sentinel value (None) that signals completion
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Unpack the item
|
||||||
|
if len(item) == 5: # Error case
|
||||||
|
i, fragment_num, sentence, _, error = item
|
||||||
|
rvc_results_queue.put((i, None, None, False, f"TTS error for sentence {i+1}: {error}"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
i, fragment_num, sentence, tts_path = item
|
||||||
|
|
||||||
|
if not tts_path or not os.path.exists(tts_path):
|
||||||
|
rvc_results_queue.put((i, None, None, False, f"No TTS output for sentence {i+1}"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare RVC path
|
||||||
|
rvc_path = os.path.join("./TEMP/rvc", f"fragment_{fragment_num}.wav")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process with RVC
|
||||||
|
rvc_success, rvc_info = process_with_rvc(
|
||||||
|
spk_item, tts_path, vc_transform, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect,
|
||||||
|
rvc_path, cuda_stream=rvc_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create info message
|
||||||
|
info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
||||||
|
info_message += f" - Spark output: {tts_path}\n"
|
||||||
|
if rvc_success:
|
||||||
|
info_message += f" - RVC output: {rvc_path}"
|
||||||
|
else:
|
||||||
|
info_message += f" - Could not save RVC output to {rvc_path}"
|
||||||
|
|
||||||
|
# Put the results to the queue
|
||||||
|
rvc_results_queue.put((i, tts_path, rvc_path if rvc_success else None, rvc_success, info_message))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"RVC processing error for sentence {i}: {str(e)}")
|
||||||
|
info_message = f"Sentence {i+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
||||||
|
info_message += f" - Spark output: {tts_path}\n"
|
||||||
|
info_message += f" - RVC processing error: {str(e)}"
|
||||||
|
rvc_results_queue.put((i, tts_path, None, False, info_message))
|
||||||
|
|
||||||
|
# Signal RVC completion
|
||||||
|
processing_complete.set()
|
||||||
|
|
||||||
|
# Start the worker threads
|
||||||
|
tts_thread = threading.Thread(target=tts_worker)
|
||||||
|
rvc_thread = threading.Thread(target=rvc_worker)
|
||||||
|
|
||||||
|
tts_thread.start()
|
||||||
|
rvc_thread.start()
|
||||||
|
|
||||||
|
# Process results as they become available
|
||||||
|
completed_sentences = {}
|
||||||
|
next_to_yield = 0
|
||||||
|
|
||||||
|
while not processing_complete.is_set() or not rvc_results_queue.empty():
|
||||||
|
try:
|
||||||
|
# Try to get an item from the results queue with a timeout
|
||||||
|
try:
|
||||||
|
i, tts_path, rvc_path, success, info = rvc_results_queue.get(timeout=0.1)
|
||||||
|
completed_sentences[i] = (tts_path, rvc_path, success, info)
|
||||||
|
except Empty:
|
||||||
|
# No results available yet, continue the loop
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we can yield the next sentence
|
||||||
|
while next_to_yield in completed_sentences:
|
||||||
|
_, rvc_path, _, info = completed_sentences[next_to_yield]
|
||||||
|
info_messages.append(info)
|
||||||
|
|
||||||
|
# Yield the current state
|
||||||
|
yield "\n".join(info_messages), rvc_path
|
||||||
|
|
||||||
|
# Move to the next sentence
|
||||||
|
next_to_yield += 1
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error in main processing loop: {str(e)}")
|
||||||
|
info_messages.append(f"Error in processing: {str(e)}")
|
||||||
|
yield "\n".join(info_messages), None
|
||||||
|
break
|
||||||
|
|
||||||
|
# Join the threads
|
||||||
|
tts_thread.join()
|
||||||
|
rvc_thread.join()
|
||||||
|
|
||||||
|
# Yield any remaining sentences in order
|
||||||
|
remaining_indices = sorted([i for i in completed_sentences if i >= next_to_yield])
|
||||||
|
for i in remaining_indices:
|
||||||
|
_, rvc_path, _, info = completed_sentences[i]
|
||||||
|
info_messages.append(info)
|
||||||
|
yield "\n".join(info_messages), rvc_path
|
||||||
|
|
||||||
return spark_output_path, rvc_output_path, rvc_saved, info_message
|
|
||||||
|
|
||||||
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
||||||
"""
|
"""
|
||||||
@ -213,74 +375,6 @@ def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
|||||||
print(f"Fallback concatenation failed: {str(e2)}")
|
print(f"Fallback concatenation failed: {str(e2)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def generate_and_process_with_rvc(
|
|
||||||
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
|
||||||
spk_item, vc_transform, f0method,
|
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
|
||||||
resample_sr, rms_mix_rate, protect
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Handle combined TTS and RVC processing for multiple sentences and yield outputs as they are processed.
|
|
||||||
The output is just the latest processed audio.
|
|
||||||
Before yielding a new audio fragment, the function waits for the previous one to finish playing,
|
|
||||||
based on its duration.
|
|
||||||
"""
|
|
||||||
# Ensure TEMP directories exist
|
|
||||||
os.makedirs("./TEMP/spark", exist_ok=True)
|
|
||||||
os.makedirs("./TEMP/rvc", exist_ok=True)
|
|
||||||
|
|
||||||
# Split text into sentences
|
|
||||||
sentences = split_into_sentences(text)
|
|
||||||
if not sentences:
|
|
||||||
yield "No valid text to process.", None
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get next base fragment number
|
|
||||||
base_fragment_num = 1
|
|
||||||
while any(os.path.exists(f"./TEMP/spark/fragment_{base_fragment_num + i}.wav") or
|
|
||||||
os.path.exists(f"./TEMP/rvc/fragment_{base_fragment_num + i}.wav")
|
|
||||||
for i in range(len(sentences))):
|
|
||||||
base_fragment_num += 1
|
|
||||||
|
|
||||||
# Process reference speech
|
|
||||||
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
|
||||||
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
|
||||||
|
|
||||||
info_messages = [f"Processing {len(sentences)} sentences..."]
|
|
||||||
|
|
||||||
# Yield initial message with no audio yet
|
|
||||||
yield "\n".join(info_messages), None
|
|
||||||
|
|
||||||
# Set up a timer to simulate playback duration
|
|
||||||
next_available_time = time.time()
|
|
||||||
|
|
||||||
for i, sentence in enumerate(sentences):
|
|
||||||
spark_path, rvc_path, success, info = process_single_sentence(
|
|
||||||
i, sentence, prompt_speech, prompt_text_clean,
|
|
||||||
spk_item, vc_transform, f0method,
|
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
|
||||||
resample_sr, rms_mix_rate, protect,
|
|
||||||
base_fragment_num
|
|
||||||
)
|
|
||||||
|
|
||||||
info_messages.append(info)
|
|
||||||
# Only update output if processing was successful and we have an audio file
|
|
||||||
if success and rvc_path:
|
|
||||||
try:
|
|
||||||
audio_seg = AudioSegment.from_file(rvc_path)
|
|
||||||
duration = audio_seg.duration_seconds
|
|
||||||
except Exception as e:
|
|
||||||
duration = 0
|
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time < next_available_time:
|
|
||||||
time.sleep(next_available_time - current_time)
|
|
||||||
|
|
||||||
yield "\n".join(info_messages), rvc_path
|
|
||||||
|
|
||||||
next_available_time = time.time() + duration
|
|
||||||
|
|
||||||
yield "\n".join(info_messages), rvc_path
|
|
||||||
|
|
||||||
def modified_get_vc(sid0_value, protect0_value, file_index2_component):
|
def modified_get_vc(sid0_value, protect0_value, file_index2_component):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user