# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json import os import logging from typing import List, Dict import torch from torch.utils.dlpack import to_dlpack import triton_python_backend_utils as pb_utils from spark.sparktts.models.bicodec import BiCodec # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class TritonPythonModel: """Triton Python model for vocoder. This model takes global and semantic tokens as input and generates audio waveforms using the BiCodec vocoder. """ def initialize(self, args): """Initialize the model. Args: args: Dictionary containing model configuration """ # Parse model parameters parameters = json.loads(args['model_config'])['parameters'] model_params = {key: value["string_value"] for key, value in parameters.items()} model_dir = model_params["model_dir"] # Initialize device and vocoder self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing vocoder from {model_dir} on {self.device}") self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec") del self.vocoder.encoder, self.vocoder.postnet self.vocoder.eval().to(self.device) # Set model to evaluation mode logger.info("Vocoder initialized successfully") def execute(self, requests): """Execute inference on the batched requests. Args: requests: List of inference requests Returns: List of inference responses containing generated waveforms """ global_tokens_list, semantic_tokens_list = [], [] # Process each request in batch for request in requests: global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy() semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy() global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device)) semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device)) # Concatenate tokens for batch processing global_tokens = torch.cat(global_tokens_list, dim=0) semantic_tokens = torch.cat(semantic_tokens_list, dim=0) # Generate waveforms with torch.no_grad(): wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1)) # Prepare responses responses = [] for i in range(len(requests)): wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i])) inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor]) responses.append(inference_response) return responses