mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
107 lines
4.4 KiB
Python
107 lines
4.4 KiB
Python
# Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions
|
|
# are met:
|
|
# * Redistributions of source code must retain the above copyright
|
|
# notice, this list of conditions and the following disclaimer.
|
|
# * Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
# * Neither the name of NVIDIA CORPORATION nor the names of its
|
|
# contributors may be used to endorse or promote products derived
|
|
# from this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
import json
|
|
import os
|
|
import logging
|
|
from typing import List, Dict
|
|
|
|
import torch
|
|
from torch.utils.dlpack import to_dlpack
|
|
|
|
import triton_python_backend_utils as pb_utils
|
|
|
|
from spark.sparktts.models.bicodec import BiCodec
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TritonPythonModel:
|
|
"""Triton Python model for vocoder.
|
|
|
|
This model takes global and semantic tokens as input and generates audio waveforms
|
|
using the BiCodec vocoder.
|
|
"""
|
|
|
|
def initialize(self, args):
|
|
"""Initialize the model.
|
|
|
|
Args:
|
|
args: Dictionary containing model configuration
|
|
"""
|
|
# Parse model parameters
|
|
parameters = json.loads(args['model_config'])['parameters']
|
|
model_params = {key: value["string_value"] for key, value in parameters.items()}
|
|
model_dir = model_params["model_dir"]
|
|
|
|
# Initialize device and vocoder
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
logger.info(f"Initializing vocoder from {model_dir} on {self.device}")
|
|
|
|
self.vocoder = BiCodec.load_from_checkpoint(f"{model_dir}/BiCodec")
|
|
del self.vocoder.encoder, self.vocoder.postnet
|
|
self.vocoder.eval().to(self.device) # Set model to evaluation mode
|
|
|
|
logger.info("Vocoder initialized successfully")
|
|
|
|
|
|
def execute(self, requests):
|
|
"""Execute inference on the batched requests.
|
|
|
|
Args:
|
|
requests: List of inference requests
|
|
|
|
Returns:
|
|
List of inference responses containing generated waveforms
|
|
"""
|
|
global_tokens_list, semantic_tokens_list = [], []
|
|
|
|
# Process each request in batch
|
|
for request in requests:
|
|
global_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "global_tokens").as_numpy()
|
|
semantic_tokens_tensor = pb_utils.get_input_tensor_by_name(request, "semantic_tokens").as_numpy()
|
|
global_tokens_list.append(torch.from_numpy(global_tokens_tensor).to(self.device))
|
|
semantic_tokens_list.append(torch.from_numpy(semantic_tokens_tensor).to(self.device))
|
|
|
|
# Concatenate tokens for batch processing
|
|
global_tokens = torch.cat(global_tokens_list, dim=0)
|
|
semantic_tokens = torch.cat(semantic_tokens_list, dim=0)
|
|
|
|
|
|
# Generate waveforms
|
|
with torch.no_grad():
|
|
wavs = self.vocoder.detokenize(semantic_tokens, global_tokens.unsqueeze(1))
|
|
|
|
# Prepare responses
|
|
responses = []
|
|
for i in range(len(requests)):
|
|
wav_tensor = pb_utils.Tensor.from_dlpack("waveform", to_dlpack(wavs[i]))
|
|
inference_response = pb_utils.InferenceResponse(output_tensors=[wav_tensor])
|
|
responses.append(inference_response)
|
|
|
|
return responses
|