Run inference in 1 button click

This commit is contained in:
VSlobolinskyi 2025-03-21 17:07:09 +02:00
parent 7317afc9c0
commit 4287b7d577

View File

@ -16,73 +16,38 @@ def build_merged_ui():
spark_model = initialize_model(model_dir, device=device)
# Create the UI
with gr.Blocks(title="Unified TTS-RVC Pipeline") as merged_ui:
with gr.Blocks(title="Unified TTS-RVC Pipeline") as app:
gr.Markdown("## Voice Generation and Conversion Pipeline")
gr.Markdown("Generate speech with Spark TTS and process it through RVC for voice conversion")
with gr.Tabs():
with gr.TabItem("TTS-to-RVC Pipeline"):
gr.Markdown("### Step 1: Generate speech with Spark TTS")
gr.Markdown("### Generate speech with Spark TTS and convert with RVC")
# TTS Generation Section
with gr.Tabs():
# Voice Clone option
with gr.TabItem("Voice Clone"):
with gr.Row():
prompt_wav_upload = gr.Audio(
sources="upload",
type="filepath",
label="Reference voice (upload)",
)
prompt_wav_record = gr.Audio(
sources="microphone",
type="filepath",
label="Reference voice (record)",
)
with gr.Row():
tts_text_input = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional)",
lines=3,
placeholder="Enter text of the reference audio",
)
# Voice Creation option
with gr.TabItem("Voice Creation"):
with gr.Row():
with gr.Column():
gender = gr.Radio(
choices=["male", "female"], value="male", label="Gender"
)
pitch = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Pitch"
)
speed = gr.Slider(
minimum=1, maximum=5, step=1, value=3, label="Speed"
)
with gr.Column():
tts_text_input_creation = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS",
value="Generate speech with this text and then convert the voice with RVC.",
)
tts_audio_output = gr.Audio(label="Generated TTS Audio")
with gr.Row():
generate_clone_button = gr.Button("Generate with Voice Clone", variant="primary")
generate_create_button = gr.Button("Generate with Voice Creation", variant="primary")
# Hidden text field to store the TTS audio path
tts_audio_path = gr.Textbox(visible=False)
gr.Markdown("### Step 2: Convert with RVC")
prompt_wav_upload = gr.Audio(
sources="upload",
type="filepath",
label="Reference voice (upload)",
)
prompt_wav_record = gr.Audio(
sources="microphone",
type="filepath",
label="Reference voice (record)",
)
with gr.Row():
tts_text_input = gr.Textbox(
label="Text to synthesize",
lines=3,
placeholder="Enter text for TTS"
)
prompt_text_input = gr.Textbox(
label="Text of prompt speech (Optional)",
lines=3,
placeholder="Enter text of the reference audio",
)
# RVC Settings
with gr.Row():
@ -166,51 +131,34 @@ def build_merged_ui():
interactive=True,
)
# Process button and outputs
process_button = gr.Button("Process with RVC", variant="primary")
# Combined process button and outputs
generate_with_rvc_button = gr.Button("Generate with RVC", variant="primary")
with gr.Row():
vc_output1 = gr.Textbox(label="Output information")
vc_output2 = gr.Audio(label="Final converted audio")
# Function to handle voice clone TTS generation
def voice_clone_tts(text, prompt_text, prompt_wav_upload, prompt_wav_record):
# Function to handle combined TTS and RVC processing
def generate_and_process_with_rvc(
text, prompt_text, prompt_wav_upload, prompt_wav_record,
spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
# First generate TTS audio
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
audio_output_path = run_tts(
tts_path = run_tts(
text,
spark_model,
prompt_text=prompt_text_clean,
prompt_speech=prompt_speech
)
return audio_output_path, audio_output_path
# Function to handle voice creation TTS generation
def voice_creation_tts(text, gender, pitch, speed):
pitch_val = LEVELS_MAP_UI[int(pitch)]
speed_val = LEVELS_MAP_UI[int(speed)]
audio_output_path = run_tts(
text,
spark_model,
gender=gender,
pitch=pitch_val,
speed=speed_val
)
return audio_output_path, audio_output_path
# Function to process audio with RVC
def process_with_rvc(
tts_path, spk_item, vc_transform, f0method,
file_index1, file_index2, index_rate, filter_radius,
resample_sr, rms_mix_rate, protect
):
# Make sure we have a TTS file to process
if not tts_path or not os.path.exists(tts_path):
return "No TTS audio generated yet", None
return "Failed to generate TTS audio", None
# Call RVC processing function
f0_file = None # We're not using an F0 curve file in this pipeline
@ -222,33 +170,14 @@ def build_merged_ui():
return output_info, output_audio
# Connect functions to buttons
generate_clone_button.click(
voice_clone_tts,
# Connect function to button
generate_with_rvc_button.click(
generate_and_process_with_rvc,
inputs=[
tts_text_input,
prompt_text_input,
prompt_wav_upload,
prompt_wav_record,
],
outputs=[tts_audio_output, tts_audio_path],
)
generate_create_button.click(
voice_creation_tts,
inputs=[
tts_text_input_creation,
gender,
pitch,
speed,
],
outputs=[tts_audio_output, tts_audio_path],
)
process_button.click(
process_with_rvc,
inputs=[
tts_audio_path,
spk_item,
vc_transform0,
f0method0,
@ -278,7 +207,7 @@ def build_merged_ui():
outputs=[spk_item, protect0, file_index2],
)
return merged_ui
return app
if __name__ == "__main__":
build_merged_ui()