diff --git a/infer-web.py b/infer-web.py index ac107ff..5b6efc0 100644 --- a/infer-web.py +++ b/infer-web.py @@ -34,7 +34,7 @@ from lib.infer_pack.models import ( from lib.infer_pack.models_onnx import SynthesizerTrnMsNSFsidM from infer_uvr5 import _audio_pre_, _audio_pre_new from my_utils import load_audio -from train.process_ckpt import change_info, extract_small_model, merge, show_info +from lib.train.process_ckpt import change_info, extract_small_model, merge, show_info from vc_infer_pipeline import VC from sklearn.cluster import MiniBatchKMeans diff --git a/lib/train/data_utils.py b/lib/train/data_utils.py index af044d0..3437e24 100644 --- a/lib/train/data_utils.py +++ b/lib/train/data_utils.py @@ -3,8 +3,8 @@ import numpy as np import torch import torch.utils.data -from mel_processing import spectrogram_torch -from utils import load_wav_to_torch, load_filepaths_and_text +from lib.train.mel_processing import spectrogram_torch +from lib.train.utils import load_wav_to_torch, load_filepaths_and_text class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset): diff --git a/lib/train/losses.py b/lib/train/losses.py index 4d71f86..aa7bd81 100644 --- a/lib/train/losses.py +++ b/lib/train/losses.py @@ -1,5 +1,4 @@ import torch -from torch.nn import functional as F def feature_loss(fmap_r, fmap_g): diff --git a/lib/train/process_ckpt.py b/lib/train/process_ckpt.py index 8f9c3d7..324d5a5 100644 --- a/lib/train/process_ckpt.py +++ b/lib/train/process_ckpt.py @@ -1,4 +1,4 @@ -import torch, traceback, os, pdb, sys +import torch, traceback, os, sys now_dir = os.getcwd() sys.path.append(now_dir) diff --git a/lib/train/utils.py b/lib/train/utils.py index 8884e43..9c0fb5c 100644 --- a/lib/train/utils.py +++ b/lib/train/utils.py @@ -44,9 +44,10 @@ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): model.module.load_state_dict(new_state_dict, strict=False) else: model.load_state_dict(new_state_dict, strict=False) + return model go(combd, "combd") - go(sbd, "sbd") + model = go(sbd, "sbd") ############# logger.info("Loaded model weights")