mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-04-05 04:08:58 +08:00
188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
# Copyright (c) 2025 SparkAudio
|
|
# 2025 Xinsheng Wang (w.xinshawn@gmail.com)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Heavily based on https://github.com/lucidrains/vector-quantize-pytorch
|
|
|
|
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch.nn.utils import weight_norm
|
|
|
|
|
|
def WNConv1d(*args, **kwargs):
|
|
return weight_norm(nn.Conv1d(*args, **kwargs))
|
|
|
|
|
|
def ema_inplace(moving_avg, new, decay):
|
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
|
|
|
|
class FactorizedVectorQuantize(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
codebook_size: int,
|
|
codebook_dim: int,
|
|
commitment: float,
|
|
codebook_loss_weight: float = 1.0,
|
|
decay: float = 0.99,
|
|
threshold_ema_dead_code: float = 2,
|
|
momentum: float = 0.99,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.codebook_size = codebook_size
|
|
self.codebook_dim = codebook_dim
|
|
self.commitment = commitment
|
|
self.codebook_loss_weight = codebook_loss_weight
|
|
self.decay = decay
|
|
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
self.momentum = momentum
|
|
|
|
if input_dim != self.codebook_dim:
|
|
self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1)
|
|
self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1)
|
|
|
|
else:
|
|
self.in_project = nn.Identity()
|
|
self.out_project = nn.Identity()
|
|
|
|
self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim)
|
|
self.register_buffer("cluster_size", torch.zeros(self.codebook_size))
|
|
|
|
def forward(self, z: torch.Tensor) -> Dict[str, Any]:
|
|
"""Quantized the input tensor using a fixed codebook and returns
|
|
the corresponding codebook vectors
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor[B x D x T]
|
|
|
|
Returns
|
|
-------
|
|
Tensor[B x D x T]
|
|
Quantized continuous representation of input
|
|
Tensor[1]
|
|
Commitment loss to train encoder to predict vectors closer to codebook
|
|
entries
|
|
Tensor[1]
|
|
Codebook loss to update the codebook
|
|
Tensor[B x T]
|
|
Codebook indices (quantized discrete representation of input)
|
|
Tensor[B x D x T]
|
|
Projected latents (continuous representation of input before quantization)
|
|
"""
|
|
# transpose since we use linear
|
|
|
|
# Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim
|
|
z_e = self.in_project(z)
|
|
z_q, indices, dists = self.decode_latents(z_e)
|
|
|
|
# statistic the usage of codes
|
|
embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype)
|
|
avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0)
|
|
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
|
|
|
active_num = (embed_onehot.sum(0).sum(0) > 0).sum()
|
|
if self.training:
|
|
# We do the expiry of code at that point as buffers are in sync
|
|
# and all the workers will take the same decision.
|
|
ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay)
|
|
active_num = sum(self.cluster_size > self.threshold_ema_dead_code)
|
|
|
|
if self.training:
|
|
commit_loss = (
|
|
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
|
* self.commitment
|
|
)
|
|
|
|
codebook_loss = (
|
|
F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
|
* self.codebook_loss_weight
|
|
)
|
|
|
|
else:
|
|
commit_loss = torch.zeros(0, device=z.device)
|
|
codebook_loss = torch.zeros(0, device=z.device)
|
|
|
|
z_q = (
|
|
z_e + (z_q - z_e).detach()
|
|
) # noop in forward pass, straight-through gradient estimator in backward pass
|
|
|
|
z_q = self.out_project(z_q)
|
|
|
|
vq_loss = (commit_loss + codebook_loss).mean()
|
|
|
|
return {
|
|
"z_q": z_q,
|
|
"indices": indices,
|
|
"dists": dists,
|
|
"vq_loss": vq_loss,
|
|
"perplexity": perplexity,
|
|
"active_num": active_num.float(),
|
|
}
|
|
|
|
def vq2emb(self, vq, out_proj=True):
|
|
emb = self.embed_code(vq)
|
|
if out_proj:
|
|
emb = self.out_project(emb)
|
|
return emb
|
|
|
|
def tokenize(self, z: torch.Tensor) -> torch.Tensor:
|
|
"""tokenize the input tensor"""
|
|
z_e = self.in_project(z)
|
|
_, indices, _ = self.decode_latents(z_e)
|
|
return indices
|
|
|
|
def detokenize(self, indices):
|
|
"""detokenize the input indices"""
|
|
z_q = self.decode_code(indices)
|
|
z_q = self.out_project(z_q)
|
|
return z_q
|
|
|
|
def get_emb(self):
|
|
return self.codebook.weight
|
|
|
|
def embed_code(self, embed_id):
|
|
return F.embedding(embed_id, self.codebook.weight)
|
|
|
|
def decode_code(self, embed_id):
|
|
return self.embed_code(embed_id).transpose(1, 2)
|
|
|
|
def decode_latents(self, latents):
|
|
encodings = rearrange(latents, "b d t -> (b t) d")
|
|
codebook = self.codebook.weight
|
|
|
|
# L2 normalize encodings and codebook
|
|
encodings = F.normalize(encodings)
|
|
codebook = F.normalize(codebook)
|
|
|
|
# Compute euclidean distance between encodings and codebook,
|
|
# with L2 normalization, the distance is equal to cosine distance
|
|
dist = (
|
|
encodings.pow(2).sum(1, keepdim=True)
|
|
- 2 * encodings @ codebook.t()
|
|
+ codebook.pow(2).sum(1, keepdim=True).t()
|
|
)
|
|
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
|
z_q = self.decode_code(indices)
|
|
|
|
return z_q, indices, dist
|