From 37d2d27638a0e9cc07d5cd1ba4ca2b7e8d86cf7f Mon Sep 17 00:00:00 2001 From: VSlobolinskyi Date: Thu, 20 Mar 2025 22:49:54 +0200 Subject: [PATCH] Integrate spark into the project --- .gitignore | 153 +--- infer-web.py | 22 +- modules/spark_ui/__init__.py | 0 modules/spark_ui/main.py | 210 +++++ poetry.lock | 213 ++++- pyproject.toml | 14 +- spark/cli/SparkTTS.py | 236 +++++ spark/cli/inference.py | 116 +++ spark/runtime/triton_trtllm/README.md | 45 + spark/runtime/triton_trtllm/client_grpc.py | 482 ++++++++++ spark/runtime/triton_trtllm/client_http.py | 165 ++++ .../model_repo/audio_tokenizer/1/model.py | 137 +++ .../model_repo/audio_tokenizer/config.pbtxt | 58 ++ .../model_repo/spark_tts/1/model.py | 311 +++++++ .../model_repo/spark_tts/config.pbtxt | 65 ++ .../model_repo/tensorrt_llm/1/.gitkeep | 0 .../model_repo/tensorrt_llm/config.pbtxt | 857 ++++++++++++++++++ .../model_repo/vocoder/1/model.py | 106 +++ .../model_repo/vocoder/config.pbtxt | 53 ++ .../scripts/convert_checkpoint.py | 335 +++++++ .../triton_trtllm/scripts/fill_template.py | 70 ++ spark/sparktts/models/audio_tokenizer.py | 163 ++++ spark/sparktts/models/bicodec.py | 247 +++++ spark/sparktts/modules/blocks/layers.py | 73 ++ spark/sparktts/modules/blocks/samper.py | 115 +++ spark/sparktts/modules/blocks/vocos.py | 373 ++++++++ .../modules/encoder_decoder/feat_decoder.py | 115 +++ .../modules/encoder_decoder/feat_encoder.py | 105 +++ .../modules/encoder_decoder/wave_generator.py | 88 ++ .../modules/fsq/finite_scalar_quantization.py | 251 +++++ spark/sparktts/modules/fsq/residual_fsq.py | 355 ++++++++ spark/sparktts/modules/speaker/ecapa_tdnn.py | 267 ++++++ .../modules/speaker/perceiver_encoder.py | 360 ++++++++ .../modules/speaker/pooling_layers.py | 298 ++++++ .../modules/speaker/speaker_encoder.py | 136 +++ .../modules/vq/factorized_vector_quantize.py | 187 ++++ spark/sparktts/utils/__init__.py | 0 spark/sparktts/utils/audio.py | 271 ++++++ spark/sparktts/utils/file.py | 221 +++++ spark/sparktts/utils/parse_options.sh | 97 ++ spark/sparktts/utils/token_parser.py | 187 ++++ tools/download_assets.py | 18 +- 42 files changed, 7405 insertions(+), 170 deletions(-) create mode 100644 modules/spark_ui/__init__.py create mode 100644 modules/spark_ui/main.py create mode 100644 spark/cli/SparkTTS.py create mode 100644 spark/cli/inference.py create mode 100644 spark/runtime/triton_trtllm/README.md create mode 100644 spark/runtime/triton_trtllm/client_grpc.py create mode 100644 spark/runtime/triton_trtllm/client_http.py create mode 100644 spark/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py create mode 100644 spark/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt create mode 100644 spark/runtime/triton_trtllm/model_repo/spark_tts/1/model.py create mode 100644 spark/runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt create mode 100644 spark/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep create mode 100644 spark/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt create mode 100644 spark/runtime/triton_trtllm/model_repo/vocoder/1/model.py create mode 100644 spark/runtime/triton_trtllm/model_repo/vocoder/config.pbtxt create mode 100644 spark/runtime/triton_trtllm/scripts/convert_checkpoint.py create mode 100644 spark/runtime/triton_trtllm/scripts/fill_template.py create mode 100644 spark/sparktts/models/audio_tokenizer.py create mode 100644 spark/sparktts/models/bicodec.py create mode 100644 spark/sparktts/modules/blocks/layers.py create mode 100644 spark/sparktts/modules/blocks/samper.py create mode 100644 spark/sparktts/modules/blocks/vocos.py create mode 100644 spark/sparktts/modules/encoder_decoder/feat_decoder.py create mode 100644 spark/sparktts/modules/encoder_decoder/feat_encoder.py create mode 100644 spark/sparktts/modules/encoder_decoder/wave_generator.py create mode 100644 spark/sparktts/modules/fsq/finite_scalar_quantization.py create mode 100644 spark/sparktts/modules/fsq/residual_fsq.py create mode 100644 spark/sparktts/modules/speaker/ecapa_tdnn.py create mode 100644 spark/sparktts/modules/speaker/perceiver_encoder.py create mode 100644 spark/sparktts/modules/speaker/pooling_layers.py create mode 100644 spark/sparktts/modules/speaker/speaker_encoder.py create mode 100644 spark/sparktts/modules/vq/factorized_vector_quantize.py create mode 100644 spark/sparktts/utils/__init__.py create mode 100644 spark/sparktts/utils/audio.py create mode 100644 spark/sparktts/utils/file.py create mode 100644 spark/sparktts/utils/parse_options.sh create mode 100644 spark/sparktts/utils/token_parser.py diff --git a/.gitignore b/.gitignore index 70de05b..cfbb4eb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,20 +1,13 @@ -# ========= Common (applies to the entire repository) ========= -# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class - -# Virtual environments +*.so +.Python .env .venv env/ venv/ ENV/ -env.bak/ -venv.bak/ - -# Distribution / packaging -.Python build/ develop-eggs/ dist/ @@ -27,51 +20,27 @@ parts/ sdist/ var/ wheels/ -share/python-wheels/ *.egg-info/ .installed.cfg *.egg -MANIFEST - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ .pytest_cache/ -cover/ +.coverage +htmlcov/ -# Translations -*.mo -*.pot - -# IDE and editor settings -# PyCharm +# IDE settings .idea/ -# VSCode .vscode/ +*.swp +*.swo -# Jupyter Notebook +# Jupyter .ipynb_checkpoints -# Others +# Logs and temporary files *.log -*.spec -*.manifest +.cache -# ========= Exclusions specific to the RVC Inference Module ========= -# Directories generated by RVC or for runtime usage: +# RVC specific /TEMP /opt /tools/aria2c/ @@ -89,106 +58,10 @@ rmvpe.onnx ffmpeg.* ffprobe.* -# ========= Exclusions for the Spark Repository ========= -# (Since Spark files will be moved under ./spark, prefix these rules with "spark/") -# Byte-compiled / optimized / DLL files in spark -spark/__pycache__/ -spark/*.py[cod] -spark/*$py.class - -# Directories and files generated in Spark +# Spark specific spark/pretrained_models/ spark/results/ spark/demo/ spark/.gradio/ - -# Distribution/packaging for Spark -spark/.Python -spark/build/ -spark/develop-eggs/ -spark/dist/ -spark/downloads/ -spark/eggs/ -spark/.eggs/ -spark/lib/ -spark/lib64/ -spark/parts/ -spark/sdist/ -spark/var/ -spark/wheels/ -spark/share/python-wheels/ -spark/*.egg-info/ -spark/.installed.cfg -spark/*.egg -spark/MANIFEST spark/webui_test.py - -# PyInstaller (for Spark) -spark/*.manifest -spark/*.spec - -# Installer logs for Spark -spark/pip-log.txt -spark/pip-delete-this-directory.txt - -# Unit test / coverage reports for Spark -spark/htmlcov/ -spark/.tox/ -spark/.nox/ -spark/.coverage -spark/.coverage.* -spark/.cache -spark/nosetests.xml -spark/coverage.xml -spark/*.cover -spark/*.py,cover -spark/.hypothesis/ -spark/.pytest_cache/ -spark/cover/ - -# Translations (Spark) -spark/*.mo -spark/*.pot - -# Django/Flask/other web framework logs for Spark (if any) -spark/*.log -spark/local_settings.py -spark/db.sqlite3 -spark/db.sqlite3-journal - -# Flask and Scrapy caches for Spark -spark/instance/ -spark/.webassets-cache -spark/.scrapy - -# Sphinx documentation build for Spark -spark/docs/_build/ - -# PyBuilder / PEP582 for Spark -spark/.pybuilder/ -spark/target/ -spark/__pypackages__/ - -# Celery / SageMath for Spark -spark/celerybeat-schedule -spark/celerybeat.pid -spark/*.sage.py - -# IDE settings for Spark (if desired) -spark/.idea/ - -# MkDocs for Spark -spark/site/ - -# Type checker caches for Spark -spark/.mypy_cache/ -spark.dmypy.json -spark/dmypy.json -spark/.pyre/ -spark/.pytype/ - -# Cython debug symbols for Spark -spark/cython_debug/ - -# PyPI configuration for Spark -spark/.pypirc +spark/example \ No newline at end of file diff --git a/infer-web.py b/infer-web.py index 348fbee..e8da63d 100644 --- a/infer-web.py +++ b/infer-web.py @@ -1,23 +1,27 @@ #!/usr/bin/env python3 import gradio as gr +import traceback from rvc_ui.initialization import now_dir, config, vc from rvc_ui.main import build_rvc_ui -# from spark_ui.main import build_spark_ui +from spark_ui.main import build_spark_ui def build_unified_ui(): - # Build each sub-UI rvc_ui = build_rvc_ui() # Returns a gr.Blocks instance for RVC WebUI - # spark_ui = build_spark_ui() # Returns a gr.Blocks instance for Spark TTS with gr.Blocks(title="Unified Inference UI") as app: gr.Markdown("## Unified Inference UI: RVC WebUI and Spark TTS") with gr.Tabs(): with gr.TabItem("RVC WebUI"): - # Render the RVC UI components rvc_ui.render() - # with gr.TabItem("Spark TTS"): - # # Render the Spark UI components - # spark_ui.render() + with gr.TabItem("Spark TTS"): + # Instead of calling render() on the Spark UI object, + # we'll directly build it in this context + try: + # Create the Spark UI directly in this tab's context + build_spark_ui() + except Exception as e: + gr.Markdown(f"Error building Spark TTS: {str(e)}") + gr.Markdown(traceback.format_exc()) return app @@ -32,6 +36,4 @@ if __name__ == "__main__": inbrowser=not config.noautoopen, server_port=config.listen_port, quiet=True, - ) - - + ) \ No newline at end of file diff --git a/modules/spark_ui/__init__.py b/modules/spark_ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/spark_ui/main.py b/modules/spark_ui/main.py new file mode 100644 index 0000000..2a76253 --- /dev/null +++ b/modules/spark_ui/main.py @@ -0,0 +1,210 @@ +import os +import torch +import soundfile as sf +import logging +import argparse +import gradio as gr +import platform + +from datetime import datetime +from spark.cli.SparkTTS import SparkTTS +from spark.sparktts.utils.token_parser import LEVELS_MAP_UI + + +def initialize_model(model_dir, device): + """Load the model once at the beginning.""" + logging.info(f"Loading model from: {model_dir}") + + # Determine appropriate device based on platform and availability + if platform.system() == "Darwin": + # macOS with MPS support (Apple Silicon) + device = torch.device(f"mps:{device}") + logging.info(f"Using MPS device: {device}") + elif torch.cuda.is_available(): + # System with CUDA support + device = torch.device(f"cuda:{device}") + logging.info(f"Using CUDA device: {device}") + else: + # Fall back to CPU + device = torch.device("cpu") + logging.info("GPU acceleration not available, using CPU") + + model = SparkTTS(model_dir, device) + return model + + +def run_tts( + text, + model, + prompt_text=None, + prompt_speech=None, + gender=None, + pitch=None, + speed=None, + save_dir="spark/example/results", +): + """Perform TTS inference and save the generated audio.""" + logging.info(f"Saving audio to: {save_dir}") + + if prompt_text is not None: + prompt_text = None if len(prompt_text) <= 1 else prompt_text + + # Ensure the save directory exists + os.makedirs(save_dir, exist_ok=True) + + # Generate unique filename using timestamp + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + save_path = os.path.join(save_dir, f"{timestamp}.wav") + + logging.info("Starting inference...") + + # Perform inference and save the output audio + with torch.no_grad(): + wav = model.inference( + text, + prompt_speech, + prompt_text, + gender, + pitch, + speed, + ) + + sf.write(save_path, wav, samplerate=16000) + + logging.info(f"Audio saved at: {save_path}") + + return save_path + + +def build_spark_ui(): + model_dir = "spark/pretrained_models/Spark-TTS-0.5B" + device = 0 + # Initialize model + model = initialize_model(model_dir, device=device) + + # Define callback function for voice cloning + def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record): + """ + Gradio callback to clone voice using text and optional prompt speech. + - text: The input text to be synthesised. + - prompt_text: Additional textual info for the prompt (optional). + - prompt_wav_upload/prompt_wav_record: Audio files used as reference. + """ + prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record + prompt_text_clean = None if len(prompt_text) < 2 else prompt_text + + audio_output_path = run_tts( + text, + model, + prompt_text=prompt_text_clean, + prompt_speech=prompt_speech + ) + return audio_output_path + + # Define callback function for creating new voices + def voice_creation(text, gender, pitch, speed): + """ + Gradio callback to create a synthetic voice with adjustable parameters. + - text: The input text for synthesis. + - gender: 'male' or 'female'. + - pitch/speed: Ranges mapped by LEVELS_MAP_UI. + """ + pitch_val = LEVELS_MAP_UI[int(pitch)] + speed_val = LEVELS_MAP_UI[int(speed)] + audio_output_path = run_tts( + text, + model, + gender=gender, + pitch=pitch_val, + speed=speed_val + ) + return audio_output_path + + with gr.Blocks() as demo: + # Use HTML for centered title + gr.HTML('

Spark-TTS by SparkAudio

') + with gr.Tabs(): + # Voice Clone Tab + with gr.TabItem("Voice Clone"): + gr.Markdown( + "### Upload reference audio or recording (上传参考音频或者录音)" + ) + + with gr.Row(): + prompt_wav_upload = gr.Audio( + sources="upload", + type="filepath", + label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.", + ) + prompt_wav_record = gr.Audio( + sources="microphone", + type="filepath", + label="Record the prompt audio file.", + ) + + with gr.Row(): + text_input = gr.Textbox( + label="Text", lines=3, placeholder="Enter text here" + ) + prompt_text_input = gr.Textbox( + label="Text of prompt speech (Optional; recommended for cloning in the same language.)", + lines=3, + placeholder="Enter text of the prompt speech.", + ) + + audio_output = gr.Audio( + label="Generated Audio", streaming=True + ) + + generate_buttom_clone = gr.Button("Generate") + + generate_buttom_clone.click( + voice_clone, + inputs=[ + text_input, + prompt_text_input, + prompt_wav_upload, + prompt_wav_record, + ], + outputs=[audio_output], + ) + + # Voice Creation Tab + with gr.TabItem("Voice Creation"): + gr.Markdown( + "### Create your own voice based on the following parameters" + ) + + with gr.Row(): + with gr.Column(): + gender = gr.Radio( + choices=["male", "female"], value="male", label="Gender" + ) + pitch = gr.Slider( + minimum=1, maximum=5, step=1, value=3, label="Pitch" + ) + speed = gr.Slider( + minimum=1, maximum=5, step=1, value=3, label="Speed" + ) + with gr.Column(): + text_input_creation = gr.Textbox( + label="Input Text", + lines=3, + placeholder="Enter text here", + value="You can generate a customized voice by adjusting parameters such as pitch and speed.", + ) + create_button = gr.Button("Create Voice") + + audio_output = gr.Audio( + label="Generated Audio", streaming=True + ) + create_button.click( + voice_creation, + inputs=[text_input_creation, gender, pitch, speed], + outputs=[audio_output], + ) + + return demo + +if __name__ == "__main__": + build_spark_ui() \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index bebac28..ac9ba55 100644 --- a/poetry.lock +++ b/poetry.lock @@ -785,6 +785,27 @@ files = [ {file = "einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84"}, ] +[[package]] +name = "einx" +version = "0.3.0" +description = "Universal Tensor Operations in Einstein-Inspired Notation for Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "einx-0.3.0-py3-none-any.whl", hash = "sha256:367d62bab8dbb8c4937308512abb6f746cc0920990589892ba0d281356d39345"}, + {file = "einx-0.3.0.tar.gz", hash = "sha256:17ff87c6a0f68ab358c1da489f00e95f1de106fd12ff17d0fb3e210aaa1e5f8c"}, +] + +[package.dependencies] +frozendict = "*" +numpy = "*" +sympy = "*" + +[package.extras] +keras = ["keras (>=3)"] +torch = ["torch (>=2)"] + [[package]] name = "fairseq" version = "0.12.3" @@ -1001,6 +1022,55 @@ ufo = ["fs (>=2.2.0,<3)"] unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] +[[package]] +name = "frozendict" +version = "2.4.6" +description = "A simple immutable dictionary" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "frozendict-2.4.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c3a05c0a50cab96b4bb0ea25aa752efbfceed5ccb24c007612bc63e51299336f"}, + {file = "frozendict-2.4.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f5b94d5b07c00986f9e37a38dd83c13f5fe3bf3f1ccc8e88edea8fe15d6cd88c"}, + {file = "frozendict-2.4.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4c789fd70879ccb6289a603cdebdc4953e7e5dea047d30c1b180529b28257b5"}, + {file = "frozendict-2.4.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da6a10164c8a50b34b9ab508a9420df38f4edf286b9ca7b7df8a91767baecb34"}, + {file = "frozendict-2.4.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9a8a43036754a941601635ea9c788ebd7a7efbed2becba01b54a887b41b175b9"}, + {file = "frozendict-2.4.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9905dcf7aa659e6a11b8051114c9fa76dfde3a6e50e6dc129d5aece75b449a2"}, + {file = "frozendict-2.4.6-cp310-cp310-win_amd64.whl", hash = "sha256:323f1b674a2cc18f86ab81698e22aba8145d7a755e0ac2cccf142ee2db58620d"}, + {file = "frozendict-2.4.6-cp310-cp310-win_arm64.whl", hash = "sha256:eabd21d8e5db0c58b60d26b4bb9839cac13132e88277e1376970172a85ee04b3"}, + {file = "frozendict-2.4.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:eddabeb769fab1e122d3a6872982c78179b5bcc909fdc769f3cf1964f55a6d20"}, + {file = "frozendict-2.4.6-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:377a65be0a700188fc21e669c07de60f4f6d35fae8071c292b7df04776a1c27b"}, + {file = "frozendict-2.4.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce1e9217b85eec6ba9560d520d5089c82dbb15f977906eb345d81459723dd7e3"}, + {file = "frozendict-2.4.6-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:7291abacf51798d5ffe632771a69c14fb423ab98d63c4ccd1aa382619afe2f89"}, + {file = "frozendict-2.4.6-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:e72fb86e48811957d66ffb3e95580af7b1af1e6fbd760ad63d7bd79b2c9a07f8"}, + {file = "frozendict-2.4.6-cp36-cp36m-win_amd64.whl", hash = "sha256:622301b1c29c4f9bba633667d592a3a2b093cb408ba3ce578b8901ace3931ef3"}, + {file = "frozendict-2.4.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a4e3737cb99ed03200cd303bdcd5514c9f34b29ee48f405c1184141bd68611c9"}, + {file = "frozendict-2.4.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49ffaf09241bc1417daa19362a2241a4aa435f758fd4375c39ce9790443a39cd"}, + {file = "frozendict-2.4.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d69418479bfb834ba75b0e764f058af46ceee3d655deb6a0dd0c0c1a5e82f09"}, + {file = "frozendict-2.4.6-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:c131f10c4d3906866454c4e89b87a7e0027d533cce8f4652aa5255112c4d6677"}, + {file = "frozendict-2.4.6-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:fc67cbb3c96af7a798fab53d52589752c1673027e516b702ab355510ddf6bdff"}, + {file = "frozendict-2.4.6-cp37-cp37m-win_amd64.whl", hash = "sha256:7730f8ebe791d147a1586cbf6a42629351d4597773317002181b66a2da0d509e"}, + {file = "frozendict-2.4.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:807862e14b0e9665042458fde692c4431d660c4219b9bb240817f5b918182222"}, + {file = "frozendict-2.4.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:9647c74efe3d845faa666d4853cfeabbaee403b53270cabfc635b321f770e6b8"}, + {file = "frozendict-2.4.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:665fad3f0f815aa41294e561d98dbedba4b483b3968e7e8cab7d728d64b96e33"}, + {file = "frozendict-2.4.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f42e6b75254ea2afe428ad6d095b62f95a7ae6d4f8272f0bd44a25dddd20f67"}, + {file = "frozendict-2.4.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:02331541611f3897f260900a1815b63389654951126e6e65545e529b63c08361"}, + {file = "frozendict-2.4.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:18d50a2598350b89189da9150058191f55057581e40533e470db46c942373acf"}, + {file = "frozendict-2.4.6-cp38-cp38-win_amd64.whl", hash = "sha256:1b4a3f8f6dd51bee74a50995c39b5a606b612847862203dd5483b9cd91b0d36a"}, + {file = "frozendict-2.4.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a76cee5c4be2a5d1ff063188232fffcce05dde6fd5edd6afe7b75b247526490e"}, + {file = "frozendict-2.4.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba5ef7328706db857a2bdb2c2a17b4cd37c32a19c017cff1bb7eeebc86b0f411"}, + {file = "frozendict-2.4.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:669237c571856be575eca28a69e92a3d18f8490511eff184937283dc6093bd67"}, + {file = "frozendict-2.4.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0aaa11e7c472150efe65adbcd6c17ac0f586896096ab3963775e1c5c58ac0098"}, + {file = "frozendict-2.4.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b8f2829048f29fe115da4a60409be2130e69402e29029339663fac39c90e6e2b"}, + {file = "frozendict-2.4.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:94321e646cc39bebc66954a31edd1847d3a2a3483cf52ff051cd0996e7db07db"}, + {file = "frozendict-2.4.6-cp39-cp39-win_amd64.whl", hash = "sha256:74b6b26c15dddfefddeb89813e455b00ebf78d0a3662b89506b4d55c6445a9f4"}, + {file = "frozendict-2.4.6-cp39-cp39-win_arm64.whl", hash = "sha256:7088102345d1606450bd1801a61139bbaa2cb0d805b9b692f8d81918ea835da6"}, + {file = "frozendict-2.4.6-py311-none-any.whl", hash = "sha256:d065db6a44db2e2375c23eac816f1a022feb2fa98cbb50df44a9e83700accbea"}, + {file = "frozendict-2.4.6-py312-none-any.whl", hash = "sha256:49344abe90fb75f0f9fdefe6d4ef6d4894e640fadab71f11009d52ad97f370b9"}, + {file = "frozendict-2.4.6-py313-none-any.whl", hash = "sha256:7134a2bb95d4a16556bb5f2b9736dceb6ea848fa5b6f3f6c2d6dba93b44b4757"}, + {file = "frozendict-2.4.6.tar.gz", hash = "sha256:df7cd16470fbd26fc4969a208efadc46319334eb97def1ddf48919b351192b8e"}, +] + [[package]] name = "fsspec" version = "2024.6.1" @@ -3470,6 +3540,44 @@ dev = ["lxml-stubs", "mypy", "pytest", "types-tabulate", "wheel"] ja = ["ipadic (>=1.0,<2.0)", "mecab-python3 (>=1.0.5,<=1.0.6)"] ko = ["mecab-ko (>=1.0.0,<=1.0.1)", "mecab-ko-dic (>=1.0,<2.0)"] +[[package]] +name = "safetensors" +version = "0.5.3" +description = "" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073"}, + {file = "safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04"}, + {file = "safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace"}, + {file = "safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11"}, + {file = "safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +mlx = ["mlx (>=0.0.9)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.18.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + [[package]] name = "scikit-learn" version = "1.5.1" @@ -3794,6 +3902,39 @@ files = [ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, ] +[[package]] +name = "tokenizers" +version = "0.21.1" +description = "" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41"}, + {file = "tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c"}, + {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a"}, + {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf"}, + {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6"}, + {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d"}, + {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f"}, + {file = "tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3"}, + {file = "tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382"}, + {file = "tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + [[package]] name = "torch" version = "2.4.1" @@ -4020,6 +4161,76 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "transformers" +version = "4.49.0" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.9.0" +groups = ["main"] +files = [ + {file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"}, + {file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.26.0,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.21,<0.22" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.26.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +benchmark = ["optimum-benchmark (>=0.3.0)"] +codecarbon = ["codecarbon (>=2.8.1)"] +deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +ruff = ["ruff (==0.5.1)"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +tiktoken = ["blobfile", "tiktoken"] +timm = ["timm (<=1.0.11)"] +tokenizers = ["tokenizers (>=0.21,<0.22)"] +torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] +video = ["av"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + [[package]] name = "triton" version = "3.0.0" @@ -4206,4 +4417,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "614c8a809ba49bd299e21b84f25dd78e4979c773f97fe107274fe18bac7309ed" +content-hash = "58fd044b5e4a2d7ba9fbb0291188be2eb15788e9c5ef70024b2effa5b9915681" diff --git a/pyproject.toml b/pyproject.toml index 4f5cfaa..28dd012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,10 @@ repository = "https://github.com/VSlobolinskyi/spark-rvc-inference-module" include = "rvc_ui" from = "modules" +[[tool.poetry.packages]] +include = "spark_ui" +from = "modules" + [tool.poetry.dependencies] python = ">=3.11,<3.12" @@ -22,12 +26,12 @@ joblib = ">=1.1.0" numba = "*" llvmlite = "*" Cython = "*" -numpy = "*" +numpy = ">=1.0,<2.0" scipy = "*" librosa = "==0.10.2" faiss-cpu = "*" gradio = "3.50.0" -soundfile = ">=0.12.1" +soundfile = "0.12.1" ffmpeg-python = ">=0.2.0" matplotlib = ">=3.7.0" matplotlib-inline = ">=0.1.3" @@ -47,6 +51,9 @@ torchfcpe = "*" ffmpy = "==0.3.1" python-dotenv = ">=1.0.0" av = "*" +autoflake = "^2.3.1" +einops = "0.8.1" +einx = "0.3.0" # --------------------------------------------------------------------------- # --- NVIDIA GPU configuration --- @@ -65,7 +72,8 @@ torch-directml = "^0.2.5.dev240914" # --------------------------------------------------------------------------- # Depndenciees for temp_tools -autoflake = "^2.3.1" +transformers = "^4.49.0" + [tool.poetry.group.dev.dependencies] black = "^25.1.0" diff --git a/spark/cli/SparkTTS.py b/spark/cli/SparkTTS.py new file mode 100644 index 0000000..9df3d28 --- /dev/null +++ b/spark/cli/SparkTTS.py @@ -0,0 +1,236 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import torch +from typing import Tuple +from pathlib import Path +from transformers import AutoTokenizer, AutoModelForCausalLM + +from spark.sparktts.utils.file import load_config +from spark.sparktts.models.audio_tokenizer import BiCodecTokenizer +from spark.sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP + + +class SparkTTS: + """ + Spark-TTS for text-to-speech generation. + """ + + def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): + """ + Initializes the SparkTTS model with the provided configurations and device. + + Args: + model_dir (Path): Directory containing the model and config files. + device (torch.device): The device (CPU/GPU) to run the model on. + """ + self.device = device + self.model_dir = model_dir + self.configs = load_config(f"{model_dir}/config.yaml") + self.sample_rate = self.configs["sample_rate"] + self._initialize_inference() + + def _initialize_inference(self): + """Initializes the tokenizer, model, and audio tokenizer for inference.""" + self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") + self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") + self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) + self.model.to(self.device) + + def process_prompt( + self, + text: str, + prompt_speech_path: Path, + prompt_text: str = None, + ) -> Tuple[str, torch.Tensor]: + """ + Process input for voice cloning. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + + Return: + Tuple[str, torch.Tensor]: Input prompt; global tokens + """ + + global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( + prompt_speech_path + ) + global_tokens = "".join( + [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] + ) + + # Prepare the input tokens for the model + if prompt_text is not None: + semantic_tokens = "".join( + [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] + ) + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + prompt_text, + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + "<|start_semantic_token|>", + semantic_tokens, + ] + else: + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + ] + + inputs = "".join(inputs) + + return inputs, global_token_ids + + def process_prompt_control( + self, + gender: str, + pitch: str, + speed: str, + text: str, + ): + """ + Process input for voice creation. + + Args: + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + text (str): The text input to be converted to speech. + + Return: + str: Input prompt + """ + assert gender in GENDER_MAP.keys() + assert pitch in LEVELS_MAP.keys() + assert speed in LEVELS_MAP.keys() + + gender_id = GENDER_MAP[gender] + pitch_level_id = LEVELS_MAP[pitch] + speed_level_id = LEVELS_MAP[speed] + + pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" + speed_label_tokens = f"<|speed_label_{speed_level_id}|>" + gender_tokens = f"<|gender_{gender_id}|>" + + attribte_tokens = "".join( + [gender_tokens, pitch_label_tokens, speed_label_tokens] + ) + + control_tts_inputs = [ + TASK_TOKEN_MAP["controllable_tts"], + "<|start_content|>", + text, + "<|end_content|>", + "<|start_style_label|>", + attribte_tokens, + "<|end_style_label|>", + ] + + return "".join(control_tts_inputs) + + @torch.no_grad() + def inference( + self, + text: str, + prompt_speech_path: Path = None, + prompt_text: str = None, + gender: str = None, + pitch: str = None, + speed: str = None, + temperature: float = 0.8, + top_k: float = 50, + top_p: float = 0.95, + ) -> torch.Tensor: + """ + Performs inference to generate speech from text, incorporating prompt audio and/or text. + + Args: + text (str): The text input to be converted to speech. + prompt_speech_path (Path): Path to the audio file used as a prompt. + prompt_text (str, optional): Transcript of the prompt audio. + gender (str): female | male. + pitch (str): very_low | low | moderate | high | very_high + speed (str): very_low | low | moderate | high | very_high + temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. + top_k (float, optional): Top-k sampling parameter. Default is 50. + top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. + + Returns: + torch.Tensor: Generated waveform as a tensor. + """ + if gender is not None: + prompt = self.process_prompt_control(gender, pitch, speed, text) + + else: + prompt, global_token_ids = self.process_prompt( + text, prompt_speech_path, prompt_text + ) + model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) + + # Generate speech using the model + generated_ids = self.model.generate( + **model_inputs, + max_new_tokens=3000, + do_sample=True, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + # Trim the output tokens to remove the input tokens + generated_ids = [ + output_ids[len(input_ids) :] + for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + # Decode the generated tokens into text + predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + # Extract semantic token IDs from the generated text + pred_semantic_ids = ( + torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) + .long() + .unsqueeze(0) + ) + + if gender is not None: + global_token_ids = ( + torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) + .long() + .unsqueeze(0) + .unsqueeze(0) + ) + + # Convert semantic tokens back to waveform + wav = self.audio_tokenizer.detokenize( + global_token_ids.to(self.device).squeeze(0), + pred_semantic_ids.to(self.device), + ) + + return wav \ No newline at end of file diff --git a/spark/cli/inference.py b/spark/cli/inference.py new file mode 100644 index 0000000..1c87d5f --- /dev/null +++ b/spark/cli/inference.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import argparse +import torch +import soundfile as sf +import logging +from datetime import datetime +import platform + +from spark.cli.SparkTTS import SparkTTS + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Run TTS inference.") + + parser.add_argument( + "--model_dir", + type=str, + default="spark/pretrained_models/Spark-TTS-0.5B", + help="Path to the model directory", + ) + parser.add_argument( + "--save_dir", + type=str, + default="spark/example/results", + help="Directory to save generated audio files", + ) + parser.add_argument("--device", type=int, default=0, help="CUDA device number") + parser.add_argument( + "--text", type=str, required=True, help="Text for TTS generation" + ) + parser.add_argument("--prompt_text", type=str, help="Transcript of prompt audio") + parser.add_argument( + "--prompt_speech_path", + type=str, + help="Path to the prompt audio file", + ) + parser.add_argument("--gender", choices=["male", "female"]) + parser.add_argument( + "--pitch", choices=["very_low", "low", "moderate", "high", "very_high"] + ) + parser.add_argument( + "--speed", choices=["very_low", "low", "moderate", "high", "very_high"] + ) + return parser.parse_args() + + +def run_tts(args): + """Perform TTS inference and save the generated audio.""" + logging.info(f"Using model from: {args.model_dir}") + logging.info(f"Saving audio to: {args.save_dir}") + + # Ensure the save directory exists + os.makedirs(args.save_dir, exist_ok=True) + + # Convert device argument to torch.device + if platform.system() == "Darwin" and torch.backends.mps.is_available(): + # macOS with MPS support (Apple Silicon) + device = torch.device(f"mps:{args.device}") + logging.info(f"Using MPS device: {device}") + elif torch.cuda.is_available(): + # System with CUDA support + device = torch.device(f"cuda:{args.device}") + logging.info(f"Using CUDA device: {device}") + else: + # Fall back to CPU + device = torch.device("cpu") + logging.info("GPU acceleration not available, using CPU") + + # Initialize the model + model = SparkTTS(args.model_dir, device) + + # Generate unique filename using timestamp + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + save_path = os.path.join(args.save_dir, f"{timestamp}.wav") + + logging.info("Starting inference...") + + # Perform inference and save the output audio + with torch.no_grad(): + wav = model.inference( + args.text, + args.prompt_speech_path, + prompt_text=args.prompt_text, + gender=args.gender, + pitch=args.pitch, + speed=args.speed, + ) + sf.write(save_path, wav, samplerate=16000) + + logging.info(f"Audio saved at: {save_path}") + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + + args = parse_args() + run_tts(args) diff --git a/spark/runtime/triton_trtllm/README.md b/spark/runtime/triton_trtllm/README.md new file mode 100644 index 0000000..babbd56 --- /dev/null +++ b/spark/runtime/triton_trtllm/README.md @@ -0,0 +1,45 @@ +## Nvidia Triton Inference Serving Best Practice for Spark TTS + +### Quick Start +Directly launch the service using docker compose. +```sh +docker compose up +``` + +### Build Image +Build the docker image from scratch. +```sh +docker build . -f Dockerfile.server -t soar97/triton-spark-tts:25.02 +``` + +### Create Docker Container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "spark-tts-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-spark-tts:25.02 +``` + +### Export Models to TensorRT-LLM and Launch Server +Inside docker container, we would follow the official guide of TensorRT-LLM to build TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen). + +```sh +bash run.sh 0 3 +``` +### Simple HTTP client +```sh +python3 client_http.py +``` + +### Benchmark using Dataset +```sh +num_task=2 +python3 client_grpc.py --num-tasks $num_task --huggingface-dataset yuekai/seed_tts --split-name wenetspeech4tts +``` + +### Benchmark Results +Decoding on a single L20 GPU, using 26 different prompt_audio/target_text pairs, total audio duration 169 secs. + +| Model | Note | Concurrency | Avg Latency | RTF | +|-------|-----------|-----------------------|---------|--| +| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 1 | 876.24 ms | 0.1362| +| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 2 | 920.97 ms | 0.0737| +| Spark-TTS-0.5B | [Code Commit](https://github.com/SparkAudio/Spark-TTS/tree/4d769ff782a868524f29e0be851ca64f8b22ebf1/runtime/triton_trtllm) | 4 | 1611.51 ms | 0.0704| \ No newline at end of file diff --git a/spark/runtime/triton_trtllm/client_grpc.py b/spark/runtime/triton_trtllm/client_grpc.py new file mode 100644 index 0000000..0aabb47 --- /dev/null +++ b/spark/runtime/triton_trtllm/client_grpc.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Nvidia (authors: Yuekai Zhang) +# 2023 Recurrent.ai (authors: Songtao Shi) +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script supports to load dataset from huggingface and sends it to the server +for decoding, in parallel. + +Usage: +# For offline Spark-TTS-0.5B +# huggingface dataset + num_task=2 + python3 client_grpc.py \ + --server-addr localhost \ + --model-name spark_tts \ + --num-tasks $num_task \ + --huggingface-dataset yuekai/seed_tts \ + --split-name wenetspeech4tts \ + --log-dir ./log_concurrent_tasks_${num_task} +""" + +import argparse +import asyncio +import json + +import os +import time +import types +from pathlib import Path + +import numpy as np +import soundfile as sf +import tritonclient +import tritonclient.grpc.aio as grpcclient +from tritonclient.utils import np_to_triton_dtype + + + +def write_triton_stats(stats, summary_file): + with open(summary_file, "w") as summary_f: + model_stats = stats["model_stats"] + # write a note, the log is from triton_client.get_inference_statistics(), to better human readability + summary_f.write( + "The log is parsing from triton_client.get_inference_statistics(), to better human readability. \n" + ) + summary_f.write("To learn more about the log, please refer to: \n") + summary_f.write( + "1. https://github.com/triton-inference-server/server/blob/main/docs/user_guide/metrics.md \n" + ) + summary_f.write( + "2. https://github.com/triton-inference-server/server/issues/5374 \n\n" + ) + summary_f.write( + "To better improve throughput, we always would like let requests wait in the queue for a while, and then execute them with a larger batch size. \n" + ) + summary_f.write( + "However, there is a trade-off between the increased queue time and the increased batch size. \n" + ) + summary_f.write( + "You may change 'max_queue_delay_microseconds' and 'preferred_batch_size' in the model configuration file to achieve this. \n" + ) + summary_f.write( + "See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#delayed-batching for more details. \n\n" + ) + for model_state in model_stats: + if "last_inference" not in model_state: + continue + summary_f.write(f"model name is {model_state['name']} \n") + model_inference_stats = model_state["inference_stats"] + total_queue_time_s = int(model_inference_stats["queue"]["ns"]) / 1e9 + total_infer_time_s = int(model_inference_stats["compute_infer"]["ns"]) / 1e9 + total_input_time_s = int(model_inference_stats["compute_input"]["ns"]) / 1e9 + total_output_time_s = ( + int(model_inference_stats["compute_output"]["ns"]) / 1e9 + ) + summary_f.write( + f"queue time {total_queue_time_s:<5.2f} s, compute infer time {total_infer_time_s:<5.2f} s, compute input time {total_input_time_s:<5.2f} s, compute output time {total_output_time_s:<5.2f} s \n" # noqa + ) + model_batch_stats = model_state["batch_stats"] + for batch in model_batch_stats: + batch_size = int(batch["batch_size"]) + compute_input = batch["compute_input"] + compute_output = batch["compute_output"] + compute_infer = batch["compute_infer"] + batch_count = int(compute_infer["count"]) + assert ( + compute_infer["count"] + == compute_output["count"] + == compute_input["count"] + ) + compute_infer_time_ms = int(compute_infer["ns"]) / 1e6 + compute_input_time_ms = int(compute_input["ns"]) / 1e6 + compute_output_time_ms = int(compute_output["ns"]) / 1e6 + summary_f.write( + f"execuate inference with batch_size {batch_size:<2} total {batch_count:<5} times, total_infer_time {compute_infer_time_ms:<9.2f} ms, avg_infer_time {compute_infer_time_ms:<9.2f}/{batch_count:<5}={compute_infer_time_ms/batch_count:.2f} ms, avg_infer_time_per_sample {compute_infer_time_ms:<9.2f}/{batch_count:<5}/{batch_size}={compute_infer_time_ms/batch_count/batch_size:.2f} ms \n" # noqa + ) + # summary_f.write( + # f"input {compute_input_time_ms:<9.2f} ms, avg {compute_input_time_ms/batch_count:.2f} ms, " # noqa + # ) + # summary_f.write( + # f"output {compute_output_time_ms:<9.2f} ms, avg {compute_output_time_ms/batch_count:.2f} ms \n" # noqa + # ) + + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=8001, + help="Grpc port of the triton server, default is 8001", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default=None, + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="", + help="", + ) + + parser.add_argument( + "--huggingface-dataset", + type=str, + default="yuekai/seed_tts", + help="dataset name in huggingface dataset hub", + ) + + parser.add_argument( + "--split-name", + type=str, + default="wenetspeech4tts", + choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"], + help="dataset split name, default is 'test'", + ) + + parser.add_argument( + "--manifest-path", + type=str, + default=None, + help="Path to the manifest dir which includes wav.scp trans.txt files.", + ) + + parser.add_argument( + "--model-name", + type=str, + default="f5_tts", + choices=[ + "f5_tts", "spark_tts" + ], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--num-tasks", + type=int, + default=1, + help="Number of concurrent tasks for sending", + ) + + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="Controls how frequently we print the log.", + ) + + parser.add_argument( + "--compute-wer", + action="store_true", + default=False, + help="""True to compute WER. + """, + ) + + parser.add_argument( + "--log-dir", + type=str, + required=False, + default="./tmp", + help="log directory", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="Inference batch_size per request for offline mode.", + ) + + return parser.parse_args() + + +def load_audio(wav_path, target_sample_rate=16000): + assert target_sample_rate == 16000, "hard coding in server" + if isinstance(wav_path, dict): + waveform = wav_path["array"] + sample_rate = wav_path["sampling_rate"] + else: + waveform, sample_rate = sf.read(wav_path) + if sample_rate != target_sample_rate: + from scipy.signal import resample + num_samples = int(len(waveform) * (target_sample_rate / sample_rate)) + waveform = resample(waveform, num_samples) + return waveform, target_sample_rate + +async def send( + manifest_item_list: list, + name: str, + triton_client: tritonclient.grpc.aio.InferenceServerClient, + protocol_client: types.ModuleType, + log_interval: int, + model_name: str, + padding_duration: int = None, + audio_save_dir: str = "./", +): + total_duration = 0.0 + results = [] + latency_data = [] + task_id = int(name[5:]) + + print(f"manifest_item_list: {manifest_item_list}") + for i, item in enumerate(manifest_item_list): + if i % log_interval == 0: + print(f"{name}: {i}/{len(manifest_item_list)}") + waveform, sample_rate = load_audio(item["audio_filepath"], target_sample_rate=16000) + duration = len(waveform) / sample_rate + lengths = np.array([[len(waveform)]], dtype=np.int32) + + reference_text, target_text = item["reference_text"], item["target_text"] + + estimated_target_duration = duration / len(reference_text) * len(target_text) + + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(duration) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + inputs = [ + protocol_client.InferInput( + "reference_wav", samples.shape, np_to_triton_dtype(samples.dtype) + ), + protocol_client.InferInput( + "reference_wav_len", lengths.shape, np_to_triton_dtype(lengths.dtype) + ), + protocol_client.InferInput("reference_text", [1, 1], "BYTES"), + protocol_client.InferInput("target_text", [1, 1], "BYTES") + ] + inputs[0].set_data_from_numpy(samples) + inputs[1].set_data_from_numpy(lengths) + + input_data_numpy = np.array([reference_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[2].set_data_from_numpy(input_data_numpy) + + input_data_numpy = np.array([target_text], dtype=object) + input_data_numpy = input_data_numpy.reshape((1, 1)) + inputs[3].set_data_from_numpy(input_data_numpy) + + outputs = [protocol_client.InferRequestedOutput("waveform")] + + sequence_id = 100000000 + i + task_id * 10 + start = time.time() + response = await triton_client.infer( + model_name, inputs, request_id=str(sequence_id), outputs=outputs + ) + + audio = response.as_numpy("waveform").reshape(-1) + + end = time.time() - start + + audio_save_path = os.path.join( + audio_save_dir, f"{item['target_audio_path']}.wav" + ) + sf.write(audio_save_path, audio, 16000, "PCM_16") + + latency_data.append((end, estimated_target_duration)) + total_duration += estimated_target_duration + + return total_duration, latency_data + +def load_manifests(manifest_path): + with open(manifest_path, "r") as f: + manifest_list = [] + for line in f: + assert len(line.strip().split("|")) == 4 + utt, prompt_text, prompt_wav, gt_text = line.strip().split("|") + utt = Path(utt).stem + # gt_wav = os.path.join(os.path.dirname(manifest_path), "wavs", utt + ".wav") + if not os.path.isabs(prompt_wav): + prompt_wav = os.path.join(os.path.dirname(manifest_path), prompt_wav) + manifest_list.append( + { + "audio_filepath": prompt_wav, + "reference_text": prompt_text, + "target_text": gt_text, + "target_audio_path": utt + } + ) + return manifest_list + + +def split_data(data, k): + n = len(data) + if n < k: + print( + f"Warning: the length of the input list ({n}) is less than k ({k}). Setting k to {n}." + ) + k = n + + quotient = n // k + remainder = n % k + + result = [] + start = 0 + for i in range(k): + if i < remainder: + end = start + quotient + 1 + else: + end = start + quotient + + result.append(data[start:end]) + start = end + + return result + + +async def main(): + args = get_args() + url = f"{args.server_addr}:{args.server_port}" + + triton_client = grpcclient.InferenceServerClient(url=url, verbose=False) + protocol_client = grpcclient + + if args.reference_audio: + args.num_tasks = 1 + args.log_interval = 1 + manifest_item_list = [ + { + "reference_text": args.reference_text, + "target_text": args.target_text, + "audio_filepath": args.reference_audio, + "target_audio_path": "test", + } + ] + elif args.huggingface_dataset: + import datasets + + dataset = datasets.load_dataset( + args.huggingface_dataset, + split=args.split_name, + trust_remote_code=True, + ) + manifest_item_list = [] + for i in range(len(dataset)): + manifest_item_list.append( + { + "audio_filepath": dataset[i]["prompt_audio"], + "reference_text": dataset[i]["prompt_text"], + "target_audio_path": dataset[i]["id"], + "target_text": dataset[i]["target_text"], + } + ) + else: + manifest_item_list = load_manifests(args.manifest_path) + + args.num_tasks = min(args.num_tasks, len(manifest_item_list)) + manifest_item_list = split_data(manifest_item_list, args.num_tasks) + + os.makedirs(args.log_dir, exist_ok=True) + tasks = [] + start_time = time.time() + for i in range(args.num_tasks): + task = asyncio.create_task( + send( + manifest_item_list[i], + name=f"task-{i}", + triton_client=triton_client, + protocol_client=protocol_client, + log_interval=args.log_interval, + model_name=args.model_name, + audio_save_dir=args.log_dir, + padding_duration=None, + ) + ) + tasks.append(task) + + ans_list = await asyncio.gather(*tasks) + + end_time = time.time() + elapsed = end_time - start_time + + + total_duration = 0.0 + latency_data = [] + for ans in ans_list: + total_duration += ans[0] + latency_data += ans[1] + + rtf = elapsed / total_duration + + s = f"RTF: {rtf:.4f}\n" + s += f"total_duration: {total_duration:.3f} seconds\n" + s += f"({total_duration/3600:.2f} hours)\n" + s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n" + + latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] + latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 + latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 + s += f"latency_variance: {latency_variance:.2f}\n" + s += f"latency_50_percentile_ms: {np.percentile(latency_list, 50) * 1000.0:.2f}\n" + s += f"latency_90_percentile_ms: {np.percentile(latency_list, 90) * 1000.0:.2f}\n" + s += f"latency_95_percentile_ms: {np.percentile(latency_list, 95) * 1000.0:.2f}\n" + s += f"latency_99_percentile_ms: {np.percentile(latency_list, 99) * 1000.0:.2f}\n" + s += f"average_latency_ms: {latency_ms:.2f}\n" + + print(s) + if args.manifest_path: + name = Path(args.manifest_path).stem + elif args.split_name: + name = args.split_name + with open(f"{args.log_dir}/rtf-{name}.txt", "w") as f: + f.write(s) + + stats = await triton_client.get_inference_statistics(model_name="", as_json=True) + write_triton_stats(stats, f"{args.log_dir}/stats_summary-{name}.txt") + + metadata = await triton_client.get_model_config(model_name=args.model_name, as_json=True) + with open(f"{args.log_dir}/model_config-{name}.json", "w") as f: + json.dump(metadata, f, indent=4) +if __name__ == "__main__": + asyncio.run(main()) diff --git a/spark/runtime/triton_trtllm/client_http.py b/spark/runtime/triton_trtllm/client_http.py new file mode 100644 index 0000000..a0ebc0a --- /dev/null +++ b/spark/runtime/triton_trtllm/client_http.py @@ -0,0 +1,165 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import requests +import soundfile as sf +import json +import numpy as np +import argparse + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-url", + type=str, + default="localhost:8000", + help="Address of the server", + ) + + parser.add_argument( + "--reference-audio", + type=str, + default="../../spark/example/prompt_audio.wav", + help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", + ) + + parser.add_argument( + "--reference-text", + type=str, + default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", + help="", + ) + + parser.add_argument( + "--target-text", + type=str, + default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", + help="", + ) + + parser.add_argument( + "--model-name", + type=str, + default="spark_tts", + choices=[ + "f5_tts", "spark_tts" + ], + help="triton model_repo module name to request: transducer for k2, attention_rescoring for wenet offline, streaming_wenet for wenet streaming, infer_pipeline for paraformer large offline", + ) + + parser.add_argument( + "--output-audio", + type=str, + default="output.wav", + help="Path to save the output audio", + ) + return parser.parse_args() + +def prepare_request( + waveform, + reference_text, + target_text, + sample_rate=16000, + padding_duration: int = None, + audio_save_dir: str = "./", +): + assert len(waveform.shape) == 1, "waveform should be 1D" + lengths = np.array([[len(waveform)]], dtype=np.int32) + if padding_duration: + # padding to nearset 10 seconds + samples = np.zeros( + ( + 1, + padding_duration + * sample_rate + * ((int(duration) // padding_duration) + 1), + ), + dtype=np.float32, + ) + + samples[0, : len(waveform)] = waveform + else: + samples = waveform + + samples = samples.reshape(1, -1).astype(np.float32) + + data = { + "inputs":[ + { + "name": "reference_wav", + "shape": samples.shape, + "datatype": "FP32", + "data": samples.tolist() + }, + { + "name": "reference_wav_len", + "shape": lengths.shape, + "datatype": "INT32", + "data": lengths.tolist(), + }, + { + "name": "reference_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [reference_text] + }, + { + "name": "target_text", + "shape": [1, 1], + "datatype": "BYTES", + "data": [target_text] + } + ] + } + + return data + +if __name__ == "__main__": + args = get_args() + server_url = args.server_url + if not server_url.startswith(("http://", "https://")): + server_url = f"http://{server_url}" + + url = f"{server_url}/v2/models/{args.model_name}/infer" + waveform, sr = sf.read(args.reference_audio) + assert sr == 16000, "sample rate hardcoded in server" + + samples = np.array(waveform, dtype=np.float32) + data = prepare_request(samples, args.reference_text, args.target_text) + + rsp = requests.post( + url, + headers={"Content-Type": "application/json"}, + json=data, + verify=False, + params={"request_id": '0'} + ) + result = rsp.json() + audio = result["outputs"][0]["data"] + audio = np.array(audio, dtype=np.float32) + sf.write(args.output_audio, audio, 16000, "PCM_16") \ No newline at end of file diff --git a/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py b/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py new file mode 100644 index 0000000..7da820f --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py @@ -0,0 +1,137 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +import os +import numpy as np + +from spark.sparktts.models.audio_tokenizer import BiCodecTokenizer + +class TritonPythonModel: + """Triton Python model for audio tokenization. + + This model takes reference audio input and extracts semantic and global tokens + using BiCodec tokenizer. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + + # Initialize tokenizer + self.device = torch.device("cuda") + self.audio_tokenizer = BiCodecTokenizer(model_params["model_dir"], + device=self.device) + + def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: + """Extract reference audio clip for speaker embedding. + + Args: + wav: Input waveform array + + Returns: + Reference clip of fixed duration + """ + SAMPLE_RATE = 16000 + REF_SEGMENT_DURATION = 6 # seconds + LATENT_HOP_LENGTH = 320 + + ref_segment_length = ( + int(SAMPLE_RATE * REF_SEGMENT_DURATION) + // LATENT_HOP_LENGTH + * LATENT_HOP_LENGTH + ) + wav_length = len(wav) + + if ref_segment_length > wav_length: + # Repeat and truncate if input is too short + repeat_times = ref_segment_length // wav_length + 1 + wav = np.tile(wav, repeat_times) + + return wav[:ref_segment_length] + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing tokenized outputs + """ + reference_wav_list = [] + reference_wav_ref_clip_list = [] + + # Process each request in batch + for request in requests: + # Extract input tensors + wav_array = pb_utils.get_input_tensor_by_name( + request, "reference_wav").as_numpy() + wav_len = pb_utils.get_input_tensor_by_name( + request, "reference_wav_len").as_numpy().item() + + # Prepare inputs + wav = wav_array[:, :wav_len].squeeze(0) + reference_wav_list.append(wav) + + wav_ref_clip = self.get_ref_clip(wav) + reference_wav_ref_clip_list.append(torch.from_numpy(wav_ref_clip)) + + # Batch process through tokenizer + ref_wav_clip_tensor = torch.stack(reference_wav_ref_clip_list, dim=0) + wav2vec2_features = self.audio_tokenizer.extract_wav2vec2_features( + reference_wav_list) + + audio_tokenizer_input = { + "ref_wav": ref_wav_clip_tensor.to(self.device), + "feat": wav2vec2_features.to(self.device), + } + semantic_tokens, global_tokens = self.audio_tokenizer.model.tokenize( + audio_tokenizer_input) + + # Prepare responses + responses = [] + for i in range(len(requests)): + global_tokens_tensor = pb_utils.Tensor.from_dlpack( + "global_tokens", to_dlpack(global_tokens[i])) + semantic_tokens_tensor = pb_utils.Tensor.from_dlpack( + "semantic_tokens", to_dlpack(semantic_tokens[i])) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[global_tokens_tensor, semantic_tokens_tensor]) + responses.append(inference_response) + + return responses diff --git a/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt b/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt new file mode 100644 index 0000000..1022ed1 --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt @@ -0,0 +1,58 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "audio_tokenizer" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + } +] +output [ + { + name: "global_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "semantic_tokens" + data_type: TYPE_INT32 + dims: [-1] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/spark/runtime/triton_trtllm/model_repo/spark_tts/1/model.py b/spark/runtime/triton_trtllm/model_repo/spark_tts/1/model.py new file mode 100644 index 0000000..e43191a --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/spark_tts/1/model.py @@ -0,0 +1,311 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +import re +from typing import Dict, List, Tuple, Optional, Union + +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer + +from spark.sparktts.utils.token_parser import TASK_TOKEN_MAP + +def process_prompt( + text: str, + prompt_text: Optional[str] = None, + global_token_ids: torch.Tensor = None, + semantic_token_ids: torch.Tensor = None, +) -> Tuple[str, torch.Tensor]: + """ + Process input for voice cloning. + + Args: + text: The text input to be converted to speech. + prompt_text: Transcript of the prompt audio. + global_token_ids: Global token IDs extracted from reference audio. + semantic_token_ids: Semantic token IDs extracted from reference audio. + + Returns: + Tuple containing the formatted input prompt and global token IDs. + """ + # Convert global tokens to string format + global_tokens = "".join( + [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] + ) + + + # Prepare the input tokens for the model + if prompt_text is not None: + # Include semantic tokens when prompt text is provided + semantic_tokens = "".join( + [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] + ) + + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + prompt_text, + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + "<|start_semantic_token|>", + semantic_tokens, + ] + else: + # Without prompt text, exclude semantic tokens + inputs = [ + TASK_TOKEN_MAP["tts"], + "<|start_content|>", + text, + "<|end_content|>", + "<|start_global_token|>", + global_tokens, + "<|end_global_token|>", + ] + + # Join all input components into a single string + inputs = "".join(inputs) + return inputs, global_token_ids + + +class TritonPythonModel: + """Triton Python model for Spark TTS. + + This model orchestrates the end-to-end TTS pipeline by coordinating + between audio tokenizer, LLM, and vocoder components. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {k: v["string_value"] for k, v in parameters.items()} + + # Initialize tokenizer + llm_tokenizer_dir = model_params["llm_tokenizer_dir"] + self.tokenizer = AutoTokenizer.from_pretrained(llm_tokenizer_dir) + self.device = torch.device("cuda") + self.decoupled = False + + def forward_llm(self, input_ids): + """ + Prepares the response from the language model based on the provided + inputs. Creates a `pb_utils.InferenceRequest` object with passed + `llm_request_inputs` to send to a decoupled TensorRTLLM model. + For each response from the language model: + - Checks for errors and raise an exception if any are found. + - Extracts the "output_ids" tensor from the response. + - Determines the finish reason based on the presence of the + end-of-sequence token or reaching the maximum length. + - Appends the generated token IDs to `output_ids`. + - If the finish reason is determined, decodes the output IDs to text + and prepares the final response. + + The final response includes the generated text, finish reason, + completion tokens, prompt tokens, and total tokens. + + Parameters + ---------- + - llm_request_inputs (dict): A dictionary containing the inputs for the language model. + + Returns + ------- + - pb_utils.InferenceResponse: The response object containing the generated text and additional metadata. + """ + # convert input_ids to numpy, with shape [1, sequence_length] + input_ids = input_ids.cpu().numpy() + max_tokens = 512 + input_dict = { + "request_output_len": np.array([[max_tokens]], dtype=np.int32), + "end_id": np.array([[self.tokenizer.eos_token_id]], dtype=np.int32), + "pad_id": np.array([[self.tokenizer.pad_token_id]], dtype=np.int32), + "streaming": np.array([[self.decoupled]], dtype=np.bool_), + "runtime_top_p": np.array([[0.95]], dtype=np.float32), + "runtime_top_k": np.array([[50]], dtype=np.int32), + "temperature": np.array([[0.8]], dtype=np.float32), + "input_ids": input_ids, + "input_lengths": np.array([[input_ids.shape[1]]], dtype=np.int32), + } + + # Convert inputs to Triton tensors + input_tensor_list = [ + pb_utils.Tensor(k, v) for k, v in input_dict.items() + ] + + # Create and execute inference request + llm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + requested_output_names=["output_ids", "sequence_length"], + inputs=input_tensor_list, + ) + + llm_response = llm_request.exec(decoupled=self.decoupled) + if llm_response.has_error(): + raise pb_utils.TritonModelException(llm_response.error().message()) + + # Extract and process output + output_ids = pb_utils.get_output_tensor_by_name( + llm_response, "output_ids").as_numpy() + seq_lens = pb_utils.get_output_tensor_by_name( + llm_response, "sequence_length").as_numpy() + + # Get actual output IDs up to the sequence length + actual_output_ids = output_ids[0][0][:seq_lens[0][0]] + + return actual_output_ids + + def forward_audio_tokenizer(self, wav, wav_len): + """Forward pass through the audio tokenizer component. + + Args: + wav: Input waveform tensor + wav_len: Waveform length tensor + + Returns: + Tuple of global and semantic tokens + """ + inference_request = pb_utils.InferenceRequest( + model_name='audio_tokenizer', + requested_output_names=['global_tokens', 'semantic_tokens'], + inputs=[wav, wav_len] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output tensors + global_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'global_tokens') + global_tokens = torch.utils.dlpack.from_dlpack(global_tokens.to_dlpack()).cpu() + + semantic_tokens = pb_utils.get_output_tensor_by_name(inference_response, 'semantic_tokens') + semantic_tokens = torch.utils.dlpack.from_dlpack(semantic_tokens.to_dlpack()).cpu() + + return global_tokens, semantic_tokens + + def forward_vocoder(self, global_token_ids: torch.Tensor, pred_semantic_ids: torch.Tensor) -> torch.Tensor: + """Forward pass through the vocoder component. + + Args: + global_token_ids: Global token IDs tensor + pred_semantic_ids: Predicted semantic token IDs tensor + + Returns: + Generated waveform tensor + """ + # Convert tensors to Triton format + global_token_ids_tensor = pb_utils.Tensor.from_dlpack("global_tokens", to_dlpack(global_token_ids)) + pred_semantic_ids_tensor = pb_utils.Tensor.from_dlpack("semantic_tokens", to_dlpack(pred_semantic_ids)) + + # Create and execute inference request + inference_request = pb_utils.InferenceRequest( + model_name='vocoder', + requested_output_names=['waveform'], + inputs=[global_token_ids_tensor, pred_semantic_ids_tensor] + ) + + inference_response = inference_request.exec() + if inference_response.has_error(): + raise pb_utils.TritonModelException(inference_response.error().message()) + + # Extract and convert output waveform + waveform = pb_utils.get_output_tensor_by_name(inference_response, 'waveform') + waveform = torch.utils.dlpack.from_dlpack(waveform.to_dlpack()).cpu() + + return waveform + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated audio + """ + responses = [] + + for request in requests: + # Extract input tensors + wav = pb_utils.get_input_tensor_by_name(request, "reference_wav") + wav_len = pb_utils.get_input_tensor_by_name(request, "reference_wav_len") + + # Process reference audio through audio tokenizer + global_tokens, semantic_tokens = self.forward_audio_tokenizer(wav, wav_len) + + # Extract text inputs + reference_text = pb_utils.get_input_tensor_by_name(request, "reference_text").as_numpy() + reference_text = reference_text[0][0].decode('utf-8') + + target_text = pb_utils.get_input_tensor_by_name(request, "target_text").as_numpy() + target_text = target_text[0][0].decode('utf-8') + + # Prepare prompt for LLM + prompt, global_token_ids = process_prompt( + text=target_text, + prompt_text=reference_text, + global_token_ids=global_tokens, + semantic_token_ids=semantic_tokens, + ) + + + # Tokenize prompt for LLM + model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) + input_ids = model_inputs.input_ids.to(torch.int32) + + # Generate semantic tokens with LLM + generated_ids = self.forward_llm(input_ids) + + # Decode and extract semantic token IDs from generated text + predicted_text = self.tokenizer.batch_decode([generated_ids], skip_special_tokens=True)[0] + pred_semantic_ids = ( + torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicted_text)]) + .unsqueeze(0).to(torch.int32) + ) + + + # Generate audio with vocoder + audio = self.forward_vocoder( + global_token_ids.to(self.device), + pred_semantic_ids.to(self.device), + ) + + # Prepare response + audio_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(audio)) + inference_response = pb_utils.InferenceResponse(output_tensors=[audio_tensor]) + responses.append(inference_response) + + return responses diff --git a/spark/runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt b/spark/runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt new file mode 100644 index 0000000..6a2f492 --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/spark_tts/config.pbtxt @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "spark_tts" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "llm_tokenizer_dir", + value: {string_value:"${llm_tokenizer_dir}"} + } +] + +input [ + { + name: "reference_wav" + data_type: TYPE_FP32 + dims: [-1] + optional: True + }, + { + name: "reference_wav_len" + data_type: TYPE_INT32 + dims: [1] + optional: True + }, + { + name: "reference_text" + data_type: TYPE_STRING + dims: [1] + }, + { + name: "target_text" + data_type: TYPE_STRING + dims: [1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: ${bls_instance_num} + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/spark/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep b/spark/runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/spark/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt b/spark/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 0000000..3da7894 --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,857 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "${triton_backend}" +max_batch_size: ${triton_max_batch_size} + +model_transaction_policy { + decoupled: ${decoupled_mode} +} + +dynamic_batching { + preferred_batch_size: [ ${triton_max_batch_size} ] + max_queue_delay_microseconds: ${max_queue_delay_microseconds} + default_queue_policy: { max_queue_size: ${max_queue_size} } +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_input_features" + data_type: ${encoder_input_features_data_type} + dims: [ -1, -1 ] + allow_ragged_batch: true + optional: true + }, + { + name: "encoder_output_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "num_return_sequences" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "decoder_input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + reshape: { shape: [ ] } + }, + { + name: "draft_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "draft_acceptance_threshold" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_min" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_decay" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p_reset_ids" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "early_stopping" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_search_diversity_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_perf_metrics" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "exclude_input_in_output" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_table_extra_ids" + data_type: TYPE_UINT64 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # cross_attention_mask shape `[bs, seq_len, num_images*num_tiles]` + { + name: "cross_attention_mask" + data_type: TYPE_BOOL + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # Mrope param when mrope is used + { + name: "mrope_rotary_cos_sin" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + }, + { + name: "mrope_position_deltas" + data_type: TYPE_INT64 + dims: [ 1 ] + optional: true + }, + # the unique task ID for the given LoRA. + # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given. + # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`. + # If the cache is full the oldest LoRA will be evicted to make space for new ones. An error is returned if `lora_task_id` is not cached. + { + name: "lora_task_id" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + # skip_cross_attn_blocks shape `[bs, 1]`, only used in mllama + { + name: "skip_cross_attn_blocks" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_starts" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_ends" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_priorities" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_token_range_durations_ms" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_priority" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "retention_decode_duration_ms" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide_type" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "guided_decoding_guide" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_window_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_ngram_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "lookahead_verification_set_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: ${logits_datatype} + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: ${logits_datatype} + dims: [ -1, -1, -1 ] + }, + { + name: "batch_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "sequence_index" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "context_phase_params" + data_type: TYPE_UINT8 + dims: [ -1 ] + }, + { + name: "kv_cache_alloc_new_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_reused_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "kv_cache_alloc_total_blocks" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "arrival_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_scheduled_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "first_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "last_token_time_ns" + data_type: TYPE_INT64 + dims: [ 1 ] + }, + { + name: "acceptance_rate" + data_type: TYPE_FP32 + dims: [ 1 ] + }, + { + name: "total_accepted_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "total_draft_tokens" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "${max_beam_width}" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "${batching_strategy}" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "${engine_dir}" + } +} +parameters: { + key: "encoder_model_path" + value: { + string_value: "${encoder_engine_dir}" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} +parameters: { + key: "sink_token_length" + value: { + string_value: "${sink_token_length}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "${batch_scheduler_policy}" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "${kv_cache_free_gpu_mem_fraction}" + } +} +parameters: { + key: "cross_kv_cache_fraction" + value: { + string_value: "${cross_kv_cache_fraction}" + } +} +parameters: { + key: "kv_cache_host_memory_bytes" + value: { + string_value: "${kv_cache_host_memory_bytes}" + } +} +# kv_cache_onboard_blocks is for internal implementation. +parameters: { + key: "kv_cache_onboard_blocks" + value: { + string_value: "${kv_cache_onboard_blocks}" + } +} +# enable_trt_overlap is deprecated and doesn't have any effect on the runtime +# parameters: { +# key: "enable_trt_overlap" +# value: { +# string_value: "${enable_trt_overlap}" +# } +# } +parameters: { + key: "exclude_input_in_output" + value: { + string_value: "${exclude_input_in_output}" + } +} +parameters: { + key: "cancellation_check_period_ms" + value: { + string_value: "${cancellation_check_period_ms}" + } +} +parameters: { + key: "stats_check_period_ms" + value: { + string_value: "${stats_check_period_ms}" + } +} +parameters: { + key: "iter_stats_max_iterations" + value: { + string_value: "${iter_stats_max_iterations}" + } +} +parameters: { + key: "request_stats_max_iterations" + value: { + string_value: "${request_stats_max_iterations}" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" + } +} +parameters: { + key: "participant_ids" + value: { + string_value: "${participant_ids}" + } +} +parameters: { + key: "lora_cache_optimal_adapter_size" + value: { + string_value: "${lora_cache_optimal_adapter_size}" + } +} +parameters: { + key: "lora_cache_max_adapter_size" + value: { + string_value: "${lora_cache_max_adapter_size}" + } +} +parameters: { + key: "lora_cache_gpu_memory_fraction" + value: { + string_value: "${lora_cache_gpu_memory_fraction}" + } +} +parameters: { + key: "lora_cache_host_memory_bytes" + value: { + string_value: "${lora_cache_host_memory_bytes}" + } +} +parameters: { + key: "lora_prefetch_dir" + value: { + string_value: "${lora_prefetch_dir}" + } +} +parameters: { + key: "decoding_mode" + value: { + string_value: "${decoding_mode}" + } +} +parameters: { + key: "executor_worker_path" + value: { + string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker" + } +} +parameters: { + key: "lookahead_window_size" + value: { + string_value: "${lookahead_window_size}" + } +} +parameters: { + key: "lookahead_ngram_size" + value: { + string_value: "${lookahead_ngram_size}" + } +} +parameters: { + key: "lookahead_verification_set_size" + value: { + string_value: "${lookahead_verification_set_size}" + } +} +parameters: { + key: "medusa_choices" + value: { + string_value: "${medusa_choices}" + } +} +parameters: { + key: "eagle_choices" + value: { + string_value: "${eagle_choices}" + } +} +parameters: { + key: "gpu_weights_percent" + value: { + string_value: "${gpu_weights_percent}" + } +} +parameters: { + key: "enable_context_fmha_fp32_acc" + value: { + string_value: "${enable_context_fmha_fp32_acc}" + } +} +parameters: { + key: "multi_block_mode" + value: { + string_value: "${multi_block_mode}" + } +} +parameters: { + key: "cuda_graph_mode" + value: { + string_value: "${cuda_graph_mode}" + } +} +parameters: { + key: "cuda_graph_cache_size" + value: { + string_value: "${cuda_graph_cache_size}" + } +} +parameters: { + key: "speculative_decoding_fast_logits" + value: { + string_value: "${speculative_decoding_fast_logits}" + } +} +parameters: { + key: "tokenizer_dir" + value: { + string_value: "${tokenizer_dir}" + } +} +parameters: { + key: "guided_decoding_backend" + value: { + string_value: "${guided_decoding_backend}" + } +} +parameters: { + key: "xgrammar_tokenizer_info_path" + value: { + string_value: "${xgrammar_tokenizer_info_path}" + } +} diff --git a/spark/runtime/triton_trtllm/model_repo/vocoder/1/model.py b/spark/runtime/triton_trtllm/model_repo/vocoder/1/model.py new file mode 100644 index 0000000..608d390 --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/vocoder/1/model.py @@ -0,0 +1,106 @@ +# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +import logging +from typing import List, Dict + +import torch +from torch.utils.dlpack import to_dlpack + +import triton_python_backend_utils as pb_utils + +from spark.sparktts.models.bicodec import BiCodec + +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class TritonPythonModel: + """Triton Python model for vocoder. + + This model takes global and semantic tokens as input and generates audio waveforms + using the BiCodec vocoder. + """ + + def initialize(self, args): + """Initialize the model. + + Args: + args: Dictionary containing model configuration + """ + # Parse model parameters + parameters = json.loads(args['model_config'])['parameters'] + model_params = {key: value["string_value"] for key, value in parameters.items()} + model_dir = model_params["model_dir"] + + # Initialize device and vocoder + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Initializing vocoder from {model_dir} on {self.device}") + + self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec") + del self.vocoder.encoder, self.vocoder.postnet + self.vocoder.eval().to(self.device) # Set model to evaluation mode + + logger.info("Vocoder initialized successfully") + + + def execute(self, requests): + """Execute inference on the batched requests. + + Args: + requests: List of inference requests + + Returns: + List of inference responses containing generated waveforms + """ + global_tokens_list, semantic_tokens_list = [], [] + + # Process each request in batch + for request in requests: + global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy() + semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy() + global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device)) + semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device)) + + # Concatenate tokens for batch processing + global_tokens = torch.cat(global_tokens_list, dim=0) + semantic_tokens = torch.cat(semantic_tokens_list, dim=0) + + + # Generate waveforms + with torch.no_grad(): + wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1)) + + # Prepare responses + responses = [] + for i in range(len(requests)): + wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i])) + inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) + responses.append(inference_response) + + return responses diff --git a/spark/runtime/triton_trtllm/model_repo/vocoder/config.pbtxt b/spark/runtime/triton_trtllm/model_repo/vocoder/config.pbtxt new file mode 100644 index 0000000..5a99868 --- /dev/null +++ b/spark/runtime/triton_trtllm/model_repo/vocoder/config.pbtxt @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: "vocoder" +backend: "python" +max_batch_size: ${triton_max_batch_size} +dynamic_batching { + max_queue_delay_microseconds: ${max_queue_delay_microseconds} +} +parameters [ + { + key: "model_dir", + value: {string_value:"${model_dir}"} + } +] + +input [ + { + name: "global_tokens" + data_type: TYPE_INT32 + dims: [-1] + }, + { + name: "semantic_tokens" + data_type: TYPE_INT32 + dims: [-1] + } +] +output [ + { + name: "waveform" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] \ No newline at end of file diff --git a/spark/runtime/triton_trtllm/scripts/convert_checkpoint.py b/spark/runtime/triton_trtllm/scripts/convert_checkpoint.py new file mode 100644 index 0000000..85f948f --- /dev/null +++ b/spark/runtime/triton_trtllm/scripts/convert_checkpoint.py @@ -0,0 +1,335 @@ +import argparse +import os +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +from transformers import AutoConfig + +import tensorrt_llm +from tensorrt_llm._utils import release_gc +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import QWenForCausalLM +from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm.quantization import QuantAlgo + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_dir', type=str, default=None, required=True) + parser.add_argument('--tp_size', + type=int, + default=1, + help='N-way tensor parallelism size') + parser.add_argument('--pp_size', + type=int, + default=1, + help='N-way pipeline parallelism size') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'float16', 'bfloat16', 'float32'], + help= + "The data type for the model weights and activations if not quantized. " + "If 'auto', the data type is automatically inferred from the source model; " + "however, if the source dtype is float32, it is converted to float16.") + parser.add_argument( + '--use_weight_only', + default=False, + action="store_true", + help='Quantize weights for the various GEMMs to INT4/INT8.' + 'See --weight_only_precision to set the precision') + parser.add_argument( + '--disable_weight_only_quant_plugin', + default=False, + action="store_true", + help= + 'By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--weight_only_precision', + const='int8', + type=str, + nargs='?', + default='int8', + choices=['int8', 'int4', 'int4_gptq'], + help= + 'Define the precision for the weights when using weight-only quantization.' + 'You must also use --use_weight_only for that argument to have an impact.' + ) + parser.add_argument( + '--calib_dataset', + type=str, + default='ccdv/cnn_dailymail', + help= + "The huggingface dataset name or the local directory of the dataset for calibration." + ) + parser.add_argument( + "--smoothquant", + "-sq", + type=float, + default=None, + help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" + " to Smoothquant the model, and output int8 weights." + " A good first try is 0.5. Must be in [0, 1]") + parser.add_argument( + '--per_channel', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor for the GEMM\'s result. ' + 'per_channel instead uses a different static scaling factor for each channel. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--per_token', + action="store_true", + default=False, + help= + 'By default, we use a single static scaling factor to scale activations in the int8 range. ' + 'per_token chooses at run time, and for each token, a custom scaling factor. ' + 'The latter is usually more accurate, but a little slower.') + parser.add_argument( + '--int8_kv_cache', + default=False, + action="store_true", + help= + 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' + ) + parser.add_argument( + '--per_group', + default=False, + action="store_true", + help= + 'By default, we use a single static scaling factor to scale weights in the int4 range. ' + 'per_group chooses at run time, and for each group, a custom scaling factor. ' + 'The flag is built for GPTQ/AWQ quantization.') + + parser.add_argument('--group_size', + type=int, + default=128, + help='Group size used in GPTQ quantization.') + + parser.add_argument("--load_model_on_cpu", action="store_true") + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='The path to save the TensorRT-LLM checkpoint') + parser.add_argument( + '--workers', + type=int, + default=1, + help='The number of workers for converting checkpoint in parallel') + parser.add_argument( + '--moe_tp_size', + type=int, + default=-1, + help= + 'N-way tensor parallelism size for MOE, default is tp_size, which will do tp-only for MoE' + ) + parser.add_argument( + '--moe_ep_size', + type=int, + default=-1, + help= + 'N-way expert parallelism size for MOE, default is 1, which will do tp-only for MoE' + ) + args = parser.parse_args() + return args + + +def args_to_quant_config(args: argparse.Namespace) -> QuantConfig: + '''return config dict with quantization info based on the command line args + ''' + quant_config = QuantConfig() + if args.use_weight_only: + if args.weight_only_precision == 'int8': + quant_config.quant_algo = QuantAlgo.W8A16 + elif args.weight_only_precision == 'int4': + quant_config.quant_algo = QuantAlgo.W4A16 + elif args.smoothquant: + quant_config.smoothquant_val = args.smoothquant + if args.per_channel: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN + else: + if args.per_token: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN + else: + quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN + + if args.int8_kv_cache: + quant_config.kv_cache_quant_algo = QuantAlgo.INT8 + + if args.weight_only_precision == 'int4_gptq': + quant_config.group_size = args.group_size + quant_config.has_zero_point = True + quant_config.pre_quant_scale = False + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + + return quant_config + + +def update_quant_config_from_hf(quant_config, hf_config, + override_fields) -> tuple[QuantConfig, dict]: + hf_config_dict = hf_config.to_dict() + if hf_config_dict.get('quantization_config'): + # update the quant_algo, and clamp_val. + if hf_config_dict['quantization_config'].get('quant_method') == 'awq': + logger.info( + "Load quantization configs from huggingface model_config.") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('zero_point', False) + override_fields.update({"use_autoawq": True}) + elif hf_config_dict['quantization_config'].get( + 'quant_method') == 'gptq': + logger.info( + "Load quantization configs from huggingface model_config.") + desc_act = hf_config_dict['quantization_config'].get( + 'desc_act', False) + if desc_act: + raise ValueError("GPTQ with desc_act=True is not implemented!") + quant_config.quant_algo = QuantAlgo.W4A16_GPTQ + quant_config.group_size = hf_config_dict['quantization_config'].get( + 'group_size', 128) + quant_config.has_zero_point = hf_config_dict[ + 'quantization_config'].get('sym', False) + return quant_config, override_fields + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } + + +def convert_and_save_hf(args): + model_dir = args.model_dir + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {} + override_fields.update(args_to_build_options(args)) + quant_config = args_to_quant_config(args) + + try: + hf_config = AutoConfig.from_pretrained(model_dir, + trust_remote_code=True) + quant_config, override_fields = update_quant_config_from_hf( + quant_config, hf_config, override_fields) + except: + logger.warning("AutoConfig cannot load the huggingface config.") + + if args.smoothquant is not None or args.int8_kv_cache: + mapping = Mapping( + world_size=world_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size, + ) + QWenForCausalLM.quantize(args.model_dir, + args.output_dir, + dtype=args.dtype, + mapping=mapping, + quant_config=quant_config, + calib_dataset=args.calib_dataset, + **override_fields) + else: + + def convert_and_save_rank(args, rank): + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size, + moe_tp_size=args.moe_tp_size, + moe_ep_size=args.moe_ep_size) + qwen = QWenForCausalLM.from_hugging_face(model_dir, + args.dtype, + mapping=mapping, + quant_config=quant_config, + **override_fields) + qwen.save_checkpoint(args.output_dir, save_config=(rank == 0)) + del qwen + + execute(args.workers, [convert_and_save_rank] * world_size, args) + release_gc() + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) + else: + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] + exceptions = [] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + traceback.print_exc() + exceptions.append(e) + assert len( + exceptions + ) == 0, "Checkpoint conversion failed, please check error log." + + +def main(): + print(tensorrt_llm.__version__) + args = parse_arguments() + + if (args.moe_tp_size == -1 and args.moe_ep_size == -1): + # moe default to tp-only + args.moe_tp_size = args.tp_size + args.moe_ep_size = 1 + elif (args.moe_tp_size == -1): + args.moe_tp_size = args.tp_size // args.moe_ep_size + elif (args.moe_ep_size == -1): + args.moe_ep_size = args.tp_size // args.moe_tp_size + assert (args.moe_tp_size * args.moe_ep_size == args.tp_size + ), "moe_tp_size * moe_ep_size must equal to tp_size" + + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + assert args.model_dir is not None + convert_and_save_hf(args) + + tok = time.time() + t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) + print(f'Total time of converting checkpoints: {t}') + + +if __name__ == '__main__': + main() diff --git a/spark/runtime/triton_trtllm/scripts/fill_template.py b/spark/runtime/triton_trtllm/scripts/fill_template.py new file mode 100644 index 0000000..5c629f7 --- /dev/null +++ b/spark/runtime/triton_trtllm/scripts/fill_template.py @@ -0,0 +1,70 @@ +#! /usr/bin/env python3 +from argparse import ArgumentParser +from string import Template + + +def split(string, delimiter): + """Split a string using delimiter. Supports escaping. + + Args: + string (str): The string to split. + delimiter (str): The delimiter to split the string with. + + Returns: + list: A list of strings. + """ + result = [] + current = "" + escape = False + for char in string: + if escape: + current += char + escape = False + elif char == delimiter: + result.append(current) + current = "" + elif char == "\\": + escape = True + else: + current += char + result.append(current) + return result + + +def main(file_path, substitutions, in_place): + with open(file_path) as f: + pbtxt = Template(f.read()) + + sub_dict = { + "max_queue_size": 0, + 'max_queue_delay_microseconds': 0, + } + for sub in split(substitutions, ","): + key, value = split(sub, ":") + sub_dict[key] = value + + assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}." + + pbtxt = pbtxt.safe_substitute(sub_dict) + + if in_place: + with open(file_path, "w") as f: + f.write(pbtxt) + else: + print(pbtxt) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("file_path", help="path of the .pbtxt to modify") + parser.add_argument( + "substitutions", + help= + "substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." + ) + parser.add_argument("--in_place", + "-i", + action="store_true", + help="do the operation in-place") + args = parser.parse_args() + main(**vars(args)) diff --git a/spark/sparktts/models/audio_tokenizer.py b/spark/sparktts/models/audio_tokenizer.py new file mode 100644 index 0000000..fcb7699 --- /dev/null +++ b/spark/sparktts/models/audio_tokenizer.py @@ -0,0 +1,163 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import numpy as np + +from pathlib import Path +from typing import Any, Dict, Tuple +from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model + +from spark.sparktts.utils.file import load_config +from spark.sparktts.utils.audio import load_audio +from spark.sparktts.models.bicodec import BiCodec + + +class BiCodecTokenizer: + """BiCodec tokenizer for handling audio input and tokenization.""" + + def __init__(self, model_dir: Path, device: torch.device = None, **kwargs): + super().__init__() + """ + Args: + model_dir: Path to the model directory. + device: Device to run the model on (default is GPU if available). + """ + self.device = device + self.model_dir = model_dir + self.config = load_config(f"{model_dir}/config.yaml") + self._initialize_model() + + def _initialize_model(self): + """Load and initialize the BiCodec model and Wav2Vec2 feature extractor.""" + self.model = BiCodec.load_from_checkpoint(f"{self.model_dir}/BiCodec").to( + self.device + ) + self.processor = Wav2Vec2FeatureExtractor.from_pretrained( + f"{self.model_dir}/wav2vec2-large-xlsr-53" + ) + self.feature_extractor = Wav2Vec2Model.from_pretrained( + f"{self.model_dir}/wav2vec2-large-xlsr-53" + ).to(self.device) + self.feature_extractor.config.output_hidden_states = True + + def get_ref_clip(self, wav: np.ndarray) -> np.ndarray: + """Get reference audio clip for speaker embedding.""" + ref_segment_length = ( + int(self.config["sample_rate"] * self.config["ref_segment_duration"]) + // self.config["latent_hop_length"] + * self.config["latent_hop_length"] + ) + wav_length = len(wav) + + if ref_segment_length > wav_length: + # Repeat and truncate to handle insufficient length + wav = np.tile(wav, ref_segment_length // wav_length + 1) + + return wav[:ref_segment_length] + + def process_audio(self, wav_path: Path) -> Tuple[np.ndarray, torch.Tensor]: + """load auido and get reference audio from wav path""" + wav = load_audio( + wav_path, + sampling_rate=self.config["sample_rate"], + volume_normalize=self.config["volume_normalize"], + ) + + wav_ref = self.get_ref_clip(wav) + + wav_ref = torch.from_numpy(wav_ref).unsqueeze(0).float() + return wav, wav_ref + + def extract_wav2vec2_features(self, wavs: torch.Tensor) -> torch.Tensor: + """extract wav2vec2 features""" + inputs = self.processor( + wavs, + sampling_rate=16000, + return_tensors="pt", + padding=True, + output_hidden_states=True, + ).input_values + feat = self.feature_extractor(inputs.to(self.feature_extractor.device)) + feats_mix = ( + feat.hidden_states[11] + feat.hidden_states[14] + feat.hidden_states[16] + ) / 3 + + return feats_mix + + def tokenize_batch(self, batch: Dict[str, Any]) -> torch.Tensor: + """tokenize the batch of audio + + Args: + batch: + wavs (List[np.ndarray]): batch of audio + ref_wavs (torch.Tensor): reference audio. shape: (batch_size, seq_len) + + Returns: + semantic_tokens: semantic tokens. shape: (batch_size, seq_len, latent_dim) + global_tokens: global tokens. shape: (batch_size, seq_len, global_dim) + """ + feats = self.extract_wav2vec2_features(batch["wav"]) + batch["feat"] = feats + semantic_tokens, global_tokens = self.model.tokenize(batch) + + return global_tokens, semantic_tokens + + def tokenize(self, audio_path: str) -> Tuple[torch.Tensor, torch.Tensor]: + """tokenize the audio""" + wav, ref_wav = self.process_audio(audio_path) + feat = self.extract_wav2vec2_features(wav) + batch = { + "wav": torch.from_numpy(wav).unsqueeze(0).float().to(self.device), + "ref_wav": ref_wav.to(self.device), + "feat": feat.to(self.device), + } + semantic_tokens, global_tokens = self.model.tokenize(batch) + + return global_tokens, semantic_tokens + + def detokenize( + self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor + ) -> np.array: + """detokenize the tokens to waveform + + Args: + global_tokens: global tokens. shape: (batch_size, global_dim) + semantic_tokens: semantic tokens. shape: (batch_size, latent_dim) + + Returns: + wav_rec: waveform. shape: (batch_size, seq_len) for batch or (seq_len,) for single + """ + global_tokens = global_tokens.unsqueeze(1) + wav_rec = self.model.detokenize(semantic_tokens, global_tokens) + return wav_rec.detach().squeeze().cpu().numpy() + + +# test +if __name__ == "__main__": + import soundfile as sf + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + tokenizer = BiCodecTokenizer( + model_dir="spark/pretrained_models/Spark-TTS-0.5B", + device=device, + ) + wav_path = "spark/example/prompt_audio.wav" + + global_tokens, semantic_tokens = tokenizer.tokenize(wav_path) + + wav_rec = tokenizer.detokenize(global_tokens.squeeze(0), semantic_tokens) + sf.write("spark/example/prompt_recon.wav", wav_rec, 16000) diff --git a/spark/sparktts/models/bicodec.py b/spark/sparktts/models/bicodec.py new file mode 100644 index 0000000..4848e32 --- /dev/null +++ b/spark/sparktts/models/bicodec.py @@ -0,0 +1,247 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from pathlib import Path +from typing import Dict, Any +from omegaconf import DictConfig +from safetensors.torch import load_file + +from spark.sparktts.utils.file import load_config +from spark.sparktts.modules.speaker.speaker_encoder import SpeakerEncoder +from spark.sparktts.modules.encoder_decoder.feat_encoder import Encoder +from spark.sparktts.modules.encoder_decoder.feat_decoder import Decoder +from spark.sparktts.modules.encoder_decoder.wave_generator import WaveGenerator +from spark.sparktts.modules.vq.factorized_vector_quantize import FactorizedVectorQuantize + + +class BiCodec(nn.Module): + """ + BiCodec model for speech synthesis, incorporating a speaker encoder, feature encoder/decoder, + quantizer, and wave generator. + """ + + def __init__( + self, + mel_params: Dict[str, Any], + encoder: nn.Module, + decoder: nn.Module, + quantizer: nn.Module, + speaker_encoder: nn.Module, + prenet: nn.Module, + postnet: nn.Module, + **kwargs + ) -> None: + """ + Initializes the BiCodec model with the required components. + + Args: + mel_params (dict): Parameters for the mel-spectrogram transformer. + encoder (nn.Module): Encoder module. + decoder (nn.Module): Decoder module. + quantizer (nn.Module): Quantizer module. + speaker_encoder (nn.Module): Speaker encoder module. + prenet (nn.Module): Prenet network. + postnet (nn.Module): Postnet network. + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.quantizer = quantizer + self.speaker_encoder = speaker_encoder + self.prenet = prenet + self.postnet = postnet + self.init_mel_transformer(mel_params) + + @classmethod + def load_from_checkpoint(cls, model_dir: Path, **kwargs) -> "BiCodec": + """ + Loads the model from a checkpoint. + + Args: + model_dir (Path): Path to the model directory containing checkpoint and config. + + Returns: + BiCodec: The initialized BiCodec model. + """ + ckpt_path = f'{model_dir}/model.safetensors' + config = load_config(f'{model_dir}/config.yaml')['audio_tokenizer'] + mel_params = config["mel_params"] + encoder = Encoder(**config["encoder"]) + quantizer = FactorizedVectorQuantize(**config["quantizer"]) + prenet = Decoder(**config["prenet"]) + postnet = Decoder(**config["postnet"]) + decoder = WaveGenerator(**config["decoder"]) + speaker_encoder = SpeakerEncoder(**config["speaker_encoder"]) + + model = cls( + mel_params=mel_params, + encoder=encoder, + decoder=decoder, + quantizer=quantizer, + speaker_encoder=speaker_encoder, + prenet=prenet, + postnet=postnet, + ) + + state_dict = load_file(ckpt_path) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + for key in missing_keys: + print(f"Missing tensor: {key}") + for key in unexpected_keys: + print(f"Unexpected tensor: {key}") + + model.eval() + model.remove_weight_norm() + + return model + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Performs a forward pass through the model. + + Args: + batch (dict): A dictionary containing features, reference waveform, and target waveform. + + Returns: + dict: A dictionary containing the reconstruction, features, and other metrics. + """ + feat = batch["feat"] + mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) + + z = self.encoder(feat.transpose(1, 2)) + vq_outputs = self.quantizer(z) + + x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2)) + + conditions = d_vector + with_speaker_loss = False + + x = self.prenet(vq_outputs["z_q"], conditions) + pred_feat = self.postnet(x) + x = x + conditions.unsqueeze(-1) + wav_recon = self.decoder(x) + + return { + "vq_loss": vq_outputs["vq_loss"], + "perplexity": vq_outputs["perplexity"], + "cluster_size": vq_outputs["active_num"], + "recons": wav_recon, + "pred_feat": pred_feat, + "x_vector": x_vector, + "d_vector": d_vector, + "audios": batch["wav"].unsqueeze(1), + "with_speaker_loss": with_speaker_loss, + } + + @torch.no_grad() + def tokenize(self, batch: Dict[str, Any]): + """ + Tokenizes the input audio into semantic and global tokens. + + Args: + batch (dict): The input audio features and reference waveform. + + Returns: + tuple: Semantic tokens and global tokens. + """ + feat = batch["feat"] + mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) + + z = self.encoder(feat.transpose(1, 2)) + semantic_tokens = self.quantizer.tokenize(z) + global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) + + return semantic_tokens, global_tokens + + @torch.no_grad() + def detokenize(self, semantic_tokens, global_tokens): + """ + Detokenizes the semantic and global tokens into a waveform. + + Args: + semantic_tokens (tensor): Semantic tokens. + global_tokens (tensor): Global tokens. + + Returns: + tensor: Reconstructed waveform. + """ + z_q = self.quantizer.detokenize(semantic_tokens) + d_vector = self.speaker_encoder.detokenize(global_tokens) + x = self.prenet(z_q, d_vector) + x = x + d_vector.unsqueeze(-1) + wav_recon = self.decoder(x) + + return wav_recon + + def init_mel_transformer(self, config: Dict[str, Any]): + """ + Initializes the MelSpectrogram transformer based on the provided configuration. + + Args: + config (dict): Configuration parameters for MelSpectrogram. + """ + import torchaudio.transforms as TT + + self.mel_transformer = TT.MelSpectrogram( + config["sample_rate"], + config["n_fft"], + config["win_length"], + config["hop_length"], + config["mel_fmin"], + config["mel_fmax"], + n_mels=config["num_mels"], + power=1, + norm="slaney", + mel_scale="slaney", + ) + + def remove_weight_norm(self): + """Removes weight normalization from all layers.""" + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: + pass # The module didn't have weight norm + + self.apply(_remove_weight_norm) + + +# Test the model +if __name__ == "__main__": + + config = load_config("pretrained_models/SparkTTS-0.5B/BiCodec/config.yaml") + model = BiCodec.load_from_checkpoint( + model_dir="pretrained_models/SparkTTS-0.5B/BiCodec", + ) + + # Generate random inputs for testing + duration = 0.96 + x = torch.randn(20, 1, int(duration * 16000)) + feat = torch.randn(20, int(duration * 50), 1024) + inputs = {"feat": feat, "wav": x, "ref_wav": x} + + # Forward pass + outputs = model(inputs) + semantic_tokens, global_tokens = model.tokenize(inputs) + wav_recon = model.detokenize(semantic_tokens, global_tokens) + + # Verify if the reconstruction matches + if torch.allclose(outputs["recons"].detach(), wav_recon): + print("Test successful") + else: + print("Test failed") diff --git a/spark/sparktts/modules/blocks/layers.py b/spark/sparktts/modules/blocks/layers.py new file mode 100644 index 0000000..9de506e --- /dev/null +++ b/spark/sparktts/modules/blocks/layers.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 + + +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) diff --git a/spark/sparktts/modules/blocks/samper.py b/spark/sparktts/modules/blocks/samper.py new file mode 100644 index 0000000..e6673bf --- /dev/null +++ b/spark/sparktts/modules/blocks/samper.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SamplingBlock(nn.Module): + """Sampling block for upsampling or downsampling""" + + def __init__( + self, + dim: int, + groups: int = 1, + upsample_scale: int = 1, + downsample_scale: int = 1, + ) -> None: + """ + Args: + dim: input dimension + groups: number of groups + upsample_scale: upsampling scale + downsample_scale: downsampling scale + """ + super(SamplingBlock, self).__init__() + + self.upsample_scale = upsample_scale + self.downsample_scale = downsample_scale + + if self.upsample_scale > 1: + self.de_conv_upsampler = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ConvTranspose1d( + dim, + dim, + kernel_size=upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + groups=groups, + ), + ) + + if self.downsample_scale > 1: + self.conv_downsampler = nn.Sequential( + nn.LeakyReLU(0.2), + nn.Conv1d( + dim, + dim, + kernel_size=2 * downsample_scale, + stride=downsample_scale, + padding=downsample_scale // 2 + downsample_scale % 2, + groups=groups, + ), + ) + + @staticmethod + def repeat_upsampler(x, upsample_scale): + return x.repeat_interleave(upsample_scale, dim=2) + + @staticmethod + def skip_downsampler(x, downsample_scale): + return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) + + def forward(self, x): + x = x.transpose(1, 2) + if self.upsample_scale > 1: + repeat_res = self.repeat_upsampler(x, self.upsample_scale) + deconv_res = self.de_conv_upsampler(x) + upmerge_res = repeat_res + deconv_res + else: + upmerge_res = x + repeat_res = x + + if self.downsample_scale > 1: + conv_res = self.conv_downsampler(upmerge_res) + skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) + skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) + else: + conv_res = upmerge_res + skip2_res = upmerge_res + skip1_res = repeat_res + + final_res = conv_res + skip1_res + skip2_res + + return final_res + + +# test +if __name__ == "__main__": + test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 + model = SamplingBlock(1024, 1024, upsample_scale=2) + model_down = SamplingBlock(1024, 1024, downsample_scale=2) + output = model(test_input) + output_down = model_down(test_input) + print("shape after upsample * 2", output.shape) # torch.Size([8, 1024, 100]) + print("shape after downsample * 2", output_down.shape) # torch.Size([8, 1024, 25]) + if output.shape == torch.Size([8, 1024, 100]) and output_down.shape == torch.Size( + [8, 1024, 25] + ): + print("test successful") diff --git a/spark/sparktts/modules/blocks/vocos.py b/spark/sparktts/modules/blocks/vocos.py new file mode 100644 index 0000000..31ff790 --- /dev/null +++ b/spark/sparktts/modules/blocks/vocos.py @@ -0,0 +1,373 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from typing import Tuple +from torch.nn.utils import weight_norm, remove_weight_norm + +from typing import Optional + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + condition_dim: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.adanorm = condition_dim is not None + if condition_dim: + self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + condition_dim (int): Dimension of the condition. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Linear(condition_dim, embedding_dim) + self.shift = nn.Linear(condition_dim, embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding) + shift = self.shift(cond_embedding) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale.unsqueeze(1) + shift.unsqueeze(1) + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + ] + ) + + self.gamma = nn.ParameterList( + [ + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + condition_dim: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = condition_dim is not None + if condition_dim: + self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + condition_dim=condition_dim, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor: + x = self.embed(x) + if self.adanorm: + assert condition is not None + x = self.norm(x.transpose(1, 2), condition) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, condition) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, + input_channels, + dim, + num_blocks, + layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm( + nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) + ) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ + ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) + for _ in range(num_blocks) + ] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x diff --git a/spark/sparktts/modules/encoder_decoder/feat_decoder.py b/spark/sparktts/modules/encoder_decoder/feat_decoder.py new file mode 100644 index 0000000..589e9a9 --- /dev/null +++ b/spark/sparktts/modules/encoder_decoder/feat_decoder.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from typing import List + +from spark.sparktts.modules.blocks.vocos import VocosBackbone +from spark.sparktts.modules.blocks.samper import SamplingBlock + + +class Decoder(nn.Module): + """Decoder module with convnext and upsampling blocks + + Args: + sample_ratios (List[int]): sample ratios + example: [2, 2] means downsample by 2x and then upsample by 2x + """ + + def __init__( + self, + input_channels: int, + vocos_dim: int, + vocos_intermediate_dim: int, + vocos_num_layers: int, + out_channels: int, + condition_dim: int = None, + sample_ratios: List[int] = [1, 1], + use_tanh_at_final: bool = False, + ): + super().__init__() + + self.linear_pre = nn.Linear(input_channels, vocos_dim) + modules = [ + nn.Sequential( + SamplingBlock( + dim=vocos_dim, + groups=vocos_dim, + upsample_scale=ratio, + ), + VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=2, + condition_dim=None, + ), + ) + for ratio in sample_ratios + ] + + self.downsample = nn.Sequential(*modules) + + self.vocos_backbone = VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=vocos_num_layers, + condition_dim=condition_dim, + ) + self.linear = nn.Linear(vocos_dim, out_channels) + self.use_tanh_at_final = use_tanh_at_final + + def forward(self, x: torch.Tensor, c: torch.Tensor = None): + """encoder forward. + + Args: + x (torch.Tensor): (batch_size, input_channels, length) + + Returns: + x (torch.Tensor): (batch_size, encode_channels, length) + """ + x = self.linear_pre(x.transpose(1, 2)) + x = self.downsample(x).transpose(1, 2) + x = self.vocos_backbone(x, condition=c) + x = self.linear(x).transpose(1, 2) + if self.use_tanh_at_final: + x = torch.tanh(x) + + return x + + +# test +if __name__ == "__main__": + test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 + condition = torch.randn(8, 256) + decoder = Decoder( + input_channels=1024, + vocos_dim=384, + vocos_intermediate_dim=2048, + vocos_num_layers=12, + out_channels=256, + condition_dim=256, + sample_ratios=[2, 2], + ) + output = decoder(test_input, condition) + print(output.shape) # torch.Size([8, 256, 200]) + if output.shape == torch.Size([8, 256, 200]): + print("Decoder test passed") + else: + print("Decoder test failed") diff --git a/spark/sparktts/modules/encoder_decoder/feat_encoder.py b/spark/sparktts/modules/encoder_decoder/feat_encoder.py new file mode 100644 index 0000000..7f23ab1 --- /dev/null +++ b/spark/sparktts/modules/encoder_decoder/feat_encoder.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + +from typing import List + +from spark.sparktts.modules.blocks.vocos import VocosBackbone +from spark.sparktts.modules.blocks.samper import SamplingBlock + + +class Encoder(nn.Module): + """Encoder module with convnext and downsampling blocks""" + + def __init__( + self, + input_channels: int, + vocos_dim: int, + vocos_intermediate_dim: int, + vocos_num_layers: int, + out_channels: int, + sample_ratios: List[int] = [1, 1], + ): + super().__init__() + """ + Encoder module with VocosBackbone and sampling blocks. + + Args: + sample_ratios (List[int]): sample ratios + example: [2, 2] means downsample by 2x and then upsample by 2x + """ + self.encoder = VocosBackbone( + input_channels=input_channels, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=vocos_num_layers, + condition_dim=None, + ) + + modules = [ + nn.Sequential( + SamplingBlock( + dim=vocos_dim, + groups=vocos_dim, + downsample_scale=ratio, + ), + VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=2, + condition_dim=None, + ), + ) + for ratio in sample_ratios + ] + + self.downsample = nn.Sequential(*modules) + + self.project = nn.Linear(vocos_dim, out_channels) + + def forward(self, x: torch.Tensor, *args): + """ + Args: + x (torch.Tensor): (batch_size, input_channels, length) + + Returns: + x (torch.Tensor): (batch_size, encode_channels, length) + """ + x = self.encoder(x) + x = self.downsample(x) + x = self.project(x) + return x.transpose(1, 2) + + +# test +if __name__ == "__main__": + test_input = torch.randn(8, 1024, 50) # Batch size = 8, 1024 channels, length = 50 + encoder = Encoder( + input_channels=1024, + vocos_dim=384, + vocos_intermediate_dim=2048, + vocos_num_layers=12, + out_channels=256, + sample_ratios=[2, 2], + ) + + output = encoder(test_input) + print(output.shape) # torch.Size([8, 256, 12]) + if output.shape == torch.Size([8, 256, 12]): + print("test successful") diff --git a/spark/sparktts/modules/encoder_decoder/wave_generator.py b/spark/sparktts/modules/encoder_decoder/wave_generator.py new file mode 100644 index 0000000..2288e54 --- /dev/null +++ b/spark/sparktts/modules/encoder_decoder/wave_generator.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/descriptinc/descript-audio-codec under the Apache License 2.0 + + +import torch.nn as nn + +from spark.sparktts.modules.blocks.layers import ( + Snake1d, + WNConv1d, + ResidualUnit, + WNConvTranspose1d, + init_weights, +) + + +class DecoderBlock(nn.Module): + def __init__( + self, + input_dim: int = 16, + output_dim: int = 8, + kernel_size: int = 2, + stride: int = 1, + ): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class WaveGenerator(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + kernel_sizes, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.apply(init_weights) + + def forward(self, x): + return self.model(x) diff --git a/spark/sparktts/modules/fsq/finite_scalar_quantization.py b/spark/sparktts/modules/fsq/finite_scalar_quantization.py new file mode 100644 index 0000000..da36155 --- /dev/null +++ b/spark/sparktts/modules/fsq/finite_scalar_quantization.py @@ -0,0 +1,251 @@ +""" +Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 +Code adapted from Jax version in Appendix A.1 +""" + +from __future__ import annotations +from functools import wraps, partial +from contextlib import nullcontext +from typing import List, Tuple + +import torch +import torch.nn as nn +from torch.nn import Module +from torch import Tensor, int32 +from torch.amp import autocast + +from einops import rearrange, pack, unpack + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def maybe(fn): + @wraps(fn) + def inner(x, *args, **kwargs): + if not exists(x): + return x + return fn(x, *args, **kwargs) + + return inner + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# tensor helpers + + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +# main class + + +class FSQ(Module): + def __init__( + self, + levels: List[int], + dim: int | None = None, + num_codebooks=1, + keep_num_codebooks_dim: bool | None = None, + scale: float | None = None, + allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), + channel_first: bool = False, + projection_has_bias: bool = True, + return_indices=True, + force_quantization_f32=True, + ): + super().__init__() + _levels = torch.tensor(levels, dtype=int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + self.channel_first = channel_first + + has_projections = self.dim != effective_codebook_dim + self.project_in = ( + nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) + if has_projections + else nn.Identity() + ) + self.project_out = ( + nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) + if has_projections + else nn.Identity() + ) + + self.has_projections = has_projections + + self.return_indices = return_indices + if return_indices: + self.codebook_size = self._levels.prod().item() + implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) + self.register_buffer( + "implicit_codebook", implicit_codebook, persistent=False + ) + + self.allowed_dtypes = allowed_dtypes + self.force_quantization_f32 = force_quantization_f32 + + def bound(self, z, eps: float = 1e-3): + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z): + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def _indices_to_codes(self, indices): + level_indices = self.indices_to_level_indices(indices) + codes = self._scale_and_shift_inverse(level_indices) + return codes + + def codes_to_indices(self, zhat): + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(int32) + + def indices_to_level_indices(self, indices): + """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered + + def indices_to_codes(self, indices): + """Inverse of `codes_to_indices`.""" + assert exists(indices) + + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + codes = self._indices_to_codes(indices) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + codes = self.project_out(codes) + + if is_img_or_video or self.channel_first: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + def forward(self, z): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension + c - number of codebook dim + """ + + is_img_or_video = z.ndim >= 4 + need_move_channel_last = is_img_or_video or self.channel_first + + # standardize image or video into (batch, seq, dimension) + + if need_move_channel_last: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert ( + z.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # whether to force quantization step to be full precision or not + + force_f32 = self.force_quantization_f32 + quantization_context = ( + partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext + ) + + with quantization_context(): + orig_dtype = z.dtype + + if force_f32 and orig_dtype not in self.allowed_dtypes: + z = z.float() + + codes = self.quantize(z) + + # returning indices could be optional + + indices = None + + if self.return_indices: + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + codes = codes.type(orig_dtype) + + # project out + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if need_move_channel_last: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + + indices = maybe(unpack_one)(indices, ps, "b * c") + + if not self.keep_num_codebooks_dim and self.return_indices: + indices = maybe(rearrange)(indices, "... 1 -> ...") + + # return quantized output and indices + + return out, indices diff --git a/spark/sparktts/modules/fsq/residual_fsq.py b/spark/sparktts/modules/fsq/residual_fsq.py new file mode 100644 index 0000000..9367fc4 --- /dev/null +++ b/spark/sparktts/modules/fsq/residual_fsq.py @@ -0,0 +1,355 @@ +import random +import torch +import torch.nn.functional as F +import torch.distributed as dist + +from typing import List +from torch import nn +from torch.nn import Module +from torch.amp import autocast +from einx import get_at +from einops import rearrange, reduce, pack, unpack + +from spark.sparktts.modules.fsq.finite_scalar_quantization import FSQ + + +def exists(val): + return val is not None + + +def first(l): + return l[0] + + +def default(val, d): + return val if exists(val) else d + + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + + +# distributed helpers + + +def is_distributed(): + return dist.is_initialized() and dist.get_world_size() > 1 + + +def get_maybe_sync_seed(device, max_size=10_000): + rand_int = torch.randint(0, max_size, (), device=device) + + if is_distributed(): + dist.all_reduce(rand_int) + + return rand_int.item() + + +class ResidualFSQ(Module): + """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__( + self, + *, + levels: List[int], + num_quantizers, + dim=None, + is_channel_first=False, + quantize_dropout=False, + quantize_dropout_cutoff_index=0, + quantize_dropout_multiple_of=1, + **kwargs, + ): + super().__init__() + codebook_dim = len(levels) + dim = default(dim, codebook_dim) + + requires_projection = codebook_dim != dim + self.project_in = ( + nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + ) + self.has_projections = requires_projection + + self.is_channel_first = is_channel_first + self.num_quantizers = num_quantizers + + self.levels = levels + self.layers = nn.ModuleList([]) + + levels_tensor = torch.Tensor(levels) + + scales = [] + + for ind in range(num_quantizers): + scales.append((levels_tensor - 1) ** -ind) + + fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) + + self.layers.append(fsq) + + assert all([not fsq.has_projections for fsq in self.layers]) + + self.codebook_size = self.layers[0].codebook_size + + self.register_buffer("scales", torch.stack(scales), persistent=False) + + self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 + + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 + + @property + def codebooks(self): + codebooks = [layer.implicit_codebook for layer in self.layers] + codebooks = torch.stack(codebooks, dim=0) + return codebooks + + def get_codes_from_indices(self, indices): + + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + + indices, ps = pack([indices], "b * q") + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + assert ( + self.quantize_dropout > 0.0 + ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) + + # take care of quantizer dropout + + mask = indices == -1 + indices = indices.masked_fill( + mask, 0 + ) # have it fetch a dummy code to be masked out later + + all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) + + # scale the codes + + scales = rearrange(self.scales, "q d -> q 1 1 d") + all_codes = all_codes * scales + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + (all_codes,) = unpack(all_codes, ps, "q b * d") + + return all_codes + + def get_output_from_indices(self, indices): + codes = self.get_codes_from_indices(indices) + codes_summed = reduce(codes, "q ... -> ...", "sum") + return self.project_out(codes_summed) + + def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): + num_quant, quant_dropout_multiple_of, device = ( + self.num_quantizers, + self.quantize_dropout_multiple_of, + x.device, + ) + + # handle channel first + + if self.is_channel_first: + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack([x], "b * d") + + # maybe project in + + x = self.project_in(x) + + quantized_out = 0.0 + residual = x + + all_indices = [] + + should_quantize_dropout = self.training and self.quantize_dropout + + # sample a layer index at which to dropout further residual quantization + # also prepare null indices + + if should_quantize_dropout: + + # check if seed is manually passed in + + if not exists(rand_quantize_dropout_fixed_seed): + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) + + rand = random.Random(rand_quantize_dropout_fixed_seed) + + rand_quantize_dropout_index = rand.randrange( + self.quantize_dropout_cutoff_index, num_quant + ) + + if quant_dropout_multiple_of != 1: + rand_quantize_dropout_index = ( + round_up_multiple( + rand_quantize_dropout_index + 1, quant_dropout_multiple_of + ) + - 1 + ) + + null_indices = torch.full( + x.shape[:2], -1.0, device=device, dtype=torch.long + ) + + # go through the layers + + with autocast("cuda", enabled=False): + for quantizer_index, (layer, scale) in enumerate( + zip(self.layers, self.scales) + ): + + if ( + should_quantize_dropout + and quantizer_index > rand_quantize_dropout_index + ): + all_indices.append(null_indices) + continue + + quantized, indices = layer(residual / scale) + + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + all_indices.append(indices) + + # project out, if needed + + quantized_out = self.project_out(quantized_out) + + # stack all indices + + all_indices = torch.stack(all_indices, dim=-1) + + # channel first out + + if self.is_channel_first: + (quantized_out,) = unpack(quantized_out, ps, "b * d") + (all_indices,) = unpack(all_indices, ps, "b * d") + + quantized_out = rearrange(quantized_out, "b ... d -> b d ...") + all_indices = rearrange(all_indices, "b ... d -> b d ...") + + # return + + ret = (quantized_out, all_indices) + + if not return_all_codes: + return ret + + # whether to return all codes from all codebooks across layers + + all_codes = self.get_codes_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + + return (*ret, all_codes) + + +# grouped residual fsq + + +class GroupedResidualFSQ(Module): + def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): + super().__init__() + self.dim = dim + self.groups = groups + assert (dim % groups) == 0 + dim_per_group = dim // groups + + self.accept_image_fmap = accept_image_fmap + + self.rvqs = nn.ModuleList([]) + + for _ in range(groups): + self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) + + self.codebook_size = self.rvqs[0].codebook_size + + @property + def codebooks(self): + return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) + + @property + def split_dim(self): + return 1 if self.accept_image_fmap else -1 + + def get_codes_from_indices(self, indices): + codes = tuple( + rvq.get_codes_from_indices(chunk_indices) + for rvq, chunk_indices in zip(self.rvqs, indices) + ) + return torch.stack(codes) + + def get_output_from_indices(self, indices): + outputs = tuple( + rvq.get_output_from_indices(chunk_indices) + for rvq, chunk_indices in zip(self.rvqs, indices) + ) + return torch.cat(outputs, dim=self.split_dim) + + def forward(self, x, return_all_codes=False): + shape, split_dim, device = x.shape, self.split_dim, x.device + assert shape[split_dim] == self.dim + + # split the feature dimension into groups + + x = x.chunk(self.groups, dim=split_dim) + + forward_kwargs = dict( + return_all_codes=return_all_codes, + rand_quantize_dropout_fixed_seed=( + get_maybe_sync_seed(device) if self.training else None + ), + ) + + # invoke residual vq on each group + + out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) + out = tuple(zip(*out)) + + # otherwise, get all the zipped outputs and combine them + + quantized, all_indices, *maybe_all_codes = out + + quantized = torch.cat(quantized, dim=split_dim) + all_indices = torch.stack(all_indices) + + ret = (quantized, all_indices, *maybe_all_codes) + return ret + + +if __name__ == "__main__": + model = ResidualFSQ( + levels=[4, 4, 4, 4, 4, 4], + num_quantizers=1, + dim=30, + is_channel_first=True, + quantize_dropout=False, + ) + x = torch.randn(2, 30, 10) + quantize, embed_ind = model(x) + + emb_from_ind = model.get_output_from_indices(embed_ind.transpose(1, 2)) + + print(quantize == emb_from_ind.transpose(1, 2)) + + print("quantize shape", quantize.shape) + print("embed_ind", embed_ind) diff --git a/spark/sparktts/modules/speaker/ecapa_tdnn.py b/spark/sparktts/modules/speaker/ecapa_tdnn.py new file mode 100644 index 0000000..e08704c --- /dev/null +++ b/spark/sparktts/modules/speaker/ecapa_tdnn.py @@ -0,0 +1,267 @@ +# Copyright (c) 2021 Zhengyang Chen (chenzhengyang117@gmail.com) +# 2022 Hongji Wang (jijijiang77@gmail.com) +# 2023 Bing Han (hanbing97@sjtu.edu.cn) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This implementation is adapted from github repo: + https://github.com/lawlict/ECAPA-TDNN. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import spark.sparktts.modules.speaker.pooling_layers as pooling_layers + + +class Res2Conv1dReluBn(nn.Module): + """ + in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): + super().__init__() + assert channels % scale == 0, "{} % {} != 0".format(channels, scale) + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + spx = torch.split(x, self.width, 1) + sp = spx[0] + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + # Order: conv -> relu -> bn + if i >= 1: + sp = sp + spx[i] + sp = conv(sp) + sp = bn(F.relu(sp)) + out.append(sp) + if self.scale != 1: + out.append(spx[self.nums]) + out = torch.cat(out, dim=1) + + return out + + +""" Conv1d + BatchNorm1d + ReLU +""" + + +class Conv1dReluBn(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +""" The SE connection of 1D case. +""" + + +class SE_Connect(nn.Module): + + def __init__(self, channels, se_bottleneck_dim=128): + super().__init__() + self.linear1 = nn.Linear(channels, se_bottleneck_dim) + self.linear2 = nn.Linear(se_bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + + return out + + +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + + +class SE_Res2Block(nn.Module): + + def __init__(self, channels, kernel_size, stride, padding, dilation, scale): + super().__init__() + self.se_res2block = nn.Sequential( + Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), + Res2Conv1dReluBn( + channels, kernel_size, stride, padding, dilation, scale=scale + ), + Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), + SE_Connect(channels), + ) + + def forward(self, x): + return x + self.se_res2block(x) + + +class ECAPA_TDNN(nn.Module): + + def __init__( + self, + channels=512, + feat_dim=80, + embed_dim=192, + pooling_func="ASTP", + global_context_att=False, + emb_bn=False, + ): + super().__init__() + + self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) + self.layer2 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 + ) + self.layer3 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 + ) + self.layer4 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 + ) + + cat_channels = channels * 3 + out_channels = 512 * 3 + self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1) + self.pool = getattr(pooling_layers, pooling_func)( + in_dim=out_channels, global_context_att=global_context_att + ) + self.pool_out_dim = self.pool.get_out_dim() + self.bn = nn.BatchNorm1d(self.pool_out_dim) + self.linear = nn.Linear(self.pool_out_dim, embed_dim) + self.emb_bn = emb_bn + if emb_bn: # better in SSL for SV + self.bn2 = nn.BatchNorm1d(embed_dim) + else: + self.bn2 = nn.Identity() + + def forward(self, x, return_latent=False): + x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) + + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + latent = F.relu(self.conv(out)) + out = self.bn(self.pool(latent)) + out = self.linear(out) + if self.emb_bn: + out = self.bn2(out) + + if return_latent: + return out, latent + return out + + +def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=1024, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=1024, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=512, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=512, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + emb_bn=emb_bn, + ) + + +if __name__ == "__main__": + x = torch.zeros(1, 200, 100) + model = ECAPA_TDNN_GLOB_c512(feat_dim=100, embed_dim=256, pooling_func="ASTP") + model.eval() + out, latent = model(x, True) + print(out.shape) + print(latent.shape) + + num_params = sum(param.numel() for param in model.parameters()) + print("{} M".format(num_params / 1e6)) + + # from thop import profile + # x_np = torch.randn(1, 200, 80) + # flops, params = profile(model, inputs=(x_np, )) + # print("FLOPs: {} G, Params: {} M".format(flops / 1e9, params / 1e6)) diff --git a/spark/sparktts/modules/speaker/perceiver_encoder.py b/spark/sparktts/modules/speaker/perceiver_encoder.py new file mode 100644 index 0000000..bed1bb1 --- /dev/null +++ b/spark/sparktts/modules/speaker/perceiver_encoder.py @@ -0,0 +1,360 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 + +from collections import namedtuple +from functools import wraps + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from packaging import version +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, causal=False, use_flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.register_buffer("mask", None, persistent=False) + + self.use_flash = use_flash + assert not ( + use_flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + self.config = namedtuple( + "EfficientAttentionConfig", + ["enable_flash", "enable_math", "enable_mem_efficient"], + ) + self.cpu_config = self.config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once( + "A100 GPU detected, using flash attention if input tensor is on cuda" + ) + self.cuda_config = self.config(True, False, False) + else: + print_once( + "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" + ) + self.cuda_config = self.config(False, True, True) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) + + if v.ndim == 3: + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.causal, + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash: + return self.flash_attn(q, k, v, mask=mask) + + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +def Sequential(*mods): + return nn.Sequential(*filter(exists, mods)) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class RMSNorm(nn.Module): + def __init__(self, dim, scale=True, dim_cond=None): + super().__init__() + self.cond = exists(dim_cond) + self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None + + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) if scale else None + + def forward(self, x, cond=None): + gamma = default(self.gamma, 1) + out = F.normalize(x, dim=-1) * self.scale * gamma + + if not self.cond: + return out + + assert exists(cond) + gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) + gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) + return out * gamma + beta + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (kernel_size,) = self.kernel_size + (dilation,) = self.dilation + (stride,) = self.stride + + assert stride == 1 + self.causal_padding = dilation * (kernel_size - 1) + + def forward(self, x): + causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) + return super().forward(causal_padded_x) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def FeedForward(dim, mult=4, causal_conv=False): + dim_inner = int(dim * mult * 2 / 3) + + conv = None + if causal_conv: + conv = nn.Sequential( + Rearrange("b n d -> b d n"), + CausalConv1d(dim_inner, dim_inner, 3), + Rearrange("b d n -> b n d"), + ) + + return Sequential( + nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim) + ) + + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_context=None, + causal=False, + dim_head=64, + heads=8, + dropout=0.0, + use_flash=False, + cross_attn_include_queries=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.cross_attn_include_queries = cross_attn_include_queries + + dim_inner = dim_head * heads + dim_context = default(dim_context, dim) + + self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + + def forward(self, x, context=None, mask=None): + h, has_context = self.heads, exists(context) + + context = default(context, x) + + if has_context and self.cross_attn_include_queries: + context = torch.cat((x, context), dim=-2) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=2, + dim_context=None, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.proj_context = ( + nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + ) + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std=0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + use_flash=use_flash_attn, + cross_attn_include_queries=True, + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = RMSNorm(dim) + + def forward(self, x, mask=None): + batch = x.shape[0] + + x = self.proj_context(x) + + latents = repeat(self.latents, "n d -> b n d", b=batch) + + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + latents = ff(latents) + latents + + return self.norm(latents) + + +if __name__ == "__main__": + model = PerceiverResampler(dim=256, dim_context=80) + x = torch.randn(8, 200, 80) + out = model(x) + print(out.shape) # [8, 32, 80] + + num_params = sum(param.numel() for param in model.parameters()) + print("{} M".format(num_params / 1e6)) diff --git a/spark/sparktts/modules/speaker/pooling_layers.py b/spark/sparktts/modules/speaker/pooling_layers.py new file mode 100644 index 0000000..fe82898 --- /dev/null +++ b/spark/sparktts/modules/speaker/pooling_layers.py @@ -0,0 +1,298 @@ +# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Pooling functions to aggregate frame-level deep features +into segment-level speaker embeddings + +High-order statistics are surprisingly effective, TSDP acts similarly as TSTP, +even though we remove the mean statistic, on Voxceleb. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TAP(nn.Module): + """ + Temporal average pooling, only first-order mean is considered + """ + + def __init__(self, in_dim=0, **kwargs): + super(TAP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + pooling_mean = x.mean(dim=-1) + # To be compatable with 2D input + pooling_mean = pooling_mean.flatten(start_dim=1) + return pooling_mean + + def get_out_dim(self): + self.out_dim = self.in_dim + return self.out_dim + + +class TSDP(nn.Module): + """ + Temporal standard deviation pooling, only second-order std is considered + """ + + def __init__(self, in_dim=0, **kwargs): + super(TSDP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + # The last dimension is the temporal axis + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) + pooling_std = pooling_std.flatten(start_dim=1) + return pooling_std + + def get_out_dim(self): + self.out_dim = self.in_dim + return self.out_dim + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, in_dim=0, **kwargs): + super(TSTP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + def get_out_dim(self): + self.out_dim = self.in_dim * 2 + return self.out_dim + + +class ASTP(nn.Module): + """ Attentive statistics pooling: Channel- and context-dependent + statistics pooling, first used in ECAPA_TDNN. + """ + + def __init__(self, + in_dim, + bottleneck_dim=128, + global_context_att=False, + **kwargs): + super(ASTP, self).__init__() + self.in_dim = in_dim + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't + # need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, + kernel_size=1) # equals V and k in the paper + + def forward(self, x): + """ + x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) + or a 4-dimensional tensor in resnet architecture (B,C,F,T) + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(x.shape) == 4: + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) + assert len(x.shape) == 3 + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! ReLU may be hard to converge. + alpha = torch.tanh( + self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + var = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(var.clamp(min=1e-7)) + return torch.cat([mean, std], dim=1) + + def get_out_dim(self): + self.out_dim = 2 * self.in_dim + return self.out_dim + + +class MHASTP(torch.nn.Module): + """ Multi head attentive statistics pooling + Reference: + Self Multi-Head Attention for Speaker Recognition + https://arxiv.org/pdf/1906.09890.pdf + """ + + def __init__(self, + in_dim, + layer_num=2, + head_num=2, + d_s=1, + bottleneck_dim=64, + **kwargs): + super(MHASTP, self).__init__() + assert (in_dim % head_num + ) == 0 # make sure that head num can be divided by input_dim + self.in_dim = in_dim + self.head_num = head_num + d_model = int(in_dim / head_num) + channel_dims = [bottleneck_dim for i in range(layer_num + 1)] + if d_s > 1: + d_s = d_model + else: + d_s = 1 + self.d_s = d_s + channel_dims[0], channel_dims[-1] = d_model, d_s + heads_att_trans = [] + for i in range(self.head_num): + att_trans = nn.Sequential() + for i in range(layer_num - 1): + att_trans.add_module( + 'att_' + str(i), + nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1)) + att_trans.add_module('tanh' + str(i), nn.Tanh()) + att_trans.add_module( + 'att_' + str(layer_num - 1), + nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], + 1, 1)) + heads_att_trans.append(att_trans) + self.heads_att_trans = nn.ModuleList(heads_att_trans) + + def forward(self, input): + """ + input: a 3-dimensional tensor in xvector architecture + or a 4-dimensional tensor in resnet architecture + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(input.shape) == 4: # B x F x T + input = input.reshape(input.shape[0], + input.shape[1] * input.shape[2], + input.shape[3]) + assert len(input.shape) == 3 + bs, f_dim, t_dim = input.shape + chunks = torch.chunk(input, self.head_num, 1) + # split + chunks_out = [] + # for i in range(self.head_num): + # att_score = self.heads_att_trans[i](chunks[i]) + for i, layer in enumerate(self.heads_att_trans): + att_score = layer(chunks[i]) + alpha = F.softmax(att_score, dim=-1) + mean = torch.sum(alpha * chunks[i], dim=2) + var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 + std = torch.sqrt(var.clamp(min=1e-7)) + chunks_out.append(torch.cat((mean, std), dim=1)) + out = torch.cat(chunks_out, dim=1) + return out + + def get_out_dim(self): + self.out_dim = 2 * self.in_dim + return self.out_dim + + +class MQMHASTP(torch.nn.Module): + """ An attentive pooling + Reference: + multi query multi head attentive statistics pooling + https://arxiv.org/pdf/2110.05042.pdf + Args: + in_dim: the feature dimension of input + layer_num: the number of layer in the pooling layer + query_num: the number of querys + head_num: the number of heads + bottleneck_dim: the bottleneck dimension + + SA (H = 1, Q = 1, n = 2, d_s = 1) ref: + https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf + MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: + https://arxiv.org/pdf/1906.09890.pdf + AS (H = 1, Q > 1, n = 2, d_s = 1) ref: + https://arxiv.org/pdf/1803.10963.pdf + VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: + http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf + """ + + def __init__(self, + in_dim, + layer_num=2, + query_num=2, + head_num=8, + d_s=2, + bottleneck_dim=64, + **kwargs): + super(MQMHASTP, self).__init__() + self.n_query = nn.ModuleList([ + MHASTP(in_dim, + layer_num=layer_num, + head_num=head_num, + d_s=d_s, + bottleneck_dim=bottleneck_dim) for i in range(query_num) + ]) + self.query_num = query_num + self.in_dim = in_dim + + def forward(self, input): + """ + input: a 3-dimensional tensor in xvector architecture + or a 4-dimensional tensor in resnet architecture + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(input.shape) == 4: # B x F x T + input = input.reshape(input.shape[0], + input.shape[1] * input.shape[2], + input.shape[3]) + assert len(input.shape) == 3 + res = [] + for i, layer in enumerate(self.n_query): + res.append(layer(input)) + out = torch.cat(res, dim=-1) + return out + + def get_out_dim(self): + self.out_dim = self.in_dim * 2 * self.query_num + return self.out_dim + + +if __name__ == '__main__': + data = torch.randn(16, 512, 10, 35) + # model = StatisticsPooling() + model = MQMHASTP(512 * 10) + model = MHASTP(512 * 10) + model = MQMHASTP(512 * 10, context=False) + print(model) + + out = model(data) + print(out.shape) + print(model.get_out_dim()) \ No newline at end of file diff --git a/spark/sparktts/modules/speaker/speaker_encoder.py b/spark/sparktts/modules/speaker/speaker_encoder.py new file mode 100644 index 0000000..e39cae0 --- /dev/null +++ b/spark/sparktts/modules/speaker/speaker_encoder.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from typing import List, Tuple +from spark.sparktts.modules.fsq.residual_fsq import ResidualFSQ +from spark.sparktts.modules.speaker.ecapa_tdnn import ECAPA_TDNN_GLOB_c512 +from spark.sparktts.modules.speaker.perceiver_encoder import PerceiverResampler + +""" +x-vector + d-vector +""" + + +class SpeakerEncoder(nn.Module): + """ + + Args: + input_dim (int): acoustic feature dimension + out_dim (int): output dimension of x-vector and d-vector + latent_dim (int): latent dimension before quantization + token_num (int): sequence length of speaker tokens + fsq_levels (List[int]): number of levels for each quantizer + fsq_num_quantizers (int): number of quantizers + + Return: + speaker_embs: (B, T2, out_dim) + """ + + def __init__( + self, + input_dim: int = 100, + out_dim: int = 512, + latent_dim: int = 128, + token_num: int = 32, + fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], + fsq_num_quantizers: int = 1, + ): + super(SpeakerEncoder, self).__init__() + + self.speaker_encoder = ECAPA_TDNN_GLOB_c512( + feat_dim=input_dim, embed_dim=out_dim + ) + self.perceiver_sampler = PerceiverResampler( + dim=latent_dim, dim_context=512 * 3, num_latents=token_num + ) + self.quantizer = ResidualFSQ( + levels=fsq_levels, + num_quantizers=fsq_num_quantizers, + dim=latent_dim, + is_channel_first=True, + quantize_dropout=False, + ) + + self.project = nn.Linear(latent_dim * token_num, out_dim) + + def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: + zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2)) + return zq.transpose(1, 2) + + def get_indices(self, mels: torch.Tensor) -> torch.Tensor: + mels = mels.transpose(1, 2) + x = self.perceiver_sampler(mels).transpose(1, 2) + zq, indices = self.quantizer(x) + return indices + + def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + mels: (B, D_mel, T1) + + Return: + x_vector: (B, out_dim) + d_vector: (B, out_dim) + """ + # mels = mels.transpose(1,2) + + x_vector, features = self.speaker_encoder(mels, True) + x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) + zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim) + x = zq.reshape(zq.shape[0], -1) + d_vector = self.project(x) + + return x_vector, d_vector + + def tokenize(self, mels: torch.Tensor) -> torch.Tensor: + """tokenize the input mel spectrogram""" + _, features = self.speaker_encoder(mels, True) + x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) + zq, indices = self.quantizer(x) + return indices + + def detokenize(self, indices: torch.Tensor) -> torch.Tensor: + """detokenize the input indices to d-vector""" + zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2) + x = zq.reshape(zq.shape[0], -1) + d_vector = self.project(x) + return d_vector + +if __name__ == "__main__": + model = SpeakerEncoder( + input_dim=100, + latent_dim=128, + token_num=32, + fsq_levels=[4, 4, 4, 4, 4, 4], + fsq_num_quantizers=1, + ) + mel = torch.randn(8, 200, 100) + x_vector, d_vector = model(mel) + print("x-vector shape", x_vector.shape) + print("d-vector shape", d_vector.shape) + + indices = model.tokenize(mel) + print("indices shape", indices.shape) + d_vector_post = model.detokenize(indices) + print("d-vector shape", d_vector_post.shape) + if d_vector_post.all() == d_vector.all(): + print("d-vector post and d-vector are the same") + else: + print("d-vector post and d-vector are different") + num_params = sum(param.numel() for param in model.parameters()) + print("{} M".format(num_params / 1e6)) \ No newline at end of file diff --git a/spark/sparktts/modules/vq/factorized_vector_quantize.py b/spark/sparktts/modules/vq/factorized_vector_quantize.py new file mode 100644 index 0000000..f820bc7 --- /dev/null +++ b/spark/sparktts/modules/vq/factorized_vector_quantize.py @@ -0,0 +1,187 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch + + +from typing import Any, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +class FactorizedVectorQuantize(nn.Module): + def __init__( + self, + input_dim: int, + codebook_size: int, + codebook_dim: int, + commitment: float, + codebook_loss_weight: float = 1.0, + decay: float = 0.99, + threshold_ema_dead_code: float = 2, + momentum: float = 0.99, + **kwargs, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + self.codebook_loss_weight = codebook_loss_weight + self.decay = decay + self.threshold_ema_dead_code = threshold_ema_dead_code + self.momentum = momentum + + if input_dim != self.codebook_dim: + self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) + + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) + self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) + + def forward(self, z: torch.Tensor) -> Dict[str, Any]: + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim + z_e = self.in_project(z) + z_q, indices, dists = self.decode_latents(z_e) + + # statistic the usage of codes + embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) + avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + active_num = (embed_onehot.sum(0).sum(0) > 0).sum() + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay) + active_num = sum(self.cluster_size > self.threshold_ema_dead_code) + + if self.training: + commit_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + + codebook_loss = ( + F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + * self.codebook_loss_weight + ) + + else: + commit_loss = torch.zeros(0, device=z.device) + codebook_loss = torch.zeros(0, device=z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_project(z_q) + + vq_loss = (commit_loss + codebook_loss).mean() + + return { + "z_q": z_q, + "indices": indices, + "dists": dists, + "vq_loss": vq_loss, + "perplexity": perplexity, + "active_num": active_num.float(), + } + + def vq2emb(self, vq, out_proj=True): + emb = self.embed_code(vq) + if out_proj: + emb = self.out_project(emb) + return emb + + def tokenize(self, z: torch.Tensor) -> torch.Tensor: + """tokenize the input tensor""" + z_e = self.in_project(z) + _, indices, _ = self.decode_latents(z_e) + return indices + + def detokenize(self, indices): + """detokenize the input indices""" + z_q = self.decode_code(indices) + z_q = self.out_project(z_q) + return z_q + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight + + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance between encodings and codebook, + # with L2 normalization, the distance is equal to cosine distance + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + return z_q, indices, dist diff --git a/spark/sparktts/utils/__init__.py b/spark/sparktts/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/spark/sparktts/utils/audio.py b/spark/sparktts/utils/audio.py new file mode 100644 index 0000000..105cd9c --- /dev/null +++ b/spark/sparktts/utils/audio.py @@ -0,0 +1,271 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Description: + This script contains a collection of functions designed to handle various + audio processing. +""" + +import random +import soxr +import soundfile +import torch +import torchaudio +import numpy as np + +from pathlib import Path +from typing import Tuple +from numpy.lib.stride_tricks import sliding_window_view + + +def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: + """ + Normalize the volume of an audio signal. + + Parameters: + audio (numpy array): Input audio signal array. + coeff (float): Target coefficient for normalization, default is 0.2. + + Returns: + numpy array: The volume-normalized audio signal. + """ + # Sort the absolute values of the audio signal + temp = np.sort(np.abs(audio)) + + # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1 + if temp[-1] < 0.1: + scaling_factor = max( + temp[-1], 1e-3 + ) # Prevent division by zero with a small constant + audio = audio / scaling_factor * 0.1 + + # Filter out values less than 0.01 from temp + temp = temp[temp > 0.01] + L = temp.shape[0] # Length of the filtered array + + # If there are fewer than or equal to 10 significant values, return the audio without further processing + if L <= 10: + return audio + + # Compute the average of the top 10% to 1% of values in temp + volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) + + # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10 + audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) + + # Ensure the maximum absolute value in the audio does not exceed 1 + max_value = np.max(np.abs(audio)) + if max_value > 1: + audio = audio / max_value + + return audio + + +def load_audio( + adfile: Path, + sampling_rate: int = None, + length: int = None, + volume_normalize: bool = False, + segment_duration: int = None, +) -> np.ndarray: + r"""Load audio file with target sampling rate and lsength + + Args: + adfile (Path): path to audio file. + sampling_rate (int, optional): target sampling rate. Defaults to None. + length (int, optional): target audio length. Defaults to None. + volume_normalize (bool, optional): whether perform volume normalization. Defaults to False. + segment_duration (int): random select a segment with duration of {segment_duration}s. + Defualt to None which means the whole audio will be used. + + Returns: + audio (np.ndarray): audio + """ + + audio, sr = soundfile.read(adfile) + if len(audio.shape) > 1: + audio = audio[:, 0] + + if sampling_rate is not None and sr != sampling_rate: + audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") + sr = sampling_rate + + if segment_duration is not None: + seg_length = int(sr * segment_duration) + audio = random_select_audio_segment(audio, seg_length) + + # Audio volume normalize + if volume_normalize: + audio = audio_volume_normalize(audio) + # check the audio length + if length is not None: + assert abs(audio.shape[0] - length) < 1000 + if audio.shape[0] > length: + audio = audio[:length] + else: + audio = np.pad(audio, (0, int(length - audio.shape[0]))) + return audio + + +def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: + """get an audio segment given the length + + Args: + audio (np.ndarray): + length (int): audio length = sampling_rate * duration + """ + if audio.shape[0] < length: + audio = np.pad(audio, (0, int(length - audio.shape[0]))) + start_index = random.randint(0, audio.shape[0] - length) + end_index = int(start_index + length) + + return audio[start_index:end_index] + + +def audio_highpass_filter(audio, sample_rate, highpass_cutoff_freq): + """apply highpass fileter to audio + + Args: + audio (np.ndarray): + sample_rate (ind): + highpass_cutoff_freq (int): + """ + + audio = torchaudio.functional.highpass_biquad( + torch.from_numpy(audio), sample_rate, cutoff_freq=highpass_cutoff_freq + ) + return audio.numpy() + + +def stft( + x: torch.Tensor, + fft_size: int, + hop_size: int, + win_length: int, + window: str, + use_complex: bool = False, +) -> torch.Tensor: + """Perform STFT and convert to magnitude spectrogram. + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + """ + + x_stft = torch.stft( + x, fft_size, hop_size, win_length, window.to(x.device), return_complex=True + ) + + # clamp is needed to avoid nan or inf + if not use_complex: + return torch.sqrt( + torch.clamp(x_stft.real**2 + x_stft.imag**2, min=1e-7, max=1e3) + ).transpose(2, 1) + else: + res = torch.cat([x_stft.real.unsqueeze(1), x_stft.imag.unsqueeze(1)], dim=1) + res = res.transpose(2, 3) # [B, 2, T, F] + return res + + +def detect_speech_boundaries( + wav: np.ndarray, + sample_rate: int, + window_duration: float = 0.1, + energy_threshold: float = 0.01, + margin_factor: int = 2 +) -> Tuple[int, int]: + """Detect the start and end points of speech in an audio signal using RMS energy. + + Args: + wav: Input audio signal array with values in [-1, 1] + sample_rate: Audio sample rate in Hz + window_duration: Duration of detection window in seconds + energy_threshold: RMS energy threshold for speech detection + margin_factor: Factor to determine extra margin around detected boundaries + + Returns: + tuple: (start_index, end_index) of speech segment + + Raises: + ValueError: If the audio contains only silence + """ + window_size = int(window_duration * sample_rate) + margin = margin_factor * window_size + step_size = window_size // 10 + + # Create sliding windows using stride tricks to avoid loops + windows = sliding_window_view(wav, window_size)[::step_size] + + # Calculate RMS energy for each window + energy = np.sqrt(np.mean(windows ** 2, axis=1)) + speech_mask = energy >= energy_threshold + + if not np.any(speech_mask): + raise ValueError("No speech detected in audio (only silence)") + + start = max(0, np.argmax(speech_mask) * step_size - margin) + end = min(len(wav), (len(speech_mask) - 1 - np.argmax(speech_mask[::-1])) * step_size + margin) + + return start, end + + +def remove_silence_on_both_ends( + wav: np.ndarray, + sample_rate: int, + window_duration: float = 0.1, + volume_threshold: float = 0.01 +) -> np.ndarray: + """Remove silence from both ends of an audio signal. + + Args: + wav: Input audio signal array + sample_rate: Audio sample rate in Hz + window_duration: Duration of detection window in seconds + volume_threshold: Amplitude threshold for silence detection + + Returns: + np.ndarray: Audio signal with silence removed from both ends + + Raises: + ValueError: If the audio contains only silence + """ + start, end = detect_speech_boundaries( + wav, + sample_rate, + window_duration, + volume_threshold + ) + return wav[start:end] + + + +def hertz_to_mel(pitch: float) -> float: + """ + Converts a frequency from the Hertz scale to the Mel scale. + + Parameters: + - pitch: float or ndarray + Frequency in Hertz. + + Returns: + - mel: float or ndarray + Frequency in Mel scale. + """ + mel = 2595 * np.log10(1 + pitch / 700) + return mel \ No newline at end of file diff --git a/spark/sparktts/utils/file.py b/spark/sparktts/utils/file.py new file mode 100644 index 0000000..bcc7c7c --- /dev/null +++ b/spark/sparktts/utils/file.py @@ -0,0 +1,221 @@ +# Copyright (c) 2025 SparkAudio +# 2025 Xinsheng Wang (w.xinshawn@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Description: + This script contains a collection of functions designed to handle various + file reading and writing operations. It provides utilities to read from files, + write data to files, and perform file manipulation tasks. +""" + + +import os +import json +import json +import csv + +from tqdm import tqdm +from typing import List, Dict, Any, Set, Union +from pathlib import Path +from omegaconf import OmegaConf, DictConfig + + +def resolve_symbolic_link(symbolic_link_path: Path) -> Path: + """ + Resolves the absolute path of a symbolic link. + + Args: + symbolic_link_path (Path): The path to the symbolic link. + + Returns: + Path: The absolute path that the symbolic link points to. + """ + + link_directory = os.path.dirname(symbolic_link_path) + target_path_relative = os.readlink(symbolic_link_path) + return os.path.join(link_directory, target_path_relative) + + +def write_jsonl(metadata: List[dict], file_path: Path) -> None: + """Writes a list of dictionaries to a JSONL file. + + Args: + metadata : List[dict] + A list of dictionaries, each representing a piece of meta. + file_path : Path + The file path to save the JSONL file + + This function writes each dictionary in the list to a new line in the specified file. + """ + with open(file_path, "w", encoding="utf-8") as f: + for meta in tqdm(metadata, desc="writing jsonl"): + # Convert dictionary to JSON string and write it to the file with a newline + json_str = json.dumps(meta, ensure_ascii=False) + "\n" + f.write(json_str) + print(f"jsonl saved to {file_path}") + + +def read_jsonl(file_path: Path) -> List[dict]: + """ + Reads a JSONL file and returns a list of dictionaries. + + Args: + file_path : Path + The path to the JSONL file to be read. + + Returns: + List[dict] + A list of dictionaries parsed from each line of the JSONL file. + """ + metadata = [] + # Open the file for reading + with open(file_path, "r", encoding="utf-8") as f: + # Split the file into lines + lines = f.read().splitlines() + # Process each line + for line in lines: + # Convert JSON string back to dictionary and append to list + meta = json.loads(line) + metadata.append(meta) + # Return the list of metadata + return metadata + +def read_json_as_jsonl(file_path: Path) -> List[dict]: + metadata = [] + with open(file_path, 'r', encoding='utf-8') as infile: + data = json.load(infile) + for k in sorted(data.keys()): + meta = {'index': k} + meta.update(data[k]) + metadata.append(meta) + return metadata + + + +def decode_unicode_strings(meta: Dict[str, Any]) -> Dict[str, Any]: + processed_meta = {} + for k, v in meta.items(): + if isinstance(v, str): + processed_meta[k] = v.encode("utf-8").decode("unicode_escape") + else: + processed_meta[k] = v + return processed_meta + + +def load_config(config_path: Path) -> DictConfig: + """Loads a configuration file and optionally merges it with a base configuration. + + Args: + config_path (Path): Path to the configuration file. + """ + # Load the initial configuration from the given path + config = OmegaConf.load(config_path) + + # Check if there is a base configuration specified and merge if necessary + if config.get("base_config", None) is not None: + base_config = OmegaConf.load(config["base_config"]) + config = OmegaConf.merge(base_config, config) + + return config + + + +def jsonl_to_csv(jsonl_file_path: str, csv_file_path: str) -> None: + """ + Converts a JSONL file to a CSV file. + + This function reads a JSONL file, determines all unique keys present in the file, + and writes the data to a CSV file with columns for all these keys. + """ + + all_keys = set() + data_rows = [] + + # Read the JSONL file once to extract keys and collect data + with open(jsonl_file_path, 'r') as file: + for line in file: + data = json.loads(line.strip()) + data_rows.append(data) + all_keys.update(data.keys()) + + # Convert the set of keys to a sorted list for consistent column order + sorted_keys = sorted(all_keys) + + # Write the data to a CSV file + with open(csv_file_path, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=sorted_keys) + + # Write the header row + writer.writeheader() + + # Write each row of data + for data in data_rows: + writer.writerow(data) + + print(f"CSV file has been created at {csv_file_path}") + + +def save_metadata(data, filename, headers=None): + """ + Save metadata to a file. + + Args: + data (list of dict): Metadata to be saved. + filename (str): Name of the file to save the metadata. + headers (list of str): The order of column names to be saved; defaults to the keys from the first dictionary in data if not provided. + """ + # Set headers to keys from the first dictionary in data if not explicitly provided + if headers is None: + headers = list(data[0].keys()) + + with open(filename, "w", encoding="utf-8") as file: + # Write the headers to the file + file.write("|".join(headers) + "\n") + for entry in data: + # Retrieve values in the order of headers, replacing any '|' characters with a space to prevent formatting errors + formatted_values = [str(entry.get(key, "")).replace("|", " ") for key in headers] + # Write the formatted values to the file + file.write("|".join(formatted_values) + "\n") + + +def read_metadata(filename, headers=None): + """ + Read metadata from a file. + + Args: + filename (str): The file from which to read the metadata. + + Returns: + list of dict: The metadata read from the file. + list of str: The headers used in the file. + """ + with open(filename, "r", encoding="utf-8") as file: + lines = file.readlines() + + data = [] + # Set headers from the first line of the file if not provided + if headers is None: + headers = lines[0].strip().split("|") + lines = lines[1:] + + for line in lines: + line = line.strip() + # Skip empty lines + if not line: + continue + # Split the line by '|' and pair with headers to form a dictionary + entry_data = dict(zip(headers, line.split("|"))) + data.append(entry_data) + + return data, headers diff --git a/spark/sparktts/utils/parse_options.sh b/spark/sparktts/utils/parse_options.sh new file mode 100644 index 0000000..025b5e3 --- /dev/null +++ b/spark/sparktts/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +# for ((argpos=1; argpos<$#; argpos++)); do +# if [ "${!argpos}" == "--config" ]; then +# argpos_plus1=$((argpos+1)) +# config=${!argpos_plus1} +# [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 +# . $config # source the config file. +# fi +# done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. \ No newline at end of file diff --git a/spark/sparktts/utils/token_parser.py b/spark/sparktts/utils/token_parser.py new file mode 100644 index 0000000..cc43782 --- /dev/null +++ b/spark/sparktts/utils/token_parser.py @@ -0,0 +1,187 @@ +TASK_TOKEN_MAP = { + "vc": "<|task_vc|>", + "tts": "<|task_tts|>", + "asr": "<|task_asr|>", + "s2s": "<|task_s2s|>", + "t2s": "<|task_t2s|>", + "understand": "<|task_understand|>", + "caption": "<|task_cap|>", + "controllable_tts": "<|task_controllable_tts|>", + "prompt_tts": "<|task_prompt_tts|>", + "speech_edit": "<|task_edit|>", +} + +LEVELS_MAP = { + "very_low": 0, + "low": 1, + "moderate": 2, + "high": 3, + "very_high": 4, +} + +LEVELS_MAP_UI = { + 1: 'very_low', + 2: 'low', + 3: 'moderate', + 4: 'high', + 5: 'very_high' +} + +GENDER_MAP = { + "female": 0, + "male": 1, +} + +AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4} + +EMO_MAP = { + "UNKNOWN": 0, + "NEUTRAL": 1, + "ANGRY": 2, + "HAPPY": 3, + "SAD": 4, + "FEARFUL": 5, + "DISGUSTED": 6, + "SURPRISED": 7, + "SARCASTIC": 8, + "EXCITED": 9, + "SLEEPY": 10, + "CONFUSED": 11, + "EMPHASIS": 12, + "LAUGHING": 13, + "SINGING": 14, + "WORRIED": 15, + "WHISPER": 16, + "ANXIOUS": 17, + "NO-AGREEMENT": 18, + "APOLOGETIC": 19, + "CONCERNED": 20, + "ENUNCIATED": 21, + "ASSERTIVE": 22, + "ENCOURAGING": 23, + "CONTEMPT": 24, +} + + +class TokenParser: + """Turn label to special token""" + + def __init__(self): + pass + + """Parse the attributes of a person.""" + + def __init__(self): + pass + + @staticmethod + def age(age: str) -> str: + """Turn age token.""" + age_id = AGE_MAP[age] + return f"<|age_{age_id}|>" + + @staticmethod + def gender(gender: str) -> str: + """Turn gender token.""" + gender_id = GENDER_MAP[gender] + return f"<|gender_{gender_id}|>" + + @staticmethod + def mel_value(mel: int): + """Turn special token of mel scale pitch.""" + mel = max(0, int(mel)) + mel = min(1000, int(mel)) + return f"<|pitch_value_{mel}|>" + + @staticmethod + def mel_level(level: str): + """Turn special token of mel level.""" + level_tag = LEVELS_MAP[level] + return f"<|pitch_label_{level_tag}|>" + + @staticmethod + def pitch_var_value(pitch_std: int): + """Turn special token of pitch_std value.""" + assert isinstance(pitch_std, int) + pitch_std = max(0, int(pitch_std)) + pitch_std = min(10, int(pitch_std)) + return f"<|pitch_var_value_{pitch_std}|>" + + @staticmethod + def pitch_var_level(level: str): + """Turn special token of pitch std level.""" + level_tag = LEVELS_MAP[level] + return f"<|pitch_var_label_{level_tag}|>" + + @staticmethod + def loudness_value(loudness: int): + """Turn special toak of loudness value [0, 30]""" + assert loudness >= 0 + loudness = max(0, int(loudness)) + loudness = min(30, int(loudness)) + return f"<|loudness_value_{loudness}|>" + + @staticmethod + def loudness_level(level: str): + """Turn special token of loudness level.""" + level_tag = LEVELS_MAP[level] + return f"<|loudness_label_{level_tag}|>" + + @staticmethod + def speed_value(speed: int): + """Turn special token of speed value.""" + speed = max(0, int(speed)) + speed = min(10, int(speed)) + return f"<|speed_value_{speed}|>" + + @staticmethod + def speed_level(level: str): + """Turn special token of speed level.""" + level_tag = LEVELS_MAP[level] + return f"<|speed_label_{level_tag}|>" + + @staticmethod + def task(task: str) -> str: + """Turn special token of task.""" + assert task in TASK_TOKEN_MAP.keys() + + return TASK_TOKEN_MAP[task] + + @staticmethod + def emotion(emotion: str): + emo_id = EMO_MAP[emotion] + + return f"<|emotion_{emo_id}|>" + + +# test +if __name__ == "__main__": + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained( + "/aifs4su/xinshengwang/code/StyleCraft/tokenizer/stylecraft-bicodec-pitch-loudness-speed-emotion-tokenizer" + ) + + tasks = ["tts", "tts", "understand", "controllable_tts", "prompt_tts"] + ages = ["Child", "Teenager", "Youth-Adult", "Middle-aged", "Elderly"] + genders = ["female", "female", "female", "male", "male"] + mels = [100, 200, 300, 400, 500] + mel_levels = ["very_low", "low", "moderate", "high", "very_high"] + loudnesses = [1, 10, 23, 19, 30] + loudness_levels = ["very_low", "low", "moderate", "high", "very_high"] + emotions = ["UNKNOWN", "NEUTRAL", "ANGRY", "HAPPY", "SAD"] + + for i in range(5): + task = TokenParser.task(tasks[i]) + age = TokenParser.age(ages[i]) + gender = TokenParser.gender(genders[i]) + mel = TokenParser.mel_value(mels[i]) + mel_level = TokenParser.mel_level(mel_levels[i]) + loudness = TokenParser.loudness_value(loudnesses[i]) + loudness_level = TokenParser.loudness_level(loudness_levels[i]) + emotion = TokenParser.emotion(emotions[i]) + inputs = [task, age, gender, mel, mel_level, loudness, loudness_level, emotion] + inputs = "".join(inputs) + ids = tokenizer.encode(inputs, add_special_tokens=False) + print(ids) + print("decode", tokenizer.decode(ids)) diff --git a/tools/download_assets.py b/tools/download_assets.py index 1bbfbaa..f413f26 100644 --- a/tools/download_assets.py +++ b/tools/download_assets.py @@ -2,15 +2,10 @@ import os import subprocess import sys -import shutil -import stat import requests from pathlib import Path -########################### -# Part 1: Clone Spark TTS # -########################### - +# Part 1: Download Spark assets def run_command(command, error_message): try: subprocess.run(command, check=True) @@ -19,7 +14,6 @@ def run_command(command, error_message): sys.exit(1) def clone_spark_tts(): - # Create the directory spark/pretrained_models if it doesn't exist. spark_pretrained_dir = os.path.join("spark", "pretrained_models") os.makedirs(spark_pretrained_dir, exist_ok=True) @@ -44,10 +38,7 @@ def clone_spark_tts(): else: print(f"Directory '{clone_dir}' already exists. Skipping clone.") -############################# -# Part 2: Download RVC Assets # -############################# - +# Part 2: Download RVC Assets def dl_model(link, model_name, dir_name): with requests.get(f"{link}{model_name}") as r: r.raise_for_status() @@ -58,7 +49,6 @@ def dl_model(link, model_name, dir_name): def download_rvc_models(): RVC_DOWNLOAD_LINK = "https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/" - # Set BASE_DIR to the project root. If this script is in ./tools, we go one level up. BASE_DIR = Path(__file__).resolve().parent.parent def check_and_dl(link, model_name, dest_dir): @@ -121,10 +111,6 @@ def download_rvc_models(): print("All models downloaded!") -########################## -# Main: Run both parts # -########################## - def main(): clone_spark_tts() download_rvc_models()