Run inference in 1 button click

This commit is contained in:
VSlobolinskyi 2025-03-21 17:07:09 +02:00
parent 7317afc9c0
commit 4287b7d577

View File

@ -16,73 +16,38 @@ def build_merged_ui():
spark_model = initialize_model(model_dir, device=device) spark_model = initialize_model(model_dir, device=device)
# Create the UI # Create the UI
with gr.Blocks(title="Unified TTS-RVC Pipeline") as merged_ui: with gr.Blocks(title="Unified TTS-RVC Pipeline") as app:
gr.Markdown("## Voice Generation and Conversion Pipeline") gr.Markdown("## Voice Generation and Conversion Pipeline")
gr.Markdown("Generate speech with Spark TTS and process it through RVC for voice conversion") gr.Markdown("Generate speech with Spark TTS and process it through RVC for voice conversion")
with gr.Tabs(): with gr.Tabs():
with gr.TabItem("TTS-to-RVC Pipeline"): with gr.TabItem("TTS-to-RVC Pipeline"):
gr.Markdown("### Step 1: Generate speech with Spark TTS") gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
# TTS Generation Section # TTS Generation Section
with gr.Tabs():
# Voice Clone option
with gr.TabItem("Voice Clone"):
with gr.Row():
prompt_wav_upload = gr.Audio(
sources="upload",
type="filepath",
label="Reference voice (upload)",
)
prompt_wav_record = gr.Audio(
sources="microphone",
type="filepath",
label="Reference voice (record)",
)
with gr.Row():
tts_text_input = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional)",
lines=3,
placeholder="Enter text of the reference audio",
)
# Voice Creation option
with gr.TabItem("Voice Creation"):
with gr.Row():
with gr.Column():
gender = gr.Radio(
choices=["male", "female"], value="male", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed"
)
with gr.Column():
tts_text_input_creation = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS",
value="Generate speech with this text and then convert the voice with RVC.",
)
tts_audio_output = gr.Audio(label="Generated TTS Audio")
with gr.Row(): with gr.Row():
generate_clone_button = gr.Button("Generate with Voice Clone", variant="primary") prompt_wav_upload = gr.Audio(
generate_create_button = gr.Button("Generate with Voice Creation", variant="primary") sources="upload",
type="filepath",
# Hidden text field to store the TTS audio path label="Reference voice (upload)",
tts_audio_path = gr.Textbox(visible=False) )
prompt_wav_record = gr.Audio(
gr.Markdown("### Step 2: Convert with RVC") sources="microphone",
type="filepath",
label="Reference voice (record)",
)
with gr.Row():
tts_text_input = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional)",
lines=3,
placeholder="Enter text of the reference audio",
)
# RVC Settings # RVC Settings
with gr.Row(): with gr.Row():
@ -166,51 +131,34 @@ def build_merged_ui():
interactive=True, interactive=True,
) )
# Process button and outputs # Combined process button and outputs
process_button = gr.Button("Process with RVC", variant="primary") generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
with gr.Row(): with gr.Row():
vc_output1 = gr.Textbox(label="Output information") vc_output1 = gr.Textbox(label="Output information")
vc_output2 = gr.Audio(label="Final converted audio") vc_output2 = gr.Audio(label="Final converted audio")
# Function to handle voice clone TTS generation # Function to handle combined TTS and RVC processing
def voice_clone_tts(text, prompt_text, prompt_wav_upload, prompt_wav_record): def generate_and_process_with_rvc(
text, prompt_text, prompt_wav_upload, prompt_wav_record,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
# First generate TTS audio
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
audio_output_path = run_tts( tts_path = run_tts(
text, text,
spark_model, spark_model,
prompt_text=prompt_text_clean, prompt_text=prompt_text_clean,
prompt_speech=prompt_speech prompt_speech=prompt_speech
) )
return audio_output_path, audio_output_path
# Function to handle voice creation TTS generation
def voice_creation_tts(text, gender, pitch, speed):
pitch_val = LEVELS_MAP_UI[int(pitch)]
speed_val = LEVELS_MAP_UI[int(speed)]
audio_output_path = run_tts(
text,
spark_model,
gender=gender,
pitch=pitch_val,
speed=speed_val
)
return audio_output_path, audio_output_path
# Function to process audio with RVC
def process_with_rvc(
tts_path, spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
# Make sure we have a TTS file to process # Make sure we have a TTS file to process
if not tts_path or not os.path.exists(tts_path): if not tts_path or not os.path.exists(tts_path):
return "No TTS audio generated yet", None return "Failed to generate TTS audio", None
# Call RVC processing function # Call RVC processing function
f0_file = None # We're not using an F0 curve file in this pipeline f0_file = None # We're not using an F0 curve file in this pipeline
@ -222,33 +170,14 @@ def build_merged_ui():
return output_info, output_audio return output_info, output_audio
# Connect functions to buttons # Connect function to button
generate_clone_button.click( generate_with_rvc_button.click(
voice_clone_tts, generate_and_process_with_rvc,
inputs=[ inputs=[
tts_text_input, tts_text_input,
prompt_text_input, prompt_text_input,
prompt_wav_upload, prompt_wav_upload,
prompt_wav_record, prompt_wav_record,
],
outputs=[tts_audio_output, tts_audio_path],
)
generate_create_button.click(
voice_creation_tts,
inputs=[
tts_text_input_creation,
gender,
pitch,
speed,
],
outputs=[tts_audio_output, tts_audio_path],
)
process_button.click(
process_with_rvc,
inputs=[
tts_audio_path,
spk_item, spk_item,
vc_transform0, vc_transform0,
f0method0, f0method0,
@ -278,7 +207,7 @@ def build_merged_ui():
outputs=[spk_item, protect0, file_index2], outputs=[spk_item, protect0, file_index2],
) )
return merged_ui return app
if __name__ == "__main__": if __name__ == "__main__":
build_merged_ui() build_merged_ui()