mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
Stream functionality preparation
This commit is contained in:
parent
a18a2c6b94
commit
284264fdbe
@ -1,7 +1,8 @@
|
||||
import gradio as gr
|
||||
|
||||
# Import modules from your packages
|
||||
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc
|
||||
from merged_ui.utils import modified_get_vc
|
||||
from merged_ui.streaming_utils import generate_and_process_with_rvc_streaming
|
||||
from rvc_ui.initialization import config
|
||||
from rvc_ui.main import names, index_paths
|
||||
|
||||
@ -17,7 +18,7 @@ def build_merged_ui():
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("TTS-to-RVC Pipeline"):
|
||||
gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
|
||||
gr.Markdown("*Note: For multi-sentence text, each sentence will be processed separately and then combined.*")
|
||||
gr.Markdown("*Note: Sentences will begin playing as soon as the first one is ready*")
|
||||
|
||||
# TTS Generation Section
|
||||
with gr.Row():
|
||||
@ -130,12 +131,14 @@ def build_merged_ui():
|
||||
generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
|
||||
|
||||
with gr.Row():
|
||||
# Status for updating information (will be updated during processing)
|
||||
vc_output1 = gr.Textbox(label="Output information", lines=10)
|
||||
vc_output2 = gr.Audio(label="Final concatenated audio")
|
||||
# Streaming audio output to support incremental updates
|
||||
vc_output2 = gr.Audio(label="Audio (updates as sentences are processed)")
|
||||
|
||||
# Connect generate function to button
|
||||
# Connect generate function to button with streaming output
|
||||
generate_with_rvc_button.click(
|
||||
generate_and_process_with_rvc,
|
||||
generate_and_process_with_rvc_streaming,
|
||||
inputs=[
|
||||
tts_text_input,
|
||||
prompt_text_input,
|
||||
|
111
modules/merged_ui/streaming_utils.py
Normal file
111
modules/merged_ui/streaming_utils.py
Normal file
@ -0,0 +1,111 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
# Import and reuse existing functions
|
||||
from merged_ui.utils import (
|
||||
split_into_sentences,
|
||||
process_single_sentence,
|
||||
concatenate_audio_files,
|
||||
split_into_sentences,
|
||||
process_single_sentence,
|
||||
initialize_model
|
||||
)
|
||||
|
||||
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||
device = 0
|
||||
spark_model = initialize_model(model_dir, device=device)
|
||||
|
||||
def generate_and_process_with_rvc_streaming(
|
||||
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
||||
spk_item, vc_transform, f0method,
|
||||
file_index1, file_index2, index_rate, filter_radius,
|
||||
resample_sr, rms_mix_rate, protect
|
||||
):
|
||||
"""
|
||||
Stream process TTS and RVC, yielding audio updates as sentences are processed
|
||||
This is a generator function that yields (status_text, audio_path) tuples
|
||||
"""
|
||||
# Ensure TEMP directories exist
|
||||
os.makedirs("./TEMP/spark", exist_ok=True)
|
||||
os.makedirs("./TEMP/rvc", exist_ok=True)
|
||||
os.makedirs("./TEMP/stream", exist_ok=True)
|
||||
|
||||
# Split text into sentences
|
||||
sentences = split_into_sentences(text)
|
||||
if not sentences:
|
||||
yield "No valid text to process.", None
|
||||
return
|
||||
|
||||
# Get timestamp to create unique session ID for this run
|
||||
session_id = str(int(time.time()))
|
||||
|
||||
# Process reference speech
|
||||
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
||||
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
||||
|
||||
# Initialize status
|
||||
status_messages = [f"Starting to process {len(sentences)} sentences..."]
|
||||
# Create a list to track processed fragments
|
||||
processed_fragments = []
|
||||
|
||||
# Create a temporary directory for streaming
|
||||
stream_dir = f"./TEMP/stream/{session_id}"
|
||||
os.makedirs(stream_dir, exist_ok=True)
|
||||
|
||||
# Yield initial status
|
||||
yield "\n".join(status_messages), None
|
||||
|
||||
# Process each sentence and update the stream
|
||||
for i, sentence in enumerate(sentences):
|
||||
# Add current sentence to status
|
||||
current_msg = f"Processing sentence {i+1}/{len(sentences)}: {sentence[:30]}..."
|
||||
status_messages.append(current_msg)
|
||||
yield "\n".join(status_messages), None
|
||||
|
||||
# Process this sentence
|
||||
spark_path, rvc_path, success, info = process_single_sentence(
|
||||
i, sentence, prompt_speech, prompt_text_clean,
|
||||
spk_item, vc_transform, f0method,
|
||||
file_index1, file_index2, index_rate, filter_radius,
|
||||
resample_sr, rms_mix_rate, protect,
|
||||
int(session_id) # Use session ID as base fragment number
|
||||
)
|
||||
|
||||
# Update status with processing result
|
||||
status_messages[-1] = info
|
||||
|
||||
if success and rvc_path and os.path.exists(rvc_path):
|
||||
processed_fragments.append(rvc_path)
|
||||
|
||||
# Create a streaming update file by concatenating all fragments processed so far
|
||||
stream_path = os.path.join(stream_dir, f"stream_update_{i+1}.wav")
|
||||
|
||||
# Concatenate all fragments processed so far
|
||||
concatenate_success = concatenate_audio_files(processed_fragments, stream_path)
|
||||
|
||||
if concatenate_success:
|
||||
# Yield the updated status and the current stream path
|
||||
yield "\n".join(status_messages), stream_path
|
||||
else:
|
||||
# If concatenation failed, just yield the most recent fragment
|
||||
yield "\n".join(status_messages), rvc_path
|
||||
else:
|
||||
# If processing failed, update status but don't update audio
|
||||
yield "\n".join(status_messages), None if not processed_fragments else processed_fragments[-1]
|
||||
|
||||
# Final streaming update with completion message
|
||||
if processed_fragments:
|
||||
# Create final output file
|
||||
final_output_path = f"./TEMP/stream/{session_id}/final_output.wav"
|
||||
concatenate_success = concatenate_audio_files(processed_fragments, final_output_path)
|
||||
|
||||
if concatenate_success:
|
||||
status_messages.append(f"\nAll {len(sentences)} sentences processed successfully!")
|
||||
yield "\n".join(status_messages), final_output_path
|
||||
else:
|
||||
status_messages.append("\nWarning: Failed to create final concatenated file.")
|
||||
yield "\n".join(status_messages), processed_fragments[-1]
|
||||
else:
|
||||
status_messages.append("\nNo sentences were successfully processed.")
|
||||
yield "\n".join(status_messages), None
|
@ -1,20 +1,87 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import re
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
import soundfile as sf
|
||||
from pydub import AudioSegment
|
||||
import torch
|
||||
|
||||
# Import modules from your packages
|
||||
from spark.cli.SparkTTS import SparkTTS
|
||||
from rvc_ui.initialization import vc
|
||||
from spark_ui.main import initialize_model, run_tts
|
||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||
|
||||
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||
device = 0
|
||||
spark_model = initialize_model(model_dir, device=device)
|
||||
|
||||
def initialize_model(model_dir, device):
|
||||
"""Load the model once at the beginning."""
|
||||
logging.info(f"Loading model from: {model_dir}")
|
||||
|
||||
# Determine appropriate device based on platform and availability
|
||||
if platform.system() == "Darwin":
|
||||
# macOS with MPS support (Apple Silicon)
|
||||
device = torch.device(f"mps:{device}")
|
||||
logging.info(f"Using MPS device: {device}")
|
||||
elif torch.cuda.is_available():
|
||||
# System with CUDA support
|
||||
device = torch.device(f"cuda:{device}")
|
||||
logging.info(f"Using CUDA device: {device}")
|
||||
else:
|
||||
# Fall back to CPU
|
||||
device = torch.device("cpu")
|
||||
logging.info("GPU acceleration not available, using CPU")
|
||||
|
||||
model = SparkTTS(model_dir, device)
|
||||
return model
|
||||
|
||||
|
||||
def run_tts(
|
||||
text,
|
||||
prompt_text=None,
|
||||
prompt_speech=None,
|
||||
gender=None,
|
||||
pitch=None,
|
||||
speed=None,
|
||||
save_dir="TEMP/spark", # Updated default save directory
|
||||
save_filename=None, # New parameter to specify filename
|
||||
):
|
||||
"""Perform TTS inference and save the generated audio."""
|
||||
model = initialize_model(model_dir, device=device)
|
||||
logging.info(f"Saving audio to: {save_dir}")
|
||||
|
||||
if prompt_text is not None:
|
||||
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
||||
|
||||
# Ensure the save directory exists
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
|
||||
if save_filename:
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
||||
|
||||
logging.info("Starting inference...")
|
||||
|
||||
# Perform inference and save the output audio
|
||||
with torch.no_grad():
|
||||
wav = model.inference(
|
||||
text,
|
||||
prompt_speech,
|
||||
prompt_text,
|
||||
gender,
|
||||
pitch,
|
||||
speed,
|
||||
)
|
||||
sf.write(save_path, wav, samplerate=16000)
|
||||
|
||||
logging.info(f"Audio saved at: {save_path}")
|
||||
return save_path
|
||||
|
||||
def split_into_sentences(text):
|
||||
"""
|
||||
@ -40,33 +107,25 @@ def process_single_sentence(
|
||||
base_fragment_num
|
||||
):
|
||||
"""
|
||||
Process a single sentence through the TTS and RVC pipeline
|
||||
|
||||
Args:
|
||||
sentence_index (int): Index of the sentence in the original text
|
||||
sentence (str): The sentence text to process
|
||||
... (other parameters are the same as generate_and_process_with_rvc)
|
||||
|
||||
Returns:
|
||||
tuple: (spark_output_path, rvc_output_path, success, info_message)
|
||||
Process a single sentence through the TTS and RVC pipeline.
|
||||
"""
|
||||
fragment_num = base_fragment_num + sentence_index
|
||||
|
||||
# Generate TTS audio for this sentence
|
||||
# Generate TTS audio for this sentence, saving directly to the correct location
|
||||
tts_path = run_tts(
|
||||
sentence,
|
||||
spark_model,
|
||||
prompt_text=prompt_text_clean,
|
||||
prompt_speech=prompt_speech
|
||||
prompt_speech=prompt_speech,
|
||||
save_dir="./TEMP/spark",
|
||||
save_filename=f"fragment_{fragment_num}.wav"
|
||||
)
|
||||
|
||||
# Make sure we have a TTS file to process
|
||||
if not tts_path or not os.path.exists(tts_path):
|
||||
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
|
||||
|
||||
# Save Spark output to TEMP/spark
|
||||
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
|
||||
shutil.copy2(tts_path, spark_output_path)
|
||||
# Use the tts_path as the Spark output (no need to copy)
|
||||
spark_output_path = tts_path
|
||||
|
||||
# Call RVC processing function
|
||||
f0_file = None # We're not using an F0 curve file in this pipeline
|
||||
@ -80,7 +139,6 @@ def process_single_sentence(
|
||||
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
||||
rvc_saved = False
|
||||
|
||||
# Try different ways to save the RVC output based on common formats
|
||||
try:
|
||||
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
||||
# Case 1: output_audio is a file path string
|
||||
|
@ -11,72 +11,6 @@ from datetime import datetime
|
||||
from spark.cli.SparkTTS import SparkTTS
|
||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||
|
||||
|
||||
def initialize_model(model_dir, device):
|
||||
"""Load the model once at the beginning."""
|
||||
logging.info(f"Loading model from: {model_dir}")
|
||||
|
||||
# Determine appropriate device based on platform and availability
|
||||
if platform.system() == "Darwin":
|
||||
# macOS with MPS support (Apple Silicon)
|
||||
device = torch.device(f"mps:{device}")
|
||||
logging.info(f"Using MPS device: {device}")
|
||||
elif torch.cuda.is_available():
|
||||
# System with CUDA support
|
||||
device = torch.device(f"cuda:{device}")
|
||||
logging.info(f"Using CUDA device: {device}")
|
||||
else:
|
||||
# Fall back to CPU
|
||||
device = torch.device("cpu")
|
||||
logging.info("GPU acceleration not available, using CPU")
|
||||
|
||||
model = SparkTTS(model_dir, device)
|
||||
return model
|
||||
|
||||
|
||||
def run_tts(
|
||||
text,
|
||||
model,
|
||||
prompt_text=None,
|
||||
prompt_speech=None,
|
||||
gender=None,
|
||||
pitch=None,
|
||||
speed=None,
|
||||
save_dir="spark/example/results",
|
||||
):
|
||||
"""Perform TTS inference and save the generated audio."""
|
||||
logging.info(f"Saving audio to: {save_dir}")
|
||||
|
||||
if prompt_text is not None:
|
||||
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
||||
|
||||
# Ensure the save directory exists
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Generate unique filename using timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
||||
|
||||
logging.info("Starting inference...")
|
||||
|
||||
# Perform inference and save the output audio
|
||||
with torch.no_grad():
|
||||
wav = model.inference(
|
||||
text,
|
||||
prompt_speech,
|
||||
prompt_text,
|
||||
gender,
|
||||
pitch,
|
||||
speed,
|
||||
)
|
||||
|
||||
sf.write(save_path, wav, samplerate=16000)
|
||||
|
||||
logging.info(f"Audio saved at: {save_path}")
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
def build_spark_ui():
|
||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||
device = 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user