mirror of
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI.git
synced 2025-05-06 20:01:37 +08:00
fix(train): save small model fail
This commit is contained in:
parent
5df99f2f73
commit
481f14dd74
@ -1 +1 @@
|
||||
from .config import singleton_variable, Config
|
||||
from .config import singleton_variable, Config, CPUConfig
|
||||
|
@ -262,3 +262,51 @@ class Config:
|
||||
% (self.is_half, self.device)
|
||||
)
|
||||
return x_pad, x_query, x_center, x_max
|
||||
|
||||
@singleton_variable
|
||||
class CPUConfig:
|
||||
def __init__(self):
|
||||
self.device = "cpu"
|
||||
self.is_half = False
|
||||
self.use_jit = False
|
||||
self.n_cpu = 0
|
||||
self.gpu_name = None
|
||||
self.json_config = self.load_config_json()
|
||||
self.gpu_mem = None
|
||||
self.instead = "cpu"
|
||||
self.preprocess_per = 3.7
|
||||
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
||||
|
||||
@staticmethod
|
||||
def load_config_json() -> dict:
|
||||
d = {}
|
||||
for config_file in version_config_list:
|
||||
p = f"configs/inuse/{config_file}"
|
||||
if not os.path.exists(p):
|
||||
shutil.copy(f"configs/{config_file}", p)
|
||||
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||
d[config_file] = json.load(f)
|
||||
return d
|
||||
|
||||
def use_fp32_config(self):
|
||||
for config_file in version_config_list:
|
||||
self.json_config[config_file]["train"]["fp16_run"] = False
|
||||
with open(f"configs/inuse/{config_file}", "r") as f:
|
||||
strr = f.read().replace("true", "false")
|
||||
with open(f"configs/inuse/{config_file}", "w") as f:
|
||||
f.write(strr)
|
||||
self.preprocess_per = 3.0
|
||||
|
||||
def device_config(self):
|
||||
self.use_fp32_config()
|
||||
|
||||
if self.n_cpu == 0:
|
||||
self.n_cpu = cpu_count()
|
||||
|
||||
# 5G显存配置
|
||||
x_pad = 1
|
||||
x_query = 6
|
||||
x_center = 38
|
||||
x_max = 41
|
||||
|
||||
return x_pad, x_query, x_center, x_max
|
||||
|
@ -822,8 +822,8 @@ with gr.Blocks(title="RVC WebUI") as app:
|
||||
label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"),
|
||||
value=0,
|
||||
)
|
||||
input_audio0 = gr.File(
|
||||
label=i18n("待处理音频文件"), file_types=["audio"]
|
||||
input_audio0 = gr.Audio(
|
||||
label=i18n("待处理音频文件"), type="filepath"
|
||||
)
|
||||
file_index2 = gr.Dropdown(
|
||||
label=i18n("自动检测index路径,下拉式选择(dropdown)"),
|
||||
|
@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
logger = logging
|
||||
|
||||
|
||||
"""
|
||||
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
|
||||
# traceback.print_exc()
|
||||
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
|
||||
return model, optimizer, learning_rate, iteration
|
||||
"""
|
||||
|
||||
|
||||
# def load_checkpoint(checkpoint_path, model, optimizer=None):
|
||||
# assert os.path.isfile(checkpoint_path)
|
||||
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
||||
# iteration = checkpoint_dict['iteration']
|
||||
# learning_rate = checkpoint_dict['learning_rate']
|
||||
# if optimizer is not None:
|
||||
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
||||
# # print(1111)
|
||||
# saved_state_dict = checkpoint_dict['model']
|
||||
# # print(1111)
|
||||
#
|
||||
# if hasattr(model, 'module'):
|
||||
# state_dict = model.module.state_dict()
|
||||
# else:
|
||||
# state_dict = model.state_dict()
|
||||
# new_state_dict= {}
|
||||
# for k, v in state_dict.items():
|
||||
# try:
|
||||
# new_state_dict[k] = saved_state_dict[k]
|
||||
# except:
|
||||
# logger.info("%s is not in the checkpoint" % k)
|
||||
# new_state_dict[k] = v
|
||||
# if hasattr(model, 'module'):
|
||||
# model.module.load_state_dict(new_state_dict)
|
||||
# else:
|
||||
# model.load_state_dict(new_state_dict)
|
||||
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
|
||||
# checkpoint_path, iteration))
|
||||
# return model, optimizer, learning_rate, iteration
|
||||
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
||||
assert os.path.isfile(checkpoint_path)
|
||||
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
||||
@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
|
||||
logger.info(
|
||||
"Saving model and optimizer state at epoch {} to {}".format(
|
||||
@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
|
||||
},
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def summarize(
|
||||
writer,
|
||||
|
@ -3,6 +3,7 @@ import sys
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(os.path.join(now_dir))
|
||||
|
@ -5,19 +5,12 @@ import pathlib
|
||||
from scipy.fft import fft
|
||||
from pybase16384 import encode_to_string, decode_from_string
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os, sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
from configs import Config, singleton_variable
|
||||
from infer.lib.audio import load_audio
|
||||
from configs import CPUConfig, singleton_variable
|
||||
|
||||
from .pipeline import Pipeline
|
||||
from .utils import load_hubert
|
||||
|
||||
from infer.lib.audio import load_audio
|
||||
|
||||
|
||||
class TorchSeedContext:
|
||||
def __init__(self, seed):
|
||||
@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
audio_max = np.abs(audio).max() / 0.95
|
||||
if audio_max > 1:
|
||||
np.divide(audio, audio_max, audio)
|
||||
hbt = load_hubert(config.device, config.is_half)
|
||||
audio_opt = pipeline.pipeline(
|
||||
load_hubert(config.device, config.is_half),
|
||||
hbt,
|
||||
net_g,
|
||||
0,
|
||||
audio,
|
||||
@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
|
||||
version,
|
||||
0.33,
|
||||
)
|
||||
del hbt
|
||||
opt_len = len(audio_opt)
|
||||
diff = 48000 - opt_len
|
||||
n = diff // 2
|
||||
@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
|
||||
SynthesizerTrnMs768NSFsid_nono,
|
||||
)
|
||||
|
||||
config = Config()
|
||||
config = CPUConfig()
|
||||
|
||||
with TorchSeedContext(114514):
|
||||
tgt_sr = cpt["config"][-1]
|
||||
if_f0 = cpt.get("f0", 1)
|
||||
@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
|
||||
|
||||
h = model_hash(config, tgt_sr, net_g, if_f0, version)
|
||||
|
||||
del net_g
|
||||
del net_g
|
||||
|
||||
return h
|
||||
|
||||
@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
|
||||
|
||||
|
||||
def hash_id(h: str) -> str:
|
||||
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
|
||||
d = decode_from_string(h)
|
||||
if len(d) != half_hash_len * 2:
|
||||
return "invalid hash length"
|
||||
return encode_to_string(
|
||||
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
|
||||
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])
|
||||
|
@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
|
||||
from functools import lru_cache
|
||||
from time import time
|
||||
|
||||
import faiss
|
||||
import librosa
|
||||
import numpy as np
|
||||
import parselmouth
|
||||
@ -330,7 +331,6 @@ class Pipeline(object):
|
||||
and os.path.exists(file_index)
|
||||
and index_rate != 0
|
||||
):
|
||||
if "faiss" not in sys.modules: import faiss
|
||||
try:
|
||||
index = faiss.read_index(file_index)
|
||||
big_npy = index.reconstruct_n(0, index.ntotal)
|
||||
|
@ -2,8 +2,6 @@ import os
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
from configs import singleton_variable
|
||||
|
||||
|
||||
def get_index_path_from_model(sid):
|
||||
return next(
|
||||
@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
|
||||
)
|
||||
|
||||
|
||||
@singleton_variable
|
||||
def load_hubert(device, is_half):
|
||||
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
||||
["assets/hubert/hubert_base.pt"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user