mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-06 20:01:37 +08:00
fix(train): save small model fail
This commit is contained in:
parent
5df99f2f73
commit
481f14dd74
@ -1 +1 @@
|
|||||||
from .config import singleton_variable, Config
|
from .config import singleton_variable, Config, CPUConfig
|
||||||
|
@ -262,3 +262,51 @@ class Config:
|
|||||||
% (self.is_half, self.device)
|
% (self.is_half, self.device)
|
||||||
)
|
)
|
||||||
return x_pad, x_query, x_center, x_max
|
return x_pad, x_query, x_center, x_max
|
||||||
|
|
||||||
|
@singleton_variable
|
||||||
|
class CPUConfig:
|
||||||
|
def __init__(self):
|
||||||
|
self.device = "cpu"
|
||||||
|
self.is_half = False
|
||||||
|
self.use_jit = False
|
||||||
|
self.n_cpu = 0
|
||||||
|
self.gpu_name = None
|
||||||
|
self.json_config = self.load_config_json()
|
||||||
|
self.gpu_mem = None
|
||||||
|
self.instead = "cpu"
|
||||||
|
self.preprocess_per = 3.7
|
||||||
|
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_config_json() -> dict:
|
||||||
|
d = {}
|
||||||
|
for config_file in version_config_list:
|
||||||
|
p = f"configs/inuse/{config_file}"
|
||||||
|
if not os.path.exists(p):
|
||||||
|
shutil.copy(f"configs/{config_file}", p)
|
||||||
|
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||||
|
d[config_file] = json.load(f)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def use_fp32_config(self):
|
||||||
|
for config_file in version_config_list:
|
||||||
|
self.json_config[config_file]["train"]["fp16_run"] = False
|
||||||
|
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||||
|
strr = f.read().replace("true", "false")
|
||||||
|
with open(f"configs/inuse/{config_file}", "w") as f:
|
||||||
|
f.write(strr)
|
||||||
|
self.preprocess_per = 3.0
|
||||||
|
|
||||||
|
def device_config(self):
|
||||||
|
self.use_fp32_config()
|
||||||
|
|
||||||
|
if self.n_cpu == 0:
|
||||||
|
self.n_cpu = cpu_count()
|
||||||
|
|
||||||
|
# 5G显存配置
|
||||||
|
x_pad = 1
|
||||||
|
x_query = 6
|
||||||
|
x_center = 38
|
||||||
|
x_max = 41
|
||||||
|
|
||||||
|
return x_pad, x_query, x_center, x_max
|
||||||
|
@ -822,8 +822,8 @@ with gr.Blocks(title="RVC WebUI") as app:
|
|||||||
label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"),
|
label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"),
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
input_audio0 = gr.File(
|
input_audio0 = gr.Audio(
|
||||||
label=i18n("待处理音频文件"), file_types=["audio"]
|
label=i18n("待处理音频文件"), type="filepath"
|
||||||
)
|
)
|
||||||
file_index2 = gr.Dropdown(
|
file_index2 = gr.Dropdown(
|
||||||
label=i18n("自动检测index路径,下拉式选择(dropdown)"),
|
label=i18n("自动检测index路径,下拉式选择(dropdown)"),
|
||||||
|
@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
|
|||||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||||
logger = logging
|
logger = logging
|
||||||
|
|
||||||
|
"""
|
||||||
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||||
assert os.path.isfile(checkpoint_path)
|
assert os.path.isfile(checkpoint_path)
|
||||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
|||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||||
return model, optimizer, learning_rate, iteration
|
return model, optimizer, learning_rate, iteration
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
# def load_checkpoint(checkpoint_path, model, optimizer=None):
|
|
||||||
# assert os.path.isfile(checkpoint_path)
|
|
||||||
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
|
||||||
# iteration = checkpoint_dict['iteration']
|
|
||||||
# learning_rate = checkpoint_dict['learning_rate']
|
|
||||||
# if optimizer is not None:
|
|
||||||
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
|
||||||
# # print(1111)
|
|
||||||
# saved_state_dict = checkpoint_dict['model']
|
|
||||||
# # print(1111)
|
|
||||||
#
|
|
||||||
# if hasattr(model, 'module'):
|
|
||||||
# state_dict = model.module.state_dict()
|
|
||||||
# else:
|
|
||||||
# state_dict = model.state_dict()
|
|
||||||
# new_state_dict= {}
|
|
||||||
# for k, v in state_dict.items():
|
|
||||||
# try:
|
|
||||||
# new_state_dict[k] = saved_state_dict[k]
|
|
||||||
# except:
|
|
||||||
# logger.info("%s is not in the checkpoint" % k)
|
|
||||||
# new_state_dict[k] = v
|
|
||||||
# if hasattr(model, 'module'):
|
|
||||||
# model.module.load_state_dict(new_state_dict)
|
|
||||||
# else:
|
|
||||||
# model.load_state_dict(new_state_dict)
|
|
||||||
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
|
|
||||||
# checkpoint_path, iteration))
|
|
||||||
# return model, optimizer, learning_rate, iteration
|
|
||||||
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||||
assert os.path.isfile(checkpoint_path)
|
assert os.path.isfile(checkpoint_path)
|
||||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||||
@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
||||||
logger.info(
|
logger.info(
|
||||||
"Saving model and optimizer state at epoch {} to {}".format(
|
"Saving model and optimizer state at epoch {} to {}".format(
|
||||||
@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
|
|||||||
},
|
},
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
def summarize(
|
def summarize(
|
||||||
writer,
|
writer,
|
||||||
|
@ -3,6 +3,7 @@ import sys
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(os.path.join(now_dir))
|
sys.path.append(os.path.join(now_dir))
|
||||||
|
@ -5,19 +5,12 @@ import pathlib
|
|||||||
from scipy.fft import fft
|
from scipy.fft import fft
|
||||||
from pybase16384 import encode_to_string, decode_from_string
|
from pybase16384 import encode_to_string, decode_from_string
|
||||||
|
|
||||||
if __name__ == "__main__":
|
from infer.lib.audio import load_audio
|
||||||
import os, sys
|
from configs import CPUConfig, singleton_variable
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
|
||||||
sys.path.append(now_dir)
|
|
||||||
|
|
||||||
from configs import Config, singleton_variable
|
|
||||||
|
|
||||||
from .pipeline import Pipeline
|
from .pipeline import Pipeline
|
||||||
from .utils import load_hubert
|
from .utils import load_hubert
|
||||||
|
|
||||||
from infer.lib.audio import load_audio
|
|
||||||
|
|
||||||
|
|
||||||
class TorchSeedContext:
|
class TorchSeedContext:
|
||||||
def __init__(self, seed):
|
def __init__(self, seed):
|
||||||
@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
|||||||
audio_max = np.abs(audio).max() / 0.95
|
audio_max = np.abs(audio).max() / 0.95
|
||||||
if audio_max > 1:
|
if audio_max > 1:
|
||||||
np.divide(audio, audio_max, audio)
|
np.divide(audio, audio_max, audio)
|
||||||
|
hbt = load_hubert(config.device, config.is_half)
|
||||||
audio_opt = pipeline.pipeline(
|
audio_opt = pipeline.pipeline(
|
||||||
load_hubert(config.device, config.is_half),
|
hbt,
|
||||||
net_g,
|
net_g,
|
||||||
0,
|
0,
|
||||||
audio,
|
audio,
|
||||||
@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
|||||||
version,
|
version,
|
||||||
0.33,
|
0.33,
|
||||||
)
|
)
|
||||||
|
del hbt
|
||||||
opt_len = len(audio_opt)
|
opt_len = len(audio_opt)
|
||||||
diff = 48000 - opt_len
|
diff = 48000 - opt_len
|
||||||
n = diff // 2
|
n = diff // 2
|
||||||
@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
|
|||||||
SynthesizerTrnMs768NSFsid_nono,
|
SynthesizerTrnMs768NSFsid_nono,
|
||||||
)
|
)
|
||||||
|
|
||||||
config = Config()
|
config = CPUConfig()
|
||||||
|
|
||||||
with TorchSeedContext(114514):
|
with TorchSeedContext(114514):
|
||||||
tgt_sr = cpt["config"][-1]
|
tgt_sr = cpt["config"][-1]
|
||||||
if_f0 = cpt.get("f0", 1)
|
if_f0 = cpt.get("f0", 1)
|
||||||
@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
|
|||||||
|
|
||||||
|
|
||||||
def hash_id(h: str) -> str:
|
def hash_id(h: str) -> str:
|
||||||
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
|
d = decode_from_string(h)
|
||||||
|
if len(d) != half_hash_len * 2:
|
||||||
|
return "invalid hash length"
|
||||||
|
return encode_to_string(
|
||||||
|
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
|
||||||
|
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])
|
||||||
|
@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
import faiss
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import parselmouth
|
import parselmouth
|
||||||
@ -330,7 +331,6 @@ class Pipeline(object):
|
|||||||
and os.path.exists(file_index)
|
and os.path.exists(file_index)
|
||||||
and index_rate != 0
|
and index_rate != 0
|
||||||
):
|
):
|
||||||
if "faiss" not in sys.modules: import faiss
|
|
||||||
try:
|
try:
|
||||||
index = faiss.read_index(file_index)
|
index = faiss.read_index(file_index)
|
||||||
big_npy = index.reconstruct_n(0, index.ntotal)
|
big_npy = index.reconstruct_n(0, index.ntotal)
|
||||||
|
@ -2,8 +2,6 @@ import os
|
|||||||
|
|
||||||
from fairseq import checkpoint_utils
|
from fairseq import checkpoint_utils
|
||||||
|
|
||||||
from configs import singleton_variable
|
|
||||||
|
|
||||||
|
|
||||||
def get_index_path_from_model(sid):
|
def get_index_path_from_model(sid):
|
||||||
return next(
|
return next(
|
||||||
@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@singleton_variable
|
|
||||||
def load_hubert(device, is_half):
|
def load_hubert(device, is_half):
|
||||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||||
["assets/hubert/hubert_base.pt"],
|
["assets/hubert/hubert_base.pt"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user