fix(train): save small model fail

This commit is contained in:
源文雨 2024-06-04 04:07:19 +09:00
parent 5df99f2f73
commit 481f14dd74
8 changed files with 71 additions and 53 deletions

View File

@ -1 +1 @@
from .config import singleton_variable, Config
from .config import singleton_variable, Config, CPUConfig

View File

@ -262,3 +262,51 @@ class Config:
% (self.is_half, self.device)
)
return x_pad, x_query, x_center, x_max
@singleton_variable
class CPUConfig:
def __init__(self):
self.device = "cpu"
self.is_half = False
self.use_jit = False
self.n_cpu = 0
self.gpu_name = None
self.json_config = self.load_config_json()
self.gpu_mem = None
self.instead = "cpu"
self.preprocess_per = 3.7
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
@staticmethod
def load_config_json() -> dict:
d = {}
for config_file in version_config_list:
p = f"configs/inuse/{config_file}"
if not os.path.exists(p):
shutil.copy(f"configs/{config_file}", p)
with open(f"configs/inuse/{config_file}", "r") as f:
d[config_file] = json.load(f)
return d
def use_fp32_config(self):
for config_file in version_config_list:
self.json_config[config_file]["train"]["fp16_run"] = False
with open(f"configs/inuse/{config_file}", "r") as f:
strr = f.read().replace("true", "false")
with open(f"configs/inuse/{config_file}", "w") as f:
f.write(strr)
self.preprocess_per = 3.0
def device_config(self):
self.use_fp32_config()
if self.n_cpu == 0:
self.n_cpu = cpu_count()
# 5G显存配置
x_pad = 1
x_query = 6
x_center = 38
x_max = 41
return x_pad, x_query, x_center, x_max

View File

@ -822,8 +822,8 @@ with gr.Blocks(title="RVC WebUI") as app:
label=i18n("变调(整数, 半音数量, 升八度12降八度-12)"),
value=0,
)
input_audio0 = gr.File(
label=i18n("待处理音频文件"), file_types=["audio"]
input_audio0 = gr.Audio(
label=i18n("待处理音频文件"), type="filepath"
)
file_index2 = gr.Dropdown(
label=i18n("自动检测index路径,下拉式选择(dropdown)"),

View File

@ -14,7 +14,7 @@ MATPLOTLIB_FLAG = False
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging
"""
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@ -64,37 +64,8 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
# traceback.print_exc()
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
return model, optimizer, learning_rate, iteration
"""
# def load_checkpoint(checkpoint_path, model, optimizer=None):
# assert os.path.isfile(checkpoint_path)
# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
# iteration = checkpoint_dict['iteration']
# learning_rate = checkpoint_dict['learning_rate']
# if optimizer is not None:
# optimizer.load_state_dict(checkpoint_dict['optimizer'])
# # print(1111)
# saved_state_dict = checkpoint_dict['model']
# # print(1111)
#
# if hasattr(model, 'module'):
# state_dict = model.module.state_dict()
# else:
# state_dict = model.state_dict()
# new_state_dict= {}
# for k, v in state_dict.items():
# try:
# new_state_dict[k] = saved_state_dict[k]
# except:
# logger.info("%s is not in the checkpoint" % k)
# new_state_dict[k] = v
# if hasattr(model, 'module'):
# model.module.load_state_dict(new_state_dict)
# else:
# model.load_state_dict(new_state_dict)
# logger.info("Loaded checkpoint '{}' (epoch {})" .format(
# checkpoint_path, iteration))
# return model, optimizer, learning_rate, iteration
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
@ -159,7 +130,7 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
checkpoint_path,
)
"""
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
logger.info(
"Saving model and optimizer state at epoch {} to {}".format(
@ -184,7 +155,7 @@ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoin
},
checkpoint_path,
)
"""
def summarize(
writer,

View File

@ -3,6 +3,7 @@ import sys
import logging
logger = logging.getLogger(__name__)
logging.getLogger("numba").setLevel(logging.WARNING)
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))

View File

@ -5,19 +5,12 @@ import pathlib
from scipy.fft import fft
from pybase16384 import encode_to_string, decode_from_string
if __name__ == "__main__":
import os, sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from configs import Config, singleton_variable
from infer.lib.audio import load_audio
from configs import CPUConfig, singleton_variable
from .pipeline import Pipeline
from .utils import load_hubert
from infer.lib.audio import load_audio
class TorchSeedContext:
def __init__(self, seed):
@ -102,8 +95,9 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
np.divide(audio, audio_max, audio)
hbt = load_hubert(config.device, config.is_half)
audio_opt = pipeline.pipeline(
load_hubert(config.device, config.is_half),
hbt,
net_g,
0,
audio,
@ -120,6 +114,7 @@ def model_hash(config, tgt_sr, net_g, if_f0, version):
version,
0.33,
)
del hbt
opt_len = len(audio_opt)
diff = 48000 - opt_len
n = diff // 2
@ -141,7 +136,8 @@ def model_hash_ckpt(cpt):
SynthesizerTrnMs768NSFsid_nono,
)
config = Config()
config = CPUConfig()
with TorchSeedContext(114514):
tgt_sr = cpt["config"][-1]
if_f0 = cpt.get("f0", 1)
@ -167,7 +163,7 @@ def model_hash_ckpt(cpt):
h = model_hash(config, tgt_sr, net_g, if_f0, version)
del net_g
del net_g
return h
@ -217,4 +213,9 @@ def hash_similarity(h1: str, h2: str) -> float:
def hash_id(h: str) -> str:
return encode_to_string(hashlib.md5(decode_from_string(h)).digest())[:-1]
d = decode_from_string(h)
if len(d) != half_hash_len * 2:
return "invalid hash length"
return encode_to_string(
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])

View File

@ -8,6 +8,7 @@ logger = logging.getLogger(__name__)
from functools import lru_cache
from time import time
import faiss
import librosa
import numpy as np
import parselmouth
@ -330,7 +331,6 @@ class Pipeline(object):
and os.path.exists(file_index)
and index_rate != 0
):
if "faiss" not in sys.modules: import faiss
try:
index = faiss.read_index(file_index)
big_npy = index.reconstruct_n(0, index.ntotal)

View File

@ -2,8 +2,6 @@ import os
from fairseq import checkpoint_utils
from configs import singleton_variable
def get_index_path_from_model(sid):
return next(
@ -22,7 +20,6 @@ def get_index_path_from_model(sid):
)
@singleton_variable
def load_hubert(device, is_half):
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],