# spark_ui/main.py import os import torch import soundfile as sf import logging import argparse import gradio as gr import platform from datetime import datetime from spark.cli.SparkTTS import SparkTTS from spark.sparktts.utils.token_parser import LEVELS_MAP_UI def initialize_model(model_dir, device): """Load the model once at the beginning.""" logging.info(f"Loading model from: {model_dir}") # Determine appropriate device based on platform and availability if platform.system() == "Darwin": # macOS with MPS support (Apple Silicon) device = torch.device(f"mps:{device}") logging.info(f"Using MPS device: {device}") elif torch.cuda.is_available(): # System with CUDA support device = torch.device(f"cuda:{device}") logging.info(f"Using CUDA device: {device}") else: # Fall back to CPU device = torch.device("cpu") logging.info("GPU acceleration not available, using CPU") model = SparkTTS(model_dir, device) return model def run_tts( text, model, prompt_text=None, prompt_speech=None, gender=None, pitch=None, speed=None, save_dir="spark/example/results", ): """Perform TTS inference and save the generated audio.""" logging.info(f"Saving audio to: {save_dir}") if prompt_text is not None: prompt_text = None if len(prompt_text) <= 1 else prompt_text # Ensure the save directory exists os.makedirs(save_dir, exist_ok=True) # Generate unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d%H%M%S") save_path = os.path.join(save_dir, f"{timestamp}.wav") logging.info("Starting inference...") # Perform inference and save the output audio with torch.no_grad(): wav = model.inference( text, prompt_speech, prompt_text, gender, pitch, speed, ) sf.write(save_path, wav, samplerate=16000) logging.info(f"Audio saved at: {save_path}") return save_path def build_spark_ui(): model_dir = "spark/pretrained_models/Spark-TTS-0.5B" device = 0 # Initialize model model = initialize_model(model_dir, device=device) # Define callback function for voice cloning def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): """ Gradio callback to clone voice using text and optional prompt speech. - text: The input text to be synthesised. - prompt_text: Additional textual info for the prompt (optional). - prompt_wav_upload/prompt_wav_record: Audio files used as reference. """ prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record prompt_text_clean = None if len(prompt_text) < 2 else prompt_text audio_output_path = run_tts( text, model, prompt_text=prompt_text_clean, prompt_speech=prompt_speech ) return audio_output_path # Define callback function for creating new voices def voice_creation(text, gender, pitch, speed): """ Gradio callback to create a synthetic voice with adjustable parameters. - text: The input text for synthesis. - gender: 'male' or 'female'. - pitch/speed: Ranges mapped by LEVELS_MAP_UI. """ pitch_val = LEVELS_MAP_UI[int(pitch)] speed_val = LEVELS_MAP_UI[int(speed)] audio_output_path = run_tts( text, model, gender=gender, pitch=pitch_val, speed=speed_val ) return audio_output_path with gr.Blocks() as app: # Use HTML for centered title gr.HTML('

Spark-TTS by SparkAudio

') with gr.Tabs(): # Voice Clone Tab with gr.TabItem("Voice Clone"): gr.Markdown( "### Upload reference audio or recording (上传参考音频或者录音)" ) with gr.Row(): prompt_wav_upload = gr.Audio( sources="upload", type="filepath", label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", ) prompt_wav_record = gr.Audio( sources="microphone", type="filepath", label="Record the prompt audio file.", ) with gr.Row(): text_input = gr.Textbox( label="Text", lines=3, placeholder="Enter text here" ) prompt_text_input = gr.Textbox( label="Text of prompt speech (Optional; recommended for cloning in the same language.)", lines=3, placeholder="Enter text of the prompt speech.", ) audio_output = gr.Audio( label="Generated Audio", streaming=True ) generate_buttom_clone = gr.Button("Generate") generate_buttom_clone.click( voice_clone, inputs=[ text_input, prompt_text_input, prompt_wav_upload, prompt_wav_record, ], outputs=[audio_output], ) # Voice Creation Tab with gr.TabItem("Voice Creation"): gr.Markdown( "### Create your own voice based on the following parameters" ) with gr.Row(): with gr.Column(): gender = gr.Radio( choices=["male", "female"], value="male", label="Gender" ) pitch = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Pitch" ) speed = gr.Slider( minimum=1, maximum=5, step=1, value=3, label="Speed" ) with gr.Column(): text_input_creation = gr.Textbox( label="Input Text", lines=3, placeholder="Enter text here", value="You can generate a customized voice by adjusting parameters such as pitch and speed.", ) create_button = gr.Button("Create Voice") audio_output = gr.Audio( label="Generated Audio", streaming=True ) create_button.click( voice_creation, inputs=[text_input_creation, gender, pitch, speed], outputs=[audio_output], ) return app if __name__ == "__main__": build_spark_ui()