mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
Stream functionality preparation
This commit is contained in:
parent
a18a2c6b94
commit
284264fdbe
@ -1,7 +1,8 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
# Import modules from your packages
|
# Import modules from your packages
|
||||||
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc
|
from merged_ui.utils import modified_get_vc
|
||||||
|
from merged_ui.streaming_utils import generate_and_process_with_rvc_streaming
|
||||||
from rvc_ui.initialization import config
|
from rvc_ui.initialization import config
|
||||||
from rvc_ui.main import names, index_paths
|
from rvc_ui.main import names, index_paths
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ def build_merged_ui():
|
|||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.TabItem("TTS-to-RVC Pipeline"):
|
with gr.TabItem("TTS-to-RVC Pipeline"):
|
||||||
gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
|
gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
|
||||||
gr.Markdown("*Note: For multi-sentence text, each sentence will be processed separately and then combined.*")
|
gr.Markdown("*Note: Sentences will begin playing as soon as the first one is ready*")
|
||||||
|
|
||||||
# TTS Generation Section
|
# TTS Generation Section
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -130,12 +131,14 @@ def build_merged_ui():
|
|||||||
generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
|
generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
# Status for updating information (will be updated during processing)
|
||||||
vc_output1 = gr.Textbox(label="Output information", lines=10)
|
vc_output1 = gr.Textbox(label="Output information", lines=10)
|
||||||
vc_output2 = gr.Audio(label="Final concatenated audio")
|
# Streaming audio output to support incremental updates
|
||||||
|
vc_output2 = gr.Audio(label="Audio (updates as sentences are processed)")
|
||||||
|
|
||||||
# Connect generate function to button
|
# Connect generate function to button with streaming output
|
||||||
generate_with_rvc_button.click(
|
generate_with_rvc_button.click(
|
||||||
generate_and_process_with_rvc,
|
generate_and_process_with_rvc_streaming,
|
||||||
inputs=[
|
inputs=[
|
||||||
tts_text_input,
|
tts_text_input,
|
||||||
prompt_text_input,
|
prompt_text_input,
|
||||||
|
111
modules/merged_ui/streaming_utils.py
Normal file
111
modules/merged_ui/streaming_utils.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Import and reuse existing functions
|
||||||
|
from merged_ui.utils import (
|
||||||
|
split_into_sentences,
|
||||||
|
process_single_sentence,
|
||||||
|
concatenate_audio_files,
|
||||||
|
split_into_sentences,
|
||||||
|
process_single_sentence,
|
||||||
|
initialize_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||||
|
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||||
|
device = 0
|
||||||
|
spark_model = initialize_model(model_dir, device=device)
|
||||||
|
|
||||||
|
def generate_and_process_with_rvc_streaming(
|
||||||
|
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
||||||
|
spk_item, vc_transform, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream process TTS and RVC, yielding audio updates as sentences are processed
|
||||||
|
This is a generator function that yields (status_text, audio_path) tuples
|
||||||
|
"""
|
||||||
|
# Ensure TEMP directories exist
|
||||||
|
os.makedirs("./TEMP/spark", exist_ok=True)
|
||||||
|
os.makedirs("./TEMP/rvc", exist_ok=True)
|
||||||
|
os.makedirs("./TEMP/stream", exist_ok=True)
|
||||||
|
|
||||||
|
# Split text into sentences
|
||||||
|
sentences = split_into_sentences(text)
|
||||||
|
if not sentences:
|
||||||
|
yield "No valid text to process.", None
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get timestamp to create unique session ID for this run
|
||||||
|
session_id = str(int(time.time()))
|
||||||
|
|
||||||
|
# Process reference speech
|
||||||
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
||||||
|
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
||||||
|
|
||||||
|
# Initialize status
|
||||||
|
status_messages = [f"Starting to process {len(sentences)} sentences..."]
|
||||||
|
# Create a list to track processed fragments
|
||||||
|
processed_fragments = []
|
||||||
|
|
||||||
|
# Create a temporary directory for streaming
|
||||||
|
stream_dir = f"./TEMP/stream/{session_id}"
|
||||||
|
os.makedirs(stream_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Yield initial status
|
||||||
|
yield "\n".join(status_messages), None
|
||||||
|
|
||||||
|
# Process each sentence and update the stream
|
||||||
|
for i, sentence in enumerate(sentences):
|
||||||
|
# Add current sentence to status
|
||||||
|
current_msg = f"Processing sentence {i+1}/{len(sentences)}: {sentence[:30]}..."
|
||||||
|
status_messages.append(current_msg)
|
||||||
|
yield "\n".join(status_messages), None
|
||||||
|
|
||||||
|
# Process this sentence
|
||||||
|
spark_path, rvc_path, success, info = process_single_sentence(
|
||||||
|
i, sentence, prompt_speech, prompt_text_clean,
|
||||||
|
spk_item, vc_transform, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect,
|
||||||
|
int(session_id) # Use session ID as base fragment number
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update status with processing result
|
||||||
|
status_messages[-1] = info
|
||||||
|
|
||||||
|
if success and rvc_path and os.path.exists(rvc_path):
|
||||||
|
processed_fragments.append(rvc_path)
|
||||||
|
|
||||||
|
# Create a streaming update file by concatenating all fragments processed so far
|
||||||
|
stream_path = os.path.join(stream_dir, f"stream_update_{i+1}.wav")
|
||||||
|
|
||||||
|
# Concatenate all fragments processed so far
|
||||||
|
concatenate_success = concatenate_audio_files(processed_fragments, stream_path)
|
||||||
|
|
||||||
|
if concatenate_success:
|
||||||
|
# Yield the updated status and the current stream path
|
||||||
|
yield "\n".join(status_messages), stream_path
|
||||||
|
else:
|
||||||
|
# If concatenation failed, just yield the most recent fragment
|
||||||
|
yield "\n".join(status_messages), rvc_path
|
||||||
|
else:
|
||||||
|
# If processing failed, update status but don't update audio
|
||||||
|
yield "\n".join(status_messages), None if not processed_fragments else processed_fragments[-1]
|
||||||
|
|
||||||
|
# Final streaming update with completion message
|
||||||
|
if processed_fragments:
|
||||||
|
# Create final output file
|
||||||
|
final_output_path = f"./TEMP/stream/{session_id}/final_output.wav"
|
||||||
|
concatenate_success = concatenate_audio_files(processed_fragments, final_output_path)
|
||||||
|
|
||||||
|
if concatenate_success:
|
||||||
|
status_messages.append(f"\nAll {len(sentences)} sentences processed successfully!")
|
||||||
|
yield "\n".join(status_messages), final_output_path
|
||||||
|
else:
|
||||||
|
status_messages.append("\nWarning: Failed to create final concatenated file.")
|
||||||
|
yield "\n".join(status_messages), processed_fragments[-1]
|
||||||
|
else:
|
||||||
|
status_messages.append("\nNo sentences were successfully processed.")
|
||||||
|
yield "\n".join(status_messages), None
|
@ -1,20 +1,87 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import shutil
|
import shutil
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from time import sleep
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
|
import torch
|
||||||
|
|
||||||
# Import modules from your packages
|
# Import modules from your packages
|
||||||
|
from spark.cli.SparkTTS import SparkTTS
|
||||||
from rvc_ui.initialization import vc
|
from rvc_ui.initialization import vc
|
||||||
from spark_ui.main import initialize_model, run_tts
|
|
||||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
|
||||||
|
|
||||||
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||||
device = 0
|
device = 0
|
||||||
spark_model = initialize_model(model_dir, device=device)
|
|
||||||
|
def initialize_model(model_dir, device):
|
||||||
|
"""Load the model once at the beginning."""
|
||||||
|
logging.info(f"Loading model from: {model_dir}")
|
||||||
|
|
||||||
|
# Determine appropriate device based on platform and availability
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
# macOS with MPS support (Apple Silicon)
|
||||||
|
device = torch.device(f"mps:{device}")
|
||||||
|
logging.info(f"Using MPS device: {device}")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
# System with CUDA support
|
||||||
|
device = torch.device(f"cuda:{device}")
|
||||||
|
logging.info(f"Using CUDA device: {device}")
|
||||||
|
else:
|
||||||
|
# Fall back to CPU
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logging.info("GPU acceleration not available, using CPU")
|
||||||
|
|
||||||
|
model = SparkTTS(model_dir, device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def run_tts(
|
||||||
|
text,
|
||||||
|
prompt_text=None,
|
||||||
|
prompt_speech=None,
|
||||||
|
gender=None,
|
||||||
|
pitch=None,
|
||||||
|
speed=None,
|
||||||
|
save_dir="TEMP/spark", # Updated default save directory
|
||||||
|
save_filename=None, # New parameter to specify filename
|
||||||
|
):
|
||||||
|
"""Perform TTS inference and save the generated audio."""
|
||||||
|
model = initialize_model(model_dir, device=device)
|
||||||
|
logging.info(f"Saving audio to: {save_dir}")
|
||||||
|
|
||||||
|
if prompt_text is not None:
|
||||||
|
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
||||||
|
|
||||||
|
# Ensure the save directory exists
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
|
||||||
|
if save_filename:
|
||||||
|
save_path = os.path.join(save_dir, save_filename)
|
||||||
|
else:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
||||||
|
|
||||||
|
logging.info("Starting inference...")
|
||||||
|
|
||||||
|
# Perform inference and save the output audio
|
||||||
|
with torch.no_grad():
|
||||||
|
wav = model.inference(
|
||||||
|
text,
|
||||||
|
prompt_speech,
|
||||||
|
prompt_text,
|
||||||
|
gender,
|
||||||
|
pitch,
|
||||||
|
speed,
|
||||||
|
)
|
||||||
|
sf.write(save_path, wav, samplerate=16000)
|
||||||
|
|
||||||
|
logging.info(f"Audio saved at: {save_path}")
|
||||||
|
return save_path
|
||||||
|
|
||||||
def split_into_sentences(text):
|
def split_into_sentences(text):
|
||||||
"""
|
"""
|
||||||
@ -40,34 +107,26 @@ def process_single_sentence(
|
|||||||
base_fragment_num
|
base_fragment_num
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Process a single sentence through the TTS and RVC pipeline
|
Process a single sentence through the TTS and RVC pipeline.
|
||||||
|
|
||||||
Args:
|
|
||||||
sentence_index (int): Index of the sentence in the original text
|
|
||||||
sentence (str): The sentence text to process
|
|
||||||
... (other parameters are the same as generate_and_process_with_rvc)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (spark_output_path, rvc_output_path, success, info_message)
|
|
||||||
"""
|
"""
|
||||||
fragment_num = base_fragment_num + sentence_index
|
fragment_num = base_fragment_num + sentence_index
|
||||||
|
|
||||||
# Generate TTS audio for this sentence
|
# Generate TTS audio for this sentence, saving directly to the correct location
|
||||||
tts_path = run_tts(
|
tts_path = run_tts(
|
||||||
sentence,
|
sentence,
|
||||||
spark_model,
|
|
||||||
prompt_text=prompt_text_clean,
|
prompt_text=prompt_text_clean,
|
||||||
prompt_speech=prompt_speech
|
prompt_speech=prompt_speech,
|
||||||
|
save_dir="./TEMP/spark",
|
||||||
|
save_filename=f"fragment_{fragment_num}.wav"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure we have a TTS file to process
|
# Make sure we have a TTS file to process
|
||||||
if not tts_path or not os.path.exists(tts_path):
|
if not tts_path or not os.path.exists(tts_path):
|
||||||
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
|
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
|
||||||
|
|
||||||
# Save Spark output to TEMP/spark
|
# Use the tts_path as the Spark output (no need to copy)
|
||||||
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
|
spark_output_path = tts_path
|
||||||
shutil.copy2(tts_path, spark_output_path)
|
|
||||||
|
|
||||||
# Call RVC processing function
|
# Call RVC processing function
|
||||||
f0_file = None # We're not using an F0 curve file in this pipeline
|
f0_file = None # We're not using an F0 curve file in this pipeline
|
||||||
output_info, output_audio = vc.vc_single(
|
output_info, output_audio = vc.vc_single(
|
||||||
@ -75,12 +134,11 @@ def process_single_sentence(
|
|||||||
file_index1, file_index2, index_rate, filter_radius,
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
resample_sr, rms_mix_rate, protect
|
resample_sr, rms_mix_rate, protect
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save RVC output to TEMP/rvc directory
|
# Save RVC output to TEMP/rvc directory
|
||||||
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
||||||
rvc_saved = False
|
rvc_saved = False
|
||||||
|
|
||||||
# Try different ways to save the RVC output based on common formats
|
|
||||||
try:
|
try:
|
||||||
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
||||||
# Case 1: output_audio is a file path string
|
# Case 1: output_audio is a file path string
|
||||||
@ -99,7 +157,7 @@ def process_single_sentence(
|
|||||||
rvc_saved = True
|
rvc_saved = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
output_info += f"\nError saving RVC output: {str(e)}"
|
output_info += f"\nError saving RVC output: {str(e)}"
|
||||||
|
|
||||||
# Prepare info message
|
# Prepare info message
|
||||||
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
||||||
info_message += f" - Spark output: {spark_output_path}\n"
|
info_message += f" - Spark output: {spark_output_path}\n"
|
||||||
@ -107,7 +165,7 @@ def process_single_sentence(
|
|||||||
info_message += f" - RVC output: {rvc_output_path}"
|
info_message += f" - RVC output: {rvc_output_path}"
|
||||||
else:
|
else:
|
||||||
info_message += f" - Could not save RVC output to {rvc_output_path}"
|
info_message += f" - Could not save RVC output to {rvc_output_path}"
|
||||||
|
|
||||||
return spark_output_path, rvc_output_path, rvc_saved, info_message
|
return spark_output_path, rvc_output_path, rvc_saved, info_message
|
||||||
|
|
||||||
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
||||||
|
@ -11,72 +11,6 @@ from datetime import datetime
|
|||||||
from spark.cli.SparkTTS import SparkTTS
|
from spark.cli.SparkTTS import SparkTTS
|
||||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||||
|
|
||||||
|
|
||||||
def initialize_model(model_dir, device):
|
|
||||||
"""Load the model once at the beginning."""
|
|
||||||
logging.info(f"Loading model from: {model_dir}")
|
|
||||||
|
|
||||||
# Determine appropriate device based on platform and availability
|
|
||||||
if platform.system() == "Darwin":
|
|
||||||
# macOS with MPS support (Apple Silicon)
|
|
||||||
device = torch.device(f"mps:{device}")
|
|
||||||
logging.info(f"Using MPS device: {device}")
|
|
||||||
elif torch.cuda.is_available():
|
|
||||||
# System with CUDA support
|
|
||||||
device = torch.device(f"cuda:{device}")
|
|
||||||
logging.info(f"Using CUDA device: {device}")
|
|
||||||
else:
|
|
||||||
# Fall back to CPU
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logging.info("GPU acceleration not available, using CPU")
|
|
||||||
|
|
||||||
model = SparkTTS(model_dir, device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def run_tts(
|
|
||||||
text,
|
|
||||||
model,
|
|
||||||
prompt_text=None,
|
|
||||||
prompt_speech=None,
|
|
||||||
gender=None,
|
|
||||||
pitch=None,
|
|
||||||
speed=None,
|
|
||||||
save_dir="spark/example/results",
|
|
||||||
):
|
|
||||||
"""Perform TTS inference and save the generated audio."""
|
|
||||||
logging.info(f"Saving audio to: {save_dir}")
|
|
||||||
|
|
||||||
if prompt_text is not None:
|
|
||||||
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
|
||||||
|
|
||||||
# Ensure the save directory exists
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Generate unique filename using timestamp
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
|
||||||
|
|
||||||
logging.info("Starting inference...")
|
|
||||||
|
|
||||||
# Perform inference and save the output audio
|
|
||||||
with torch.no_grad():
|
|
||||||
wav = model.inference(
|
|
||||||
text,
|
|
||||||
prompt_speech,
|
|
||||||
prompt_text,
|
|
||||||
gender,
|
|
||||||
pitch,
|
|
||||||
speed,
|
|
||||||
)
|
|
||||||
|
|
||||||
sf.write(save_path, wav, samplerate=16000)
|
|
||||||
|
|
||||||
logging.info(f"Audio saved at: {save_path}")
|
|
||||||
|
|
||||||
return save_path
|
|
||||||
|
|
||||||
|
|
||||||
def build_spark_ui():
|
def build_spark_ui():
|
||||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||||
device = 0
|
device = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user