mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
Split merged_ui
This commit is contained in:
parent
4287b7d577
commit
9a0af2de15
@ -1,20 +1,15 @@
|
|||||||
import os
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
# Import modules from your packages
|
# Import modules from your packages
|
||||||
from rvc_ui.initialization import now_dir, config, vc
|
from merged_ui.utils import generate_and_process_with_rvc, modified_get_vc
|
||||||
from rvc_ui.main import build_rvc_ui, names, index_paths
|
from rvc_ui.initialization import config
|
||||||
from spark_ui.main import build_spark_ui, initialize_model, run_tts
|
from rvc_ui.main import names, index_paths
|
||||||
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||||
|
|
||||||
def build_merged_ui():
|
def build_merged_ui():
|
||||||
# Initialize the Spark TTS model
|
"""
|
||||||
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
Build the combined TTS-RVC UI interface using Gradio
|
||||||
device = 0
|
"""
|
||||||
spark_model = initialize_model(model_dir, device=device)
|
|
||||||
|
|
||||||
# Create the UI
|
# Create the UI
|
||||||
with gr.Blocks(title="Unified TTS-RVC Pipeline") as app:
|
with gr.Blocks(title="Unified TTS-RVC Pipeline") as app:
|
||||||
gr.Markdown("## Voice Generation and Conversion Pipeline")
|
gr.Markdown("## Voice Generation and Conversion Pipeline")
|
||||||
@ -138,39 +133,7 @@ def build_merged_ui():
|
|||||||
vc_output1 = gr.Textbox(label="Output information")
|
vc_output1 = gr.Textbox(label="Output information")
|
||||||
vc_output2 = gr.Audio(label="Final converted audio")
|
vc_output2 = gr.Audio(label="Final converted audio")
|
||||||
|
|
||||||
# Function to handle combined TTS and RVC processing
|
# Connect generate function to button
|
||||||
def generate_and_process_with_rvc(
|
|
||||||
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
|
||||||
spk_item, vc_transform, f0method,
|
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
|
||||||
resample_sr, rms_mix_rate, protect
|
|
||||||
):
|
|
||||||
# First generate TTS audio
|
|
||||||
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
|
||||||
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
|
||||||
|
|
||||||
tts_path = run_tts(
|
|
||||||
text,
|
|
||||||
spark_model,
|
|
||||||
prompt_text=prompt_text_clean,
|
|
||||||
prompt_speech=prompt_speech
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make sure we have a TTS file to process
|
|
||||||
if not tts_path or not os.path.exists(tts_path):
|
|
||||||
return "Failed to generate TTS audio", None
|
|
||||||
|
|
||||||
# Call RVC processing function
|
|
||||||
f0_file = None # We're not using an F0 curve file in this pipeline
|
|
||||||
output_info, output_audio = vc.vc_single(
|
|
||||||
spk_item, tts_path, vc_transform, f0_file, f0method,
|
|
||||||
file_index1, file_index2, index_rate, filter_radius,
|
|
||||||
resample_sr, rms_mix_rate, protect
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_info, output_audio
|
|
||||||
|
|
||||||
# Connect function to button
|
|
||||||
generate_with_rvc_button.click(
|
generate_with_rvc_button.click(
|
||||||
generate_and_process_with_rvc,
|
generate_and_process_with_rvc,
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -192,17 +155,9 @@ def build_merged_ui():
|
|||||||
outputs=[vc_output1, vc_output2],
|
outputs=[vc_output1, vc_output2],
|
||||||
)
|
)
|
||||||
|
|
||||||
def modified_get_vc(sid0_value, protect0_value):
|
# Connect modified_get_vc function for dropdown change
|
||||||
protect1_value = protect0_value
|
|
||||||
outputs = vc.get_vc(sid0_value, protect0_value, protect1_value)
|
|
||||||
|
|
||||||
if isinstance(outputs, tuple) and len(outputs) >= 3:
|
|
||||||
return outputs[0], outputs[1], outputs[3]
|
|
||||||
|
|
||||||
return 0, protect0_value, file_index2.choices[0] if file_index2.choices else ""
|
|
||||||
|
|
||||||
sid0.change(
|
sid0.change(
|
||||||
fn=modified_get_vc,
|
fn=lambda sid0_val, protect0_val: modified_get_vc(sid0_val, protect0_val, file_index2),
|
||||||
inputs=[sid0, protect0],
|
inputs=[sid0, protect0],
|
||||||
outputs=[spk_item, protect0, file_index2],
|
outputs=[spk_item, protect0, file_index2],
|
||||||
)
|
)
|
||||||
|
104
modules/merged_ui/utils.py
Normal file
104
modules/merged_ui/utils.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# Import modules from your packages
|
||||||
|
from rvc_ui.initialization import vc
|
||||||
|
from spark_ui.main import initialize_model, run_tts
|
||||||
|
from spark.sparktts.utils.token_parser import LEVELS_MAP_UI
|
||||||
|
|
||||||
|
# Initialize the Spark TTS model (moved outside function to avoid reinitializing)
|
||||||
|
model_dir = "spark/pretrained_models/Spark-TTS-0.5B"
|
||||||
|
device = 0
|
||||||
|
spark_model = initialize_model(model_dir, device=device)
|
||||||
|
|
||||||
|
def generate_and_process_with_rvc(
|
||||||
|
text, prompt_text, prompt_wav_upload, prompt_wav_record,
|
||||||
|
spk_item, vc_transform, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle combined TTS and RVC processing and save outputs to TEMP directories
|
||||||
|
"""
|
||||||
|
# Ensure TEMP directories exist
|
||||||
|
os.makedirs("./TEMP/spark", exist_ok=True)
|
||||||
|
os.makedirs("./TEMP/rvc", exist_ok=True)
|
||||||
|
|
||||||
|
# Get next fragment number
|
||||||
|
fragment_num = 1
|
||||||
|
while (os.path.exists(f"./TEMP/spark/fragment_{fragment_num}.wav") or
|
||||||
|
os.path.exists(f"./TEMP/rvc/fragment_{fragment_num}.wav")):
|
||||||
|
fragment_num += 1
|
||||||
|
|
||||||
|
# First generate TTS audio
|
||||||
|
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
||||||
|
prompt_text_clean = None if not prompt_text or len(prompt_text) < 2 else prompt_text
|
||||||
|
|
||||||
|
tts_path = run_tts(
|
||||||
|
text,
|
||||||
|
spark_model,
|
||||||
|
prompt_text=prompt_text_clean,
|
||||||
|
prompt_speech=prompt_speech
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure we have a TTS file to process
|
||||||
|
if not tts_path or not os.path.exists(tts_path):
|
||||||
|
return "Failed to generate TTS audio", None
|
||||||
|
|
||||||
|
# Save Spark output to TEMP/spark
|
||||||
|
spark_output_path = f"./TEMP/spark/fragment_{fragment_num}.wav"
|
||||||
|
shutil.copy2(tts_path, spark_output_path)
|
||||||
|
|
||||||
|
# Call RVC processing function
|
||||||
|
f0_file = None # We're not using an F0 curve file in this pipeline
|
||||||
|
output_info, output_audio = vc.vc_single(
|
||||||
|
spk_item, tts_path, vc_transform, f0_file, f0method,
|
||||||
|
file_index1, file_index2, index_rate, filter_radius,
|
||||||
|
resample_sr, rms_mix_rate, protect
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save RVC output to TEMP/rvc directory
|
||||||
|
rvc_output_path = f"./TEMP/rvc/fragment_{fragment_num}.wav"
|
||||||
|
rvc_saved = False
|
||||||
|
|
||||||
|
# Try different ways to save the RVC output based on common formats
|
||||||
|
try:
|
||||||
|
if isinstance(output_audio, str) and os.path.exists(output_audio):
|
||||||
|
# Case 1: output_audio is a file path string
|
||||||
|
shutil.copy2(output_audio, rvc_output_path)
|
||||||
|
rvc_saved = True
|
||||||
|
elif isinstance(output_audio, tuple) and len(output_audio) >= 2:
|
||||||
|
# Case 2: output_audio might be (sample_rate, audio_data)
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(rvc_output_path, output_audio[1], output_audio[0])
|
||||||
|
rvc_saved = True
|
||||||
|
except Exception as inner_e:
|
||||||
|
output_info += f"\nFailed to save RVC tuple format: {str(inner_e)}"
|
||||||
|
elif hasattr(output_audio, 'name') and os.path.exists(output_audio.name):
|
||||||
|
# Case 3: output_audio might be a file-like object
|
||||||
|
shutil.copy2(output_audio.name, rvc_output_path)
|
||||||
|
rvc_saved = True
|
||||||
|
except Exception as e:
|
||||||
|
output_info += f"\nError saving RVC output: {str(e)}"
|
||||||
|
|
||||||
|
# Add file paths to output info
|
||||||
|
output_info += f"\nSpark output saved to: {spark_output_path}"
|
||||||
|
if rvc_saved:
|
||||||
|
output_info += f"\nRVC output saved to: {rvc_output_path}"
|
||||||
|
else:
|
||||||
|
output_info += f"\nCould not automatically save RVC output to {rvc_output_path}"
|
||||||
|
|
||||||
|
return output_info, output_audio
|
||||||
|
|
||||||
|
def modified_get_vc(sid0_value, protect0_value, file_index2_component):
|
||||||
|
"""
|
||||||
|
Modified function to get voice conversion parameters
|
||||||
|
"""
|
||||||
|
protect1_value = protect0_value
|
||||||
|
outputs = vc.get_vc(sid0_value, protect0_value, protect1_value)
|
||||||
|
|
||||||
|
if isinstance(outputs, tuple) and len(outputs) >= 3:
|
||||||
|
return outputs[0], outputs[1], outputs[3]
|
||||||
|
|
||||||
|
return 0, protect0_value, file_index2_component.choices[0] if file_index2_component.choices else ""
|
Loading…
x
Reference in New Issue
Block a user