mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
211 lines
7.0 KiB
Python
211 lines
7.0 KiB
Python
# spark_ui/main.py
|
|
import os
|
|
import torch
|
|
import soundfile as sf
|
|
import logging
|
|
import argparse
|
|
import gradio as gr
|
|
import platform
|
|
|
|
from datetime import datetime
|
|
from spark.cli.SparkTTS import SparkTTS
|
|
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
|
|
|
|
|
def initialize_model(model_dir, device):
|
|
"""Load the model once at the beginning."""
|
|
logging.info(f"Loading model from: {model_dir}")
|
|
|
|
# Determine appropriate device based on platform and availability
|
|
if platform.system() == "Darwin":
|
|
# macOS with MPS support (Apple Silicon)
|
|
device = torch.device(f"mps:{device}")
|
|
logging.info(f"Using MPS device: {device}")
|
|
elif torch.cuda.is_available():
|
|
# System with CUDA support
|
|
device = torch.device(f"cuda:{device}")
|
|
logging.info(f"Using CUDA device: {device}")
|
|
else:
|
|
# Fall back to CPU
|
|
device = torch.device("cpu")
|
|
logging.info("GPU acceleration not available, using CPU")
|
|
|
|
model = SparkTTS(model_dir, device)
|
|
return model
|
|
|
|
|
|
def run_tts(
|
|
text,
|
|
model,
|
|
prompt_text=None,
|
|
prompt_speech=None,
|
|
gender=None,
|
|
pitch=None,
|
|
speed=None,
|
|
save_dir="spark/example/results",
|
|
):
|
|
"""Perform TTS inference and save the generated audio."""
|
|
logging.info(f"Saving audio to: {save_dir}")
|
|
|
|
if prompt_text is not None:
|
|
prompt_text = None if len(prompt_text) <= 1 else prompt_text
|
|
|
|
# Ensure the save directory exists
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Generate unique filename using timestamp
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
save_path = os.path.join(save_dir, f"{timestamp}.wav")
|
|
|
|
logging.info("Starting inference...")
|
|
|
|
# Perform inference and save the output audio
|
|
with torch.no_grad():
|
|
wav = model.inference(
|
|
text,
|
|
prompt_speech,
|
|
prompt_text,
|
|
gender,
|
|
pitch,
|
|
speed,
|
|
)
|
|
|
|
sf.write(save_path, wav, samplerate=16000)
|
|
|
|
logging.info(f"Audio saved at: {save_path}")
|
|
|
|
return save_path
|
|
|
|
|
|
def build_spark_ui():
|
|
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
|
device = 0
|
|
# Initialize model
|
|
model = initialize_model(model_dir, device=device)
|
|
|
|
# Define callback function for voice cloning
|
|
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
|
"""
|
|
Gradio callback to clone voice using text and optional prompt speech.
|
|
- text: The input text to be synthesised.
|
|
- prompt_text: Additional textual info for the prompt (optional).
|
|
- prompt_wav_upload/prompt_wav_record: Audio files used as reference.
|
|
"""
|
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
|
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
|
|
|
|
audio_output_path = run_tts(
|
|
text,
|
|
model,
|
|
prompt_text=prompt_text_clean,
|
|
prompt_speech=prompt_speech
|
|
)
|
|
return audio_output_path
|
|
|
|
# Define callback function for creating new voices
|
|
def voice_creation(text, gender, pitch, speed):
|
|
"""
|
|
Gradio callback to create a synthetic voice with adjustable parameters.
|
|
- text: The input text for synthesis.
|
|
- gender: 'male' or 'female'.
|
|
- pitch/speed: Ranges mapped by LEVELS_MAP_UI.
|
|
"""
|
|
pitch_val = LEVELS_MAP_UI[int(pitch)]
|
|
speed_val = LEVELS_MAP_UI[int(speed)]
|
|
audio_output_path = run_tts(
|
|
text,
|
|
model,
|
|
gender=gender,
|
|
pitch=pitch_val,
|
|
speed=speed_val
|
|
)
|
|
return audio_output_path
|
|
|
|
with gr.Blocks() as app:
|
|
# Use HTML for centered title
|
|
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
|
|
with gr.Tabs():
|
|
# Voice Clone Tab
|
|
with gr.TabItem("Voice Clone"):
|
|
gr.Markdown(
|
|
"### Upload reference audio or recording (上传参考音频或者录音)"
|
|
)
|
|
|
|
with gr.Row():
|
|
prompt_wav_upload = gr.Audio(
|
|
sources="upload",
|
|
type="filepath",
|
|
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
|
|
)
|
|
prompt_wav_record = gr.Audio(
|
|
sources="microphone",
|
|
type="filepath",
|
|
label="Record the prompt audio file.",
|
|
)
|
|
|
|
with gr.Row():
|
|
text_input = gr.Textbox(
|
|
label="Text", lines=3, placeholder="Enter text here"
|
|
)
|
|
prompt_text_input = gr.Textbox(
|
|
label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
|
|
lines=3,
|
|
placeholder="Enter text of the prompt speech.",
|
|
)
|
|
|
|
audio_output = gr.Audio(
|
|
label="Generated Audio", streaming=True
|
|
)
|
|
|
|
generate_buttom_clone = gr.Button("Generate")
|
|
|
|
generate_buttom_clone.click(
|
|
voice_clone,
|
|
inputs=[
|
|
text_input,
|
|
prompt_text_input,
|
|
prompt_wav_upload,
|
|
prompt_wav_record,
|
|
],
|
|
outputs=[audio_output],
|
|
)
|
|
|
|
# Voice Creation Tab
|
|
with gr.TabItem("Voice Creation"):
|
|
gr.Markdown(
|
|
"### Create your own voice based on the following parameters"
|
|
)
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
gender = gr.Radio(
|
|
choices=["male", "female"], value="male", label="Gender"
|
|
)
|
|
pitch = gr.Slider(
|
|
minimum=1, maximum=5, step=1, value=3, label="Pitch"
|
|
)
|
|
speed = gr.Slider(
|
|
minimum=1, maximum=5, step=1, value=3, label="Speed"
|
|
)
|
|
with gr.Column():
|
|
text_input_creation = gr.Textbox(
|
|
label="Input Text",
|
|
lines=3,
|
|
placeholder="Enter text here",
|
|
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
|
|
)
|
|
create_button = gr.Button("Create Voice")
|
|
|
|
audio_output = gr.Audio(
|
|
label="Generated Audio", streaming=True
|
|
)
|
|
create_button.click(
|
|
voice_creation,
|
|
inputs=[text_input_creation, gender, pitch, speed],
|
|
outputs=[audio_output],
|
|
)
|
|
|
|
return app
|
|
|
|
if __name__ == "__main__":
|
|
build_spark_ui() |