mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-04 03:38:58 +08:00
Decouple merged utils from spark_ui
This commit is contained in:
parent
931870e527
commit
87e06f30ec
@ -1,20 +1,87 @@
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import re
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
import soundfile as sf
|
||||
from pydub import AudioSegment
|
||||
import torch
|
||||
|
||||
# Import modules from your packages
|
||||
from spark.cli.SparkTTS import SparkTTS
|
||||
from rvc_ui.initialization import vc
|
||||
from spark_ui.main import initialize_model, run_tts
|
||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||
|
||||
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||
device = 0
|
||||
spark_model = initialize_model(model_dir, device=device)
|
||||
|
||||
def initialize_model(model_dir, device):
|
||||
"""Load the model once at the beginning."""
|
||||
logging.info(f"Loading model from: {model_dir}")
|
||||
|
||||
# Determine appropriate device based on platform and availability
|
||||
if platform.system() == "Darwin":
|
||||
# macOS with MPS support (Apple Silicon)
|
||||
device = torch.device(f"mps:{device}")
|
||||
logging.info(f"Using MPS device: {device}")
|
||||
elif torch.cuda.is_available():
|
||||
# System with CUDA support
|
||||
device = torch.device(f"cuda:{device}")
|
||||
logging.info(f"Using CUDA device: {device}")
|
||||
else:
|
||||
# Fall back to CPU
|
||||
device = torch.device("cpu")
|
||||
logging.info("GPU acceleration not available, using CPU")
|
||||
|
||||
model = SparkTTS(model_dir, device)
|
||||
return model
|
||||
|
||||
|
||||
def run_tts(
|
||||
text,
|
||||
prompt_text=None,
|
||||
prompt_speech=None,
|
||||
gender=None,
|
||||
pitch=None,
|
||||
speed=None,
|
||||
save_dir="TEMP/spark", # Updated default save directory
|
||||
save_filename=None, # New parameter to specify filename
|
||||
):
|
||||
"""Perform TTS inference and save the generated audio."""
|
||||
model = initialize_model(model_dir, device=device)
|
||||
logging.info(f"Saving audio to: {save_dir}")
|
||||
|
||||
if prompt_text is not None:
|
||||
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
||||
|
||||
# Ensure the save directory exists
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Determine the save path based on save_filename if provided; otherwise, use a timestamp
|
||||
if save_filename:
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
else:
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
||||
|
||||
logging.info("Starting inference...")
|
||||
|
||||
# Perform inference and save the output audio
|
||||
with torch.no_grad():
|
||||
wav = model.inference(
|
||||
text,
|
||||
prompt_speech,
|
||||
prompt_text,
|
||||
gender,
|
||||
pitch,
|
||||
speed,
|
||||
)
|
||||
sf.write(save_path, wav, samplerate=16000)
|
||||
|
||||
logging.info(f"Audio saved at: {save_path}")
|
||||
return save_path
|
||||
|
||||
def split_into_sentences(text):
|
||||
"""
|
||||
@ -40,34 +107,26 @@ def process_single_sentence(
|
||||
base_fragment_num
|
||||
):
|
||||
"""
|
||||
Process a single sentence through the TTS and RVC pipeline
|
||||
|
||||
Args:
|
||||
sentence_index (int): Index of the sentence in the original text
|
||||
sentence (str): The sentence text to process
|
||||
... (other parameters are the same as generate_and_process_with_rvc)
|
||||
|
||||
Returns:
|
||||
tuple: (spark_output_path, rvc_output_path, success, info_message)
|
||||
Process a single sentence through the TTS and RVC pipeline.
|
||||
"""
|
||||
fragment_num = base_fragment_num + sentence_index
|
||||
|
||||
# Generate TTS audio for this sentence
|
||||
|
||||
# Generate TTS audio for this sentence, saving directly to the correct location
|
||||
tts_path = run_tts(
|
||||
sentence,
|
||||
spark_model,
|
||||
prompt_text=prompt_text_clean,
|
||||
prompt_speech=prompt_speech
|
||||
prompt_speech=prompt_speech,
|
||||
save_dir="./TEMP/spark",
|
||||
save_filename=f"fragment_{fragment_num}.wav"
|
||||
)
|
||||
|
||||
|
||||
# Make sure we have a TTS file to process
|
||||
if not tts_path or not os.path.exists(tts_path):
|
||||
return None, None, False, f"Failed to generate TTS audio for sentence: {sentence}"
|
||||
|
||||
# Save Spark output to TEMP/spark
|
||||
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
|
||||
shutil.copy2(tts_path, spark_output_path)
|
||||
|
||||
|
||||
# Use the tts_path as the Spark output (no need to copy)
|
||||
spark_output_path = tts_path
|
||||
|
||||
# Call RVC processing function
|
||||
f0_file = None # We're not using an F0 curve file in this pipeline
|
||||
output_info, output_audio = vc.vc_single(
|
||||
@ -75,12 +134,11 @@ def process_single_sentence(
|
||||
file_index1, file_index2, index_rate, filter_radius,
|
||||
resample_sr, rms_mix_rate, protect
|
||||
)
|
||||
|
||||
|
||||
# Save RVC output to TEMP/rvc directory
|
||||
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
||||
rvc_saved = False
|
||||
|
||||
# Try different ways to save the RVC output based on common formats
|
||||
|
||||
try:
|
||||
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
||||
# Case 1: output_audio is a file path string
|
||||
@ -99,7 +157,7 @@ def process_single_sentence(
|
||||
rvc_saved = True
|
||||
except Exception as e:
|
||||
output_info += f"\nError saving RVC output: {str(e)}"
|
||||
|
||||
|
||||
# Prepare info message
|
||||
info_message = f"Sentence {sentence_index+1}: {sentence[:30]}{'...' if len(sentence) > 30 else ''}\n"
|
||||
info_message += f" - Spark output: {spark_output_path}\n"
|
||||
@ -107,7 +165,7 @@ def process_single_sentence(
|
||||
info_message += f" - RVC output: {rvc_output_path}"
|
||||
else:
|
||||
info_message += f" - Could not save RVC output to {rvc_output_path}"
|
||||
|
||||
|
||||
return spark_output_path, rvc_output_path, rvc_saved, info_message
|
||||
|
||||
def concatenate_audio_files(file_paths, output_path, sample_rate=44100):
|
||||
@ -225,4 +283,4 @@ def modified_get_vc(sid0_value, protect0_value, file_index2_component):
|
||||
if isinstance(outputs, tuple) and len(outputs) >= 3:
|
||||
return outputs[0], outputs[1], outputs[3]
|
||||
|
||||
return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else ""
|
||||
return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else ""
|
Loading…
x
Reference in New Issue
Block a user