From 4287b7d57783225ff20f6e2764df401d6f958eac Mon Sep 17 00:00:00 2001 From: VSlobolinskyi Date: Fri, 21 Mar 2025 17:07:09 +0200 Subject: [PATCH] Run inference in 1 button click --- modules/merged_ui/main.py | 151 ++++++++++---------------------------- 1 file changed, 40 insertions(+), 111 deletions(-) diff --git a/modules/merged_ui/main.py b/modules/merged_ui/main.py index 361b964..6c48ba5 100644 --- a/modules/merged_ui/main.py +++ b/modules/merged_ui/main.py @@ -16,73 +16,38 @@ def build_merged_ui(): spark_model = initialize_model(model_dir, device=device) # Create the UI - with gr.Blocks(title="Unified TTS-RVC Pipeline") as merged_ui: + with gr.Blocks(title="Unified TTS-RVC Pipeline") as app: gr.Markdown("## Voice Generation and Conversion Pipeline") gr.Markdown("Generate speech with Spark TTS and process it through RVC for voice conversion") with gr.Tabs(): with gr.TabItem("TTS-to-RVC Pipeline"): - gr.Markdown("### Step 1: Generate speech with Spark TTS") + gr.Markdown("### Generate speech with Spark TTS and convert with RVC") # TTS Generation Section - with gr.Tabs(): - # Voice Clone option - with gr.TabItem("Voice Clone"): - with gr.Row(): - prompt_wav_upload = gr.Audio( - sources="upload", - type="filepath", - label="Reference voice (upload)", - ) - prompt_wav_record = gr.Audio( - sources="microphone", - type="filepath", - label="Reference voice (record)", - ) - - with gr.Row(): - tts_text_input = gr.Textbox( - label="Text to synthesize", - lines=3, - placeholder="Enter text for TTS" - ) - prompt_text_input = gr.Textbox( - label="Text of prompt speech (Optional)", - lines=3, - placeholder="Enter text of the reference audio", - ) - - # Voice Creation option - with gr.TabItem("Voice Creation"): - with gr.Row(): - with gr.Column(): - gender = gr.Radio( - choices=["male", "female"], value="male", label="Gender" - ) - pitch = gr.Slider( - minimum=1, maximum=5, step=1, value=3, label="Pitch" - ) - speed = gr.Slider( - minimum=1, maximum=5, step=1, value=3, label="Speed" - ) - with gr.Column(): - tts_text_input_creation = gr.Textbox( - label="Text to synthesize", - lines=3, - placeholder="Enter text for TTS", - value="Generate speech with this text and then convert the voice with RVC.", - ) - - tts_audio_output = gr.Audio(label="Generated TTS Audio") - with gr.Row(): - generate_clone_button = gr.Button("Generate with Voice Clone", variant="primary") - generate_create_button = gr.Button("Generate with Voice Creation", variant="primary") - - # Hidden text field to store the TTS audio path - tts_audio_path = gr.Textbox(visible=False) - - gr.Markdown("### Step 2: Convert with RVC") + prompt_wav_upload = gr.Audio( + sources="upload", + type="filepath", + label="Reference voice (upload)", + ) + prompt_wav_record = gr.Audio( + sources="microphone", + type="filepath", + label="Reference voice (record)", + ) + + with gr.Row(): + tts_text_input = gr.Textbox( + label="Text to synthesize", + lines=3, + placeholder="Enter text for TTS" + ) + prompt_text_input = gr.Textbox( + label="Text of prompt speech (Optional)", + lines=3, + placeholder="Enter text of the reference audio", + ) # RVC Settings with gr.Row(): @@ -166,51 +131,34 @@ def build_merged_ui(): interactive=True, ) - # Process button and outputs - process_button = gr.Button("Process with RVC", variant="primary") + # Combined process button and outputs + generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary") with gr.Row(): vc_output1 = gr.Textbox(label="Output information") vc_output2 = gr.Audio(label="Final converted audio") - # Function to handle voice clone TTS generation - def voice_clone_tts(text, prompt_text, prompt_wav_upload, prompt_wav_record): + # Function to handle combined TTS and RVC processing + def generate_and_process_with_rvc( + text, prompt_text, prompt_wav_upload, prompt_wav_record, + spk_item, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect + ): + # First generate TTS audio prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text - audio_output_path = run_tts( + tts_path = run_tts( text, spark_model, prompt_text=prompt_text_clean, prompt_speech=prompt_speech ) - return audio_output_path, audio_output_path - - # Function to handle voice creation TTS generation - def voice_creation_tts(text, gender, pitch, speed): - pitch_val = LEVELS_MAP_UI[int(pitch)] - speed_val = LEVELS_MAP_UI[int(speed)] - - audio_output_path = run_tts( - text, - spark_model, - gender=gender, - pitch=pitch_val, - speed=speed_val - ) - - return audio_output_path, audio_output_path - - # Function to process audio with RVC - def process_with_rvc( - tts_path, spk_item, vc_transform, f0method, - file_index1, file_index2, index_rate, filter_radius, - resample_sr, rms_mix_rate, protect - ): # Make sure we have a TTS file to process if not tts_path or not os.path.exists(tts_path): - return "No TTS audio generated yet", None + return "Failed to generate TTS audio", None # Call RVC processing function f0_file = None # We're not using an F0 curve file in this pipeline @@ -222,33 +170,14 @@ def build_merged_ui(): return output_info, output_audio - # Connect functions to buttons - generate_clone_button.click( - voice_clone_tts, + # Connect function to button + generate_with_rvc_button.click( + generate_and_process_with_rvc, inputs=[ tts_text_input, prompt_text_input, prompt_wav_upload, prompt_wav_record, - ], - outputs=[tts_audio_output, tts_audio_path], - ) - - generate_create_button.click( - voice_creation_tts, - inputs=[ - tts_text_input_creation, - gender, - pitch, - speed, - ], - outputs=[tts_audio_output, tts_audio_path], - ) - - process_button.click( - process_with_rvc, - inputs=[ - tts_audio_path, spk_item, vc_transform0, f0method0, @@ -278,7 +207,7 @@ def build_merged_ui(): outputs=[spk_item, protect0, file_index2], ) - return merged_ui + return app if __name__ == "__main__": build_merged_ui() \ No newline at end of file