diff --git a/infer-web.py b/infer-web.py index e8da63d..81ba9af 100644 --- a/infer-web.py +++ b/infer-web.py @@ -1,33 +1,24 @@ #!/usr/bin/env python3 import gradio as gr import traceback -from rvc_ui.initialization import now_dir, config, vc +from merged_ui.main import build_merged_ui +from rvc_ui.initialization import config from rvc_ui.main import build_rvc_ui from spark_ui.main import build_spark_ui -def build_unified_ui(): - rvc_ui = build_rvc_ui() # Returns a gr.Blocks instance for RVC WebUI - +def build_standalone_ui(): with gr.Blocks(title="Unified Inference UI") as app: gr.Markdown("## Unified Inference UI: RVC WebUI and Spark TTS") with gr.Tabs(): - with gr.TabItem("RVC WebUI"): - rvc_ui.render() with gr.TabItem("Spark TTS"): - # Instead of calling render() on the Spark UI object, - # we'll directly build it in this context - try: - # Create the Spark UI directly in this tab's context - build_spark_ui() - except Exception as e: - gr.Markdown(f"Error building Spark TTS: {str(e)}") - gr.Markdown(traceback.format_exc()) + build_spark_ui() + with gr.TabItem("RVC WebUI"): + build_rvc_ui() return app if __name__ == "__main__": - app = build_unified_ui() - # Needed for RVC + app = build_merged_ui() if config.iscolab: app.queue(concurrency_count=511, max_size=1022).launch(share=True) else: diff --git a/modules/merged_ui/__index__.py b/modules/merged_ui/__index__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/merged_ui/main.py b/modules/merged_ui/main.py new file mode 100644 index 0000000..361b964 --- /dev/null +++ b/modules/merged_ui/main.py @@ -0,0 +1,284 @@ +import os +import gradio as gr +import torch +from time import sleep + +# Import modules from your packages +from rvc_ui.initialization import now_dir, config, vc +from rvc_ui.main import build_rvc_ui, names, index_paths +from spark_ui.main import build_spark_ui, initialize_model, run_tts +from spark.sparktts.utils.token_parser import LEVELS_MAP_UI + +def build_merged_ui(): + # Initialize the Spark TTS model + model_dir = "spark/pretrained_models/Spark-TTS-0.5B" + device = 0 + spark_model = initialize_model(model_dir, device=device) + + # Create the UI + with gr.Blocks(title="Unified TTS-RVC Pipeline") as merged_ui: + gr.Markdown("## Voice Generation and Conversion Pipeline") + gr.Markdown("Generate speech with Spark TTS and process it through RVC for voice conversion") + + with gr.Tabs(): + with gr.TabItem("TTS-to-RVC Pipeline"): + gr.Markdown("### Step 1: Generate speech with Spark TTS") + + # TTS Generation Section + with gr.Tabs(): + # Voice Clone option + with gr.TabItem("Voice Clone"): + with gr.Row(): + prompt_wav_upload = gr.Audio( + sources="upload", + type="filepath", + label="Reference voice (upload)", + ) + prompt_wav_record = gr.Audio( + sources="microphone", + type="filepath", + label="Reference voice (record)", + ) + + with gr.Row(): + tts_text_input = gr.Textbox( + label="Text to synthesize", + lines=3, + placeholder="Enter text for TTS" + ) + prompt_text_input = gr.Textbox( + label="Text of prompt speech (Optional)", + lines=3, + placeholder="Enter text of the reference audio", + ) + + # Voice Creation option + with gr.TabItem("Voice Creation"): + with gr.Row(): + with gr.Column(): + gender = gr.Radio( + choices=["male", "female"], value="male", label="Gender" + ) + pitch = gr.Slider( + minimum=1, maximum=5, step=1, value=3, label="Pitch" + ) + speed = gr.Slider( + minimum=1, maximum=5, step=1, value=3, label="Speed" + ) + with gr.Column(): + tts_text_input_creation = gr.Textbox( + label="Text to synthesize", + lines=3, + placeholder="Enter text for TTS", + value="Generate speech with this text and then convert the voice with RVC.", + ) + + tts_audio_output = gr.Audio(label="Generated TTS Audio") + + with gr.Row(): + generate_clone_button = gr.Button("Generate with Voice Clone", variant="primary") + generate_create_button = gr.Button("Generate with Voice Creation", variant="primary") + + # Hidden text field to store the TTS audio path + tts_audio_path = gr.Textbox(visible=False) + + gr.Markdown("### Step 2: Convert with RVC") + + # RVC Settings + with gr.Row(): + with gr.Column(): + sid0 = gr.Dropdown( + label="Target voice model:", choices=sorted(names) + ) + vc_transform0 = gr.Number( + label="Transpose (semitones):", + value=0, + ) + f0method0 = gr.Radio( + label="Pitch extraction algorithm:", + choices=( + ["pm", "harvest", "crepe", "rmvpe"] + if config.dml == False + else ["pm", "harvest", "rmvpe"] + ), + value="rmvpe", + interactive=True, + ) + file_index1 = gr.Textbox( + label="Path to feature index file (leave blank for auto):", + placeholder="Leave blank to use dropdown selection", + interactive=True, + ) + file_index2 = gr.Dropdown( + label="Select feature index:", + choices=sorted(index_paths), + interactive=True, + ) + + with gr.Column(): + index_rate1 = gr.Slider( + minimum=0, + maximum=1, + label="Feature search ratio (accent strength):", + value=0.75, + interactive=True, + ) + filter_radius0 = gr.Slider( + minimum=0, + maximum=7, + label="Median filter radius (3+ reduces breathiness):", + value=3, + step=1, + interactive=True, + ) + rms_mix_rate0 = gr.Slider( + minimum=0, + maximum=1, + label="Volume envelope scaling (0=original, 1=constant):", + value=0.25, + interactive=True, + ) + protect0 = gr.Slider( + minimum=0, + maximum=0.5, + label="Consonant protection (0=max, 0.5=disable):", + value=0.33, + step=0.01, + interactive=True, + ) + resample_sr0 = gr.Slider( + minimum=0, + maximum=48000, + label="Output sample rate (0=no resampling):", + value=0, + step=1, + interactive=True, + ) + + # Speaker ID (hidden) + spk_item = gr.Slider( + minimum=0, + maximum=2333, + step=1, + label="Speaker ID:", + value=0, + visible=False, + interactive=True, + ) + + # Process button and outputs + process_button = gr.Button("Process with RVC", variant="primary") + + with gr.Row(): + vc_output1 = gr.Textbox(label="Output information") + vc_output2 = gr.Audio(label="Final converted audio") + + # Function to handle voice clone TTS generation + def voice_clone_tts(text, prompt_text, prompt_wav_upload, prompt_wav_record): + prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record + prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text + + audio_output_path = run_tts( + text, + spark_model, + prompt_text=prompt_text_clean, + prompt_speech=prompt_speech + ) + + return audio_output_path, audio_output_path + + # Function to handle voice creation TTS generation + def voice_creation_tts(text, gender, pitch, speed): + pitch_val = LEVELS_MAP_UI[int(pitch)] + speed_val = LEVELS_MAP_UI[int(speed)] + + audio_output_path = run_tts( + text, + spark_model, + gender=gender, + pitch=pitch_val, + speed=speed_val + ) + + return audio_output_path, audio_output_path + + # Function to process audio with RVC + def process_with_rvc( + tts_path, spk_item, vc_transform, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect + ): + # Make sure we have a TTS file to process + if not tts_path or not os.path.exists(tts_path): + return "No TTS audio generated yet", None + + # Call RVC processing function + f0_file = None # We're not using an F0 curve file in this pipeline + output_info, output_audio = vc.vc_single( + spk_item, tts_path, vc_transform, f0_file, f0method, + file_index1, file_index2, index_rate, filter_radius, + resample_sr, rms_mix_rate, protect + ) + + return output_info, output_audio + + # Connect functions to buttons + generate_clone_button.click( + voice_clone_tts, + inputs=[ + tts_text_input, + prompt_text_input, + prompt_wav_upload, + prompt_wav_record, + ], + outputs=[tts_audio_output, tts_audio_path], + ) + + generate_create_button.click( + voice_creation_tts, + inputs=[ + tts_text_input_creation, + gender, + pitch, + speed, + ], + outputs=[tts_audio_output, tts_audio_path], + ) + + process_button.click( + process_with_rvc, + inputs=[ + tts_audio_path, + spk_item, + vc_transform0, + f0method0, + file_index1, + file_index2, + index_rate1, + filter_radius0, + resample_sr0, + rms_mix_rate0, + protect0, + ], + outputs=[vc_output1, vc_output2], + ) + + def modified_get_vc(sid0_value, protect0_value): + protect1_value = protect0_value + outputs = vc.get_vc(sid0_value, protect0_value, protect1_value) + + if isinstance(outputs, tuple) and len(outputs) >= 3: + return outputs[0], outputs[1], outputs[3] + + return 0, protect0_value, file_index2.choices[0] if file_index2.choices else "" + + sid0.change( + fn=modified_get_vc, + inputs=[sid0, protect0], + outputs=[spk_item, protect0, file_index2], + ) + + return merged_ui + +if __name__ == "__main__": + build_merged_ui() \ No newline at end of file diff --git a/modules/rvc_ui/main.py b/modules/rvc_ui/main.py index 80c4f68..a1be99b 100644 --- a/modules/rvc_ui/main.py +++ b/modules/rvc_ui/main.py @@ -38,7 +38,7 @@ F0GPUVisible = config.dml == False # Build Gradio UI def build_rvc_ui(): - with gr.Blocks(title="RVC WebUI") as rvc_ui: + with gr.Blocks(title="RVC WebUI") as app: gr.Markdown("## RVC WebUI") gr.Markdown( value="This software is open source under the MIT license. The author does not have any control over the software. Users who use the software and distribute the sounds exported by the software are solely responsible.
If you do not agree with this clause, you cannot use or reference any codes and files within the software package. See the root directory Agreement-LICENSE.txt for details." @@ -326,7 +326,7 @@ def build_rvc_ui(): ], api_name="infer_change_voice", ) - return rvc_ui + return app if __name__ == "__main__": diff --git a/modules/spark_ui/main.py b/modules/spark_ui/main.py index 2a76253..faa9529 100644 --- a/modules/spark_ui/main.py +++ b/modules/spark_ui/main.py @@ -1,3 +1,4 @@ +# spark_ui/main.py import os import torch import soundfile as sf @@ -120,7 +121,7 @@ def build_spark_ui(): ) return audio_output_path - with gr.Blocks() as demo: + with gr.Blocks() as app: # Use HTML for centered title gr.HTML('

Spark-TTS by SparkAudio

') with gr.Tabs(): @@ -204,7 +205,7 @@ def build_spark_ui(): outputs=[audio_output], ) - return demo + return app if __name__ == "__main__": build_spark_ui() \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index ac9ba55..702630d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4417,4 +4417,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "58fd044b5e4a2d7ba9fbb0291188be2eb15788e9c5ef70024b2effa5b9915681" +content-hash = "e73427a82e6d9b8531424dc6552247cca063044c9bd2aa9410f2a37bee9e8d3b" diff --git a/pyproject.toml b/pyproject.toml index d4a948f..d2987d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,9 @@ from = "modules" include = "spark_ui" from = "modules" +[[tool.poetry.packages]] +include = "merged_ui" +from = "modules" [tool.poetry.dependencies] python = ">=3.11,<3.12" @@ -36,38 +39,38 @@ torch-directml = "^0.2.5.dev240914" # --------------------------------------------------------------------------- fairseq = { git = "https://github.com/One-sixth/fairseq.git" } joblib = ">=1.1.0" -numba = "*" -llvmlite = "*" -Cython = "*" +numba = ">=0.56.0" +llvmlite = ">=0.39.0" +Cython = ">=0.29.0" numpy = ">=1.0,<2.0" -scipy = "*" +scipy = ">=1.9.0" librosa = "==0.10.2" -faiss-cpu = "*" -gradio = "3.50.0" +faiss-cpu = ">=1.7.0" +gradio = "==3.50.0" soundfile = "0.12.1" ffmpeg-python = ">=0.2.0" matplotlib = ">=3.7.0" matplotlib-inline = ">=0.1.3" praat-parselmouth = ">=0.4.2" -tensorboardX = "*" -tensorboard = "*" +tensorboardX = ">=2.5.0" +tensorboard = ">=2.10.0" Pillow = ">=9.1.1" -scikit-learn = "*" +scikit-learn = ">=1.0.0" tqdm = ">=4.63.1" uvicorn = ">=0.21.1" pyworld = "==0.3.2" -onnxruntime = { version = "*", markers = "sys_platform == 'darwin'" } -onnxruntime-gpu = { version = "*", markers = "sys_platform != 'darwin'" } +onnxruntime = { version = ">=1.13.0", markers = "sys_platform == 'darwin'" } +onnxruntime-gpu = { version = ">=1.13.0", markers = "sys_platform != 'darwin'" } torchcrepe = "==0.0.23" fastapi = "==0.88" -torchfcpe = "*" +torchfcpe = ">=0.0.1" ffmpy = "==0.3.1" python-dotenv = ">=1.0.0" -av = "*" -autoflake = "^2.3.1" +av = ">=9.0.0" +autoflake = "2.3.1" einops = "0.8.1" einx = "0.3.0" -transformers = "^4.49.0" +transformers = "4.49.0" [tool.poetry.group.dev.dependencies] black = "^25.1.0"