2025-03-21 19:17:49 +02:00

145 lines
5.3 KiB
Python

# spark_ui/main.py
import os
import torch
import soundfile as sf
import logging
import argparse
import gradio as gr
import platform
from datetime import datetime
from spark.cli.SparkTTS import SparkTTS
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
def build_spark_ui():
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
device = 0
# Initialize model
model = initialize_model(model_dir, device=device)
# Define callback function for voice cloning
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
"""
Gradio callback to clone voice using text and optional prompt speech.
- text: The input text to be synthesised.
- prompt_text: Additional textual info for the prompt (optional).
- prompt_wav_upload/prompt_wav_record: Audio files used as reference.
"""
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
audio_output_path = run_tts(
text,
model,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
)
return audio_output_path
# Define callback function for creating new voices
def voice_creation(text, gender, pitch, speed):
"""
Gradio callback to create a synthetic voice with adjustable parameters.
- text: The input text for synthesis.
- gender: 'male' or 'female'.
- pitch/speed: Ranges mapped by LEVELS_MAP_UI.
"""
pitch_val = LEVELS_MAP_UI[int(pitch)]
speed_val = LEVELS_MAP_UI[int(speed)]
audio_output_path = run_tts(
text,
model,
gender=gender,
pitch=pitch_val,
speed=speed_val
)
return audio_output_path
with gr.Blocks() as app:
# Use HTML for centered title
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
with gr.Tabs():
# Voice Clone Tab
with gr.TabItem("Voice Clone"):
gr.Markdown(
"### Upload reference audio or recording (上传参考音频或者录音)"
)
with gr.Row():
prompt_wav_upload = gr.Audio(
sources="upload",
type="filepath",
label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
)
prompt_wav_record = gr.Audio(
sources="microphone",
type="filepath",
label="Record the prompt audio file.",
)
with gr.Row():
text_input = gr.Textbox(
label="Text", lines=3, placeholder="Enter text here"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
lines=3,
placeholder="Enter text of the prompt speech.",
)
audio_output = gr.Audio(
label="Generated Audio", streaming=True
)
generate_buttom_clone = gr.Button("Generate")
generate_buttom_clone.click(
voice_clone,
inputs=[
text_input,
prompt_text_input,
prompt_wav_upload,
prompt_wav_record,
],
outputs=[audio_output],
)
# Voice Creation Tab
with gr.TabItem("Voice Creation"):
gr.Markdown(
"### Create your own voice based on the following parameters"
)
with gr.Row():
with gr.Column():
gender = gr.Radio(
choices=["male", "female"], value="male", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed"
)
with gr.Column():
text_input_creation = gr.Textbox(
label="Input Text",
lines=3,
placeholder="Enter text here",
value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
)
create_button = gr.Button("Create Voice")
audio_output = gr.Audio(
label="Generated Audio", streaming=True
)
create_button.click(
voice_creation,
inputs=[text_input_creation, gender, pitch, speed],
outputs=[audio_output],
)
return app
if __name__ == "__main__":
build_spark_ui()