2025-03-19 13:40:08 +02:00

55 lines
2.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from infer.lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM
if __name__ == "__main__":
MoeVS = True # ModelYesNo为MoeVoiceStudio原MoeSS使用
ModelPath = "Shiroha/shiroha.pth" # Path to Model:
ExportedPath = "model.onnx" # 输出路径
hidden_channels = 256 # hidden_channels为768Vec做准备
cpt = torch.load(ModelPath, map_location="cpu")
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
print(*cpt["config"])
test_phone = torch.rand(1, 200, hidden_channels) # hidden unit
test_phone_lengths = torch.tensor([200]).long() # hidden unit 长度(貌似没啥用)
test_pitch = torch.randint(size=(1, 200), low=5, high=255) # 基频(单位赫兹)
test_pitchf = torch.rand(1, 200) # nsf基频
test_ds = torch.LongTensor([0]) # 说话人ID
test_rnd = torch.rand(1, 192, 200) # 噪声(加入随机因子)
device = "cpu" # 导出时设备不影响使用Model
net_g = SynthesizerTrnMsNSFsidM(
*cpt["config"], is_half=False
) # fp32导出C++要支持fp16必须手动将内存重新排列所以暂时不用fp16
net_g.load_state_dict(cpt["weight"], strict=False)
input_names = ["phone", "phone_lengths", "pitch", "pitchf", "ds", "rnd"]
output_names = [
"audio",
]
# net_g.construct_spkmixmap(n_speaker) 多角色混合轨道导出
torch.onnx.export(
net_g,
(
test_phone.to(device),
test_phone_lengths.to(device),
test_pitch.to(device),
test_pitchf.to(device),
test_ds.to(device),
test_rnd.to(device),
),
ExportedPath,
dynamic_axes={
"phone": [1],
"pitch": [1],
"pitchf": [1],
"rnd": [2],
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names,
)