mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2024-12-29 19:15:04 +08:00
chore(format): run black on dev (#1638)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
997a956f4f
commit
5449f84f06
59
gui_v1.py
59
gui_v1.py
@ -38,7 +38,11 @@ def phase_vocoder(a, b, fade_out, fade_in):
|
|||||||
deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
|
deltaphase = deltaphase - 2 * np.pi * torch.floor(deltaphase / 2 / np.pi + 0.5)
|
||||||
w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
|
w = 2 * np.pi * torch.arange(n // 2 + 1).to(a) + deltaphase
|
||||||
t = torch.arange(n).unsqueeze(-1).to(a) / n
|
t = torch.arange(n).unsqueeze(-1).to(a) / n
|
||||||
result = a * (fade_out ** 2) + b * (fade_in ** 2) + torch.sum(absab * torch.cos(w * t + phia), -1) * window / n
|
result = (
|
||||||
|
a * (fade_out**2)
|
||||||
|
+ b * (fade_in**2)
|
||||||
|
+ torch.sum(absab * torch.cos(w * t + phia), -1) * window / n
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -592,11 +596,11 @@ if __name__ == "__main__":
|
|||||||
self.gui_config.pth_path = values["pth_path"]
|
self.gui_config.pth_path = values["pth_path"]
|
||||||
self.gui_config.index_path = values["index_path"]
|
self.gui_config.index_path = values["index_path"]
|
||||||
self.gui_config.sr_type = ["sr_model", "sr_device"][
|
self.gui_config.sr_type = ["sr_model", "sr_device"][
|
||||||
[
|
[
|
||||||
values["sr_model"],
|
values["sr_model"],
|
||||||
values["sr_device"],
|
values["sr_device"],
|
||||||
].index(True)
|
].index(True)
|
||||||
]
|
]
|
||||||
self.gui_config.threhold = values["threhold"]
|
self.gui_config.threhold = values["threhold"]
|
||||||
self.gui_config.pitch = values["pitch"]
|
self.gui_config.pitch = values["pitch"]
|
||||||
self.gui_config.block_time = values["block_time"]
|
self.gui_config.block_time = values["block_time"]
|
||||||
@ -633,7 +637,11 @@ if __name__ == "__main__":
|
|||||||
self.config,
|
self.config,
|
||||||
self.rvc if hasattr(self, "rvc") else None,
|
self.rvc if hasattr(self, "rvc") else None,
|
||||||
)
|
)
|
||||||
self.gui_config.samplerate = self.rvc.tgt_sr if self.gui_config.sr_type == "sr_model" else self.get_device_samplerate()
|
self.gui_config.samplerate = (
|
||||||
|
self.rvc.tgt_sr
|
||||||
|
if self.gui_config.sr_type == "sr_model"
|
||||||
|
else self.get_device_samplerate()
|
||||||
|
)
|
||||||
self.zc = self.gui_config.samplerate // 100
|
self.zc = self.gui_config.samplerate // 100
|
||||||
self.block_frame = (
|
self.block_frame = (
|
||||||
int(
|
int(
|
||||||
@ -690,7 +698,9 @@ if __name__ == "__main__":
|
|||||||
2 * self.zc, device=self.config.device, dtype=torch.float32
|
2 * self.zc, device=self.config.device, dtype=torch.float32
|
||||||
)
|
)
|
||||||
self.skip_head = self.extra_frame // self.zc
|
self.skip_head = self.extra_frame // self.zc
|
||||||
self.return_length = (self.block_frame + self.sola_buffer_frame + self.sola_search_frame) // self.zc
|
self.return_length = (
|
||||||
|
self.block_frame + self.sola_buffer_frame + self.sola_search_frame
|
||||||
|
) // self.zc
|
||||||
self.fade_in_window: torch.Tensor = (
|
self.fade_in_window: torch.Tensor = (
|
||||||
torch.sin(
|
torch.sin(
|
||||||
0.5
|
0.5
|
||||||
@ -824,7 +834,11 @@ if __name__ == "__main__":
|
|||||||
# volume envelop mixing
|
# volume envelop mixing
|
||||||
if self.gui_config.rms_mix_rate < 1 and self.function == "vc":
|
if self.gui_config.rms_mix_rate < 1 and self.function == "vc":
|
||||||
rms1 = librosa.feature.rms(
|
rms1 = librosa.feature.rms(
|
||||||
y=self.input_wav_res[160 * self.skip_head : 160 * (self.skip_head + self.return_length)]
|
y=self.input_wav_res[
|
||||||
|
160
|
||||||
|
* self.skip_head : 160
|
||||||
|
* (self.skip_head + self.return_length)
|
||||||
|
]
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy(),
|
.numpy(),
|
||||||
frame_length=640,
|
frame_length=640,
|
||||||
@ -871,21 +885,24 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
sola_offset = torch.argmax(cor_nom[0, 0] / cor_den[0, 0])
|
||||||
printt("sola_offset = %d", int(sola_offset))
|
printt("sola_offset = %d", int(sola_offset))
|
||||||
infer_wav = infer_wav[sola_offset :]
|
infer_wav = infer_wav[sola_offset:]
|
||||||
if "privateuseone" in str(self.config.device) or not self.gui_config.use_pv:
|
if "privateuseone" in str(self.config.device) or not self.gui_config.use_pv:
|
||||||
infer_wav[: self.sola_buffer_frame] *= self.fade_in_window
|
infer_wav[: self.sola_buffer_frame] *= self.fade_in_window
|
||||||
infer_wav[: self.sola_buffer_frame] += self.sola_buffer * self.fade_out_window
|
infer_wav[: self.sola_buffer_frame] += (
|
||||||
|
self.sola_buffer * self.fade_out_window
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
infer_wav[: self.sola_buffer_frame] = phase_vocoder(
|
infer_wav[: self.sola_buffer_frame] = phase_vocoder(
|
||||||
self.sola_buffer,
|
self.sola_buffer,
|
||||||
infer_wav[: self.sola_buffer_frame],
|
infer_wav[: self.sola_buffer_frame],
|
||||||
self.fade_out_window,
|
self.fade_out_window,
|
||||||
self.fade_in_window)
|
self.fade_in_window,
|
||||||
self.sola_buffer[:] = infer_wav[self.block_frame : self.block_frame + self.sola_buffer_frame]
|
|
||||||
if sys.platform == "darwin":
|
|
||||||
outdata[:] = (
|
|
||||||
infer_wav[: self.block_frame].cpu().numpy()[:, np.newaxis]
|
|
||||||
)
|
)
|
||||||
|
self.sola_buffer[:] = infer_wav[
|
||||||
|
self.block_frame : self.block_frame + self.sola_buffer_frame
|
||||||
|
]
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
outdata[:] = infer_wav[: self.block_frame].cpu().numpy()[:, np.newaxis]
|
||||||
else:
|
else:
|
||||||
outdata[:] = (
|
outdata[:] = (
|
||||||
infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy()
|
infer_wav[: self.block_frame].repeat(2, 1).t().cpu().numpy()
|
||||||
@ -949,6 +966,8 @@ if __name__ == "__main__":
|
|||||||
printt("Output device: %s:%s", str(sd.default.device[1]), output_device)
|
printt("Output device: %s:%s", str(sd.default.device[1]), output_device)
|
||||||
|
|
||||||
def get_device_samplerate(self):
|
def get_device_samplerate(self):
|
||||||
return int(sd.query_devices(device=sd.default.device[0])['default_samplerate'])
|
return int(
|
||||||
|
sd.query_devices(device=sd.default.device[0])["default_samplerate"]
|
||||||
|
)
|
||||||
|
|
||||||
gui = GUI()
|
gui = GUI()
|
||||||
|
@ -795,9 +795,9 @@ class SynthesizerTrnMs256NSFsid(nn.Module):
|
|||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head: head + length]
|
z_p = z_p[:, :, head : head + length]
|
||||||
x_mask = x_mask[:, :, head: head + length]
|
x_mask = x_mask[:, :, head : head + length]
|
||||||
nsff0 = nsff0[:, head: head + length]
|
nsff0 = nsff0[:, head : head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
o = self.dec(z * x_mask, nsff0, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -957,9 +957,9 @@ class SynthesizerTrnMs768NSFsid(nn.Module):
|
|||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head: head + length]
|
z_p = z_p[:, :, head : head + length]
|
||||||
x_mask = x_mask[:, :, head: head + length]
|
x_mask = x_mask[:, :, head : head + length]
|
||||||
nsff0 = nsff0[:, head: head + length]
|
nsff0 = nsff0[:, head : head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, nsff0, g=g)
|
o = self.dec(z * x_mask, nsff0, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -1108,8 +1108,8 @@ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
|
|||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head: head + length]
|
z_p = z_p[:, :, head : head + length]
|
||||||
x_mask = x_mask[:, :, head: head + length]
|
x_mask = x_mask[:, :, head : head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, g=g)
|
o = self.dec(z * x_mask, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
@ -1258,8 +1258,8 @@ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
|
|||||||
assert isinstance(return_length, torch.Tensor)
|
assert isinstance(return_length, torch.Tensor)
|
||||||
head = int(skip_head.item())
|
head = int(skip_head.item())
|
||||||
length = int(return_length.item())
|
length = int(return_length.item())
|
||||||
z_p = z_p[:, :, head: head + length]
|
z_p = z_p[:, :, head : head + length]
|
||||||
x_mask = x_mask[:, :, head: head + length]
|
x_mask = x_mask[:, :, head : head + length]
|
||||||
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
||||||
o = self.dec(z * x_mask, g=g)
|
o = self.dec(z * x_mask, g=g)
|
||||||
return o, x_mask, (z, z_p, m_p, logs_p)
|
return o, x_mask, (z, z_p, m_p, logs_p)
|
||||||
|
@ -38,6 +38,7 @@ def spectral_de_normalize_torch(magnitudes):
|
|||||||
mel_basis = {}
|
mel_basis = {}
|
||||||
hann_window = {}
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
|
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
|
||||||
|
|
||||||
@ -87,6 +88,7 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
|||||||
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||||
# MelBasis - Cache if needed
|
# MelBasis - Cache if needed
|
||||||
global mel_basis
|
global mel_basis
|
||||||
|
@ -46,22 +46,23 @@ def printt(strr, *args):
|
|||||||
# config.is_half=False########强制cpu测试
|
# config.is_half=False########强制cpu测试
|
||||||
class RVC:
|
class RVC:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key,
|
key,
|
||||||
pth_path,
|
pth_path,
|
||||||
index_path,
|
index_path,
|
||||||
index_rate,
|
index_rate,
|
||||||
n_cpu,
|
n_cpu,
|
||||||
inp_q,
|
inp_q,
|
||||||
opt_q,
|
opt_q,
|
||||||
config: Config,
|
config: Config,
|
||||||
last_rvc=None,
|
last_rvc=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
初始化
|
初始化
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if config.dml == True:
|
if config.dml == True:
|
||||||
|
|
||||||
def forward_dml(ctx, x, scale):
|
def forward_dml(ctx, x, scale):
|
||||||
ctx.scale = scale
|
ctx.scale = scale
|
||||||
res = x.clone().detach()
|
res = x.clone().detach()
|
||||||
@ -201,7 +202,7 @@ class RVC:
|
|||||||
f0bak = f0.copy()
|
f0bak = f0.copy()
|
||||||
f0_mel = 1127 * np.log(1 + f0 / 700)
|
f0_mel = 1127 * np.log(1 + f0 / 700)
|
||||||
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (
|
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (
|
||||||
self.f0_mel_max - self.f0_mel_min
|
self.f0_mel_max - self.f0_mel_min
|
||||||
) + 1
|
) + 1
|
||||||
f0_mel[f0_mel <= 1] = 1
|
f0_mel[f0_mel <= 1] = 1
|
||||||
f0_mel[f0_mel > 255] = 255
|
f0_mel[f0_mel > 255] = 255
|
||||||
@ -258,7 +259,7 @@ class RVC:
|
|||||||
self.inp_q.put((idx, x[:tail], res_f0, n_cpu, ts))
|
self.inp_q.put((idx, x[:tail], res_f0, n_cpu, ts))
|
||||||
else:
|
else:
|
||||||
self.inp_q.put(
|
self.inp_q.put(
|
||||||
(idx, x[part_length * idx - 320: tail], res_f0, n_cpu, ts)
|
(idx, x[part_length * idx - 320 : tail], res_f0, n_cpu, ts)
|
||||||
)
|
)
|
||||||
while 1:
|
while 1:
|
||||||
res_ts = self.opt_q.get()
|
res_ts = self.opt_q.get()
|
||||||
@ -273,7 +274,7 @@ class RVC:
|
|||||||
else:
|
else:
|
||||||
f0 = f0[2:]
|
f0 = f0[2:]
|
||||||
f0bak[
|
f0bak[
|
||||||
part_length * idx // 160: part_length * idx // 160 + f0.shape[0]
|
part_length * idx // 160 : part_length * idx // 160 + f0.shape[0]
|
||||||
] = f0
|
] = f0
|
||||||
f0bak = signal.medfilt(f0bak, 3)
|
f0bak = signal.medfilt(f0bak, 3)
|
||||||
f0bak *= pow(2, f0_up_key / 12)
|
f0bak *= pow(2, f0_up_key / 12)
|
||||||
@ -320,6 +321,7 @@ class RVC:
|
|||||||
def get_f0_fcpe(self, x, f0_up_key):
|
def get_f0_fcpe(self, x, f0_up_key):
|
||||||
if hasattr(self, "model_fcpe") == False:
|
if hasattr(self, "model_fcpe") == False:
|
||||||
from torchfcpe import spawn_bundled_infer_model
|
from torchfcpe import spawn_bundled_infer_model
|
||||||
|
|
||||||
printt("Loading fcpe model")
|
printt("Loading fcpe model")
|
||||||
if "privateuseone" in str(self.device):
|
if "privateuseone" in str(self.device):
|
||||||
self.device_fcpe = "cpu"
|
self.device_fcpe = "cpu"
|
||||||
@ -329,7 +331,7 @@ class RVC:
|
|||||||
f0 = self.model_fcpe.infer(
|
f0 = self.model_fcpe.infer(
|
||||||
x.to(self.device_fcpe).unsqueeze(0).float(),
|
x.to(self.device_fcpe).unsqueeze(0).float(),
|
||||||
sr=16000,
|
sr=16000,
|
||||||
decoder_mode='local_argmax',
|
decoder_mode="local_argmax",
|
||||||
threshold=0.006,
|
threshold=0.006,
|
||||||
)
|
)
|
||||||
f0 *= pow(2, f0_up_key / 12)
|
f0 *= pow(2, f0_up_key / 12)
|
||||||
@ -337,12 +339,12 @@ class RVC:
|
|||||||
return self.get_f0_post(f0)
|
return self.get_f0_post(f0)
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
input_wav: torch.Tensor,
|
input_wav: torch.Tensor,
|
||||||
block_frame_16k,
|
block_frame_16k,
|
||||||
skip_head,
|
skip_head,
|
||||||
return_length,
|
return_length,
|
||||||
f0method,
|
f0method,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -364,16 +366,16 @@ class RVC:
|
|||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
try:
|
try:
|
||||||
if hasattr(self, "index") and self.index_rate != 0:
|
if hasattr(self, "index") and self.index_rate != 0:
|
||||||
npy = feats[0][skip_head // 2:].cpu().numpy().astype("float32")
|
npy = feats[0][skip_head // 2 :].cpu().numpy().astype("float32")
|
||||||
score, ix = self.index.search(npy, k=8)
|
score, ix = self.index.search(npy, k=8)
|
||||||
weight = np.square(1 / score)
|
weight = np.square(1 / score)
|
||||||
weight /= weight.sum(axis=1, keepdims=True)
|
weight /= weight.sum(axis=1, keepdims=True)
|
||||||
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
||||||
if self.config.is_half:
|
if self.config.is_half:
|
||||||
npy = npy.astype("float16")
|
npy = npy.astype("float16")
|
||||||
feats[0][skip_head // 2:] = (
|
feats[0][skip_head // 2 :] = (
|
||||||
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
|
torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate
|
||||||
+ (1 - self.index_rate) * feats[0][skip_head // 2:]
|
+ (1 - self.index_rate) * feats[0][skip_head // 2 :]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
printt("Index search FAILED or disabled")
|
printt("Index search FAILED or disabled")
|
||||||
@ -384,21 +386,29 @@ class RVC:
|
|||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
f0_extractor_frame = block_frame_16k + 800
|
f0_extractor_frame = block_frame_16k + 800
|
||||||
if f0method == "rmvpe":
|
if f0method == "rmvpe":
|
||||||
f0_extractor_frame = (
|
f0_extractor_frame = 5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160
|
||||||
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - 160
|
pitch, pitchf = self.get_f0(
|
||||||
)
|
input_wav[-f0_extractor_frame:], self.f0_up_key, self.n_cpu, f0method
|
||||||
pitch, pitchf = self.get_f0(input_wav[-f0_extractor_frame: ], self.f0_up_key, self.n_cpu, f0method)
|
)
|
||||||
start_frame = block_frame_16k // 160
|
start_frame = block_frame_16k // 160
|
||||||
end_frame = len(self.cache_pitch) - (pitch.shape[0] - 4) + start_frame
|
end_frame = len(self.cache_pitch) - (pitch.shape[0] - 4) + start_frame
|
||||||
self.cache_pitch[:] = np.append(self.cache_pitch[start_frame: end_frame], pitch[3:-1])
|
self.cache_pitch[:] = np.append(
|
||||||
|
self.cache_pitch[start_frame:end_frame], pitch[3:-1]
|
||||||
|
)
|
||||||
self.cache_pitchf[:] = np.append(
|
self.cache_pitchf[:] = np.append(
|
||||||
self.cache_pitchf[start_frame: end_frame], pitchf[3:-1]
|
self.cache_pitchf[start_frame:end_frame], pitchf[3:-1]
|
||||||
)
|
)
|
||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
p_len = input_wav.shape[0] // 160
|
p_len = input_wav.shape[0] // 160
|
||||||
if self.if_f0 == 1:
|
if self.if_f0 == 1:
|
||||||
cache_pitch = torch.LongTensor(self.cache_pitch[-p_len: ]).to(self.device).unsqueeze(0)
|
cache_pitch = (
|
||||||
cache_pitchf = torch.FloatTensor(self.cache_pitchf[-p_len: ]).to(self.device).unsqueeze(0)
|
torch.LongTensor(self.cache_pitch[-p_len:]).to(self.device).unsqueeze(0)
|
||||||
|
)
|
||||||
|
cache_pitchf = (
|
||||||
|
torch.FloatTensor(self.cache_pitchf[-p_len:])
|
||||||
|
.to(self.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
)
|
||||||
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
||||||
feats = feats[:, :p_len, :]
|
feats = feats[:, :p_len, :]
|
||||||
p_len = torch.LongTensor([p_len]).to(self.device)
|
p_len = torch.LongTensor([p_len]).to(self.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user