Stream functionality preparation

This commit is contained in:
VSlobolinskyi 2025-03-21 19:17:49 +02:00
parent a18a2c6b94
commit 284264fdbe
4 changed files with 205 additions and 99 deletions

View File

@ -1,7 +1,8 @@
import gradio as gr
# Import modules from your packages
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc
from merged_ui.utils import modified_get_vc
from merged_ui.streaming_utils import generate_and_process_with_rvc_streaming
from rvc_ui.initialization import config
from rvc_ui.main import names, index_paths
@ -17,7 +18,7 @@ def build_merged_ui():
with gr.Tabs():
with gr.TabItem("TTS-to-RVC Pipeline"):
gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
gr.Markdown("*Note: For multi-sentence text, each sentence will be processed separately and then combined.*")
gr.Markdown("*Note: Sentences will begin playing as soon as the first one is ready*")
# TTS Generation Section
with gr.Row():
@ -130,12 +131,14 @@ def build_merged_ui():
generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
with gr.Row():
# Status for updating information (will be updated during processing)
vc_output1 = gr.Textbox(label="Output information", lines=10)
vc_output2 = gr.Audio(label="Final concatenated audio")
# Streaming audio output to support incremental updates
vc_output2 = gr.Audio(label="Audio (updates as sentences are processed)")
# Connect generate function to button
# Connect generate function to button with streaming output
generate_with_rvc_button.click(
generate_and_process_with_rvc,
generate_and_process_with_rvc_streaming,
inputs=[
tts_text_input,
prompt_text_input,

View File

@ -0,0 +1,111 @@
import os
import time
# Import and reuse existing functions
from merged_ui.utils import (
split_into_sentences,
process_single_sentence,
concatenate_audio_files,
split_into_sentences,
process_single_sentence,
initialize_model
)
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
device = 0
spark_model = initialize_model(model_dir, device=device)
def generate_and_process_with_rvc_streaming(
text, prompt_text, prompt_wav_upload, prompt_wav_record,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
"""
Stream process TTS and RVC, yielding audio updates as sentences are processed
This is a generator function that yields (status_text, audio_path) tuples
"""
# Ensure TEMP directories exist
os.makedirs("./TEMP/spark", exist_ok=True)
os.makedirs("./TEMP/rvc", exist_ok=True)
os.makedirs("./TEMP/stream", exist_ok=True)
# Split text into sentences
sentences = split_into_sentences(text)
if not sentences:
yield "No valid text to process.", None
return
# Get timestamp to create unique session ID for this run
session_id = str(int(time.time()))
# Process reference speech
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
# Initialize status
status_messages = [f"Starting to process {len(sentences)} sentences..."]
# Create a list to track processed fragments
processed_fragments = []
# Create a temporary directory for streaming
stream_dir = f"./TEMP/stream/{session_id}"
os.makedirs(stream_dir, exist_ok=True)
# Yield initial status
yield "\n".join(status_messages), None
# Process each sentence and update the stream
for i, sentence in enumerate(sentences):
# Add current sentence to status
current_msg = f"Processing sentence {i+1}/{len(sentences)}: {sentence[:30]}..."
status_messages.append(current_msg)
yield "\n".join(status_messages), None
# Process this sentence
spark_path, rvc_path, success, info = process_single_sentence(
i, sentence, prompt_speech, prompt_text_clean,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect,
int(session_id) # Use session ID as base fragment number
)
# Update status with processing result
status_messages[-1] = info
if success and rvc_path and os.path.exists(rvc_path):
processed_fragments.append(rvc_path)
# Create a streaming update file by concatenating all fragments processed so far
stream_path = os.path.join(stream_dir, f"stream_update_{i+1}.wav")
# Concatenate all fragments processed so far
concatenate_success = concatenate_audio_files(processed_fragments, stream_path)
if concatenate_success:
# Yield the updated status and the current stream path
yield "\n".join(status_messages), stream_path
else:
# If concatenation failed, just yield the most recent fragment
yield "\n".join(status_messages), rvc_path
else:
# If processing failed, update status but don't update audio
yield "\n".join(status_messages), None if not processed_fragments else processed_fragments[-1]
# Final streaming update with completion message
if processed_fragments:
# Create final output file
final_output_path = f"./TEMP/stream/{session_id}/final_output.wav"
concatenate_success = concatenate_audio_files(processed_fragments, final_output_path)
if concatenate_success:
status_messages.append(f"\nAll {len(sentences)} sentences processed successfully!")
yield "\n".join(status_messages), final_output_path
else:
status_messages.append("\nWarning: Failed to create final concatenated file.")
yield "\n".join(status_messages), processed_fragments[-1]
else:
status_messages.append("\nNo sentences were successfully processed.")
yield "\n".join(status_messages), None

View File

@ -1,20 +1,87 @@
from datetime import datetime
import logging
import os
import platform
import shutil
import re
import numpy as np
from time import sleep
import soundfile as sf
from pydub import AudioSegment
import torch
# Import modules from your packages
from spark.cli.SparkTTS import SparkTTS
from rvc_ui.initialization import vc
from spark_ui.main import initialize_model, run_tts
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
device = 0
spark_model = initialize_model(model_dir, device=device)
def initialize_model(model_dir, device):
"""Load the model once at the beginning."""
logging.info(f"Loading model from: {model_dir}")
# Determine appropriate device based on platform and availability
if platform.system() == "Darwin":
# macOS with MPS support (Apple Silicon)
device = torch.device(f"mps:{device}")
logging.info(f"Using MPS device: {device}")
elif torch.cuda.is_available():
# System with CUDA support
device = torch.device(f"cuda:{device}")
logging.info(f"Using CUDA device: {device}")
else:
# Fall back to CPU
device = torch.device("cpu")
logging.info("GPU acceleration not available, using CPU")
model = SparkTTS(model_dir, device)
return model
def run_tts(
text,
prompt_text=None,
prompt_speech=None,
gender=None,
pitch=None,
speed=None,
save_dir="TEMP/spark", # Updated default save directory
save_filename=None, # New parameter to specify filename
):
"""Perform TTS inference and save the generated audio."""
model = initialize_model(model_dir, device=device)
logging.info(f"Saving audio to: {save_dir}")
if prompt_text is not None:
prompt_text = None if len(prompt_text) <= 1 else prompt_text
# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
if save_filename:
save_path = os.path.join(save_dir, save_filename)
else:
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
with torch.no_grad():
wav = model.inference(
text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,
)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
return save_path
def split_into_sentences(text):
"""
@ -40,33 +107,25 @@ def process_single_sentence(
base_fragment_num
):
"""
Process a single sentence through the TTS and RVC pipeline
Args:
sentence_index (int): Index of the sentence in the original text
sentence (str): The sentence text to process
... (other parameters are the same as generate_and_process_with_rvc)
Returns:
tuple: (spark_output_path, rvc_output_path, success, info_message)
Process a single sentence through the TTS and RVC pipeline.
"""
fragment_num = base_fragment_num + sentence_index
# Generate TTS audio for this sentence
# Generate TTS audio for this sentence, saving directly to the correct location
tts_path = run_tts(
sentence,
spark_model,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
prompt_speech=prompt_speech,
save_dir="./TEMP/spark",
save_filename=f"fragment_{fragment_num}.wav"
)
# Make sure we have a TTS file to process
if not tts_path or not os.path.exists(tts_path):
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
# Save Spark output to TEMP/spark
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
shutil.copy2(tts_path, spark_output_path)
# Use the tts_path as the Spark output (no need to copy)
spark_output_path = tts_path
# Call RVC processing function
f0_file = None # We're not using an F0 curve file in this pipeline
@ -80,7 +139,6 @@ def process_single_sentence(
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
rvc_saved = False
# Try different ways to save the RVC output based on common formats
try:
if isinstance(output_audio, str) and os.path.exists(output_audio):
# Case 1: output_audio is a file path string

View File

@ -11,72 +11,6 @@ from datetime import datetime
from spark.cli.SparkTTS import SparkTTS
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
def initialize_model(model_dir, device):
"""Load the model once at the beginning."""
logging.info(f"Loading model from: {model_dir}")
# Determine appropriate device based on platform and availability
if platform.system() == "Darwin":
# macOS with MPS support (Apple Silicon)
device = torch.device(f"mps:{device}")
logging.info(f"Using MPS device: {device}")
elif torch.cuda.is_available():
# System with CUDA support
device = torch.device(f"cuda:{device}")
logging.info(f"Using CUDA device: {device}")
else:
# Fall back to CPU
device = torch.device("cpu")
logging.info("GPU acceleration not available, using CPU")
model = SparkTTS(model_dir, device)
return model
def run_tts(
text,
model,
prompt_text=None,
prompt_speech=None,
gender=None,
pitch=None,
speed=None,
save_dir="spark/example/results",
):
"""Perform TTS inference and save the generated audio."""
logging.info(f"Saving audio to: {save_dir}")
if prompt_text is not None:
prompt_text = None if len(prompt_text) <= 1 else prompt_text
# Ensure the save directory exists
os.makedirs(save_dir, exist_ok=True)
# Generate unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
save_path = os.path.join(save_dir, f"{timestamp}.wav")
logging.info("Starting inference...")
# Perform inference and save the output audio
with torch.no_grad():
wav = model.inference(
text,
prompt_speech,
prompt_text,
gender,
pitch,
speed,
)
sf.write(save_path, wav, samplerate=16000)
logging.info(f"Audio saved at: {save_path}")
return save_path
def build_spark_ui():
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
device = 0