diff --git a/infer/modules/vc/modules.py b/infer/modules/vc/modules.py index 96db6a7..2e68fc3 100644 --- a/infer/modules/vc/modules.py +++ b/infer/modules/vc/modules.py @@ -20,6 +20,12 @@ from infer.modules.vc.utils import * from fairseq.data.dictionary import Dictionary import torch +def load_hubert_with_safe_globals(config): + safe_globals = {"fairseq.data.dictionary.Dictionary": Dictionary} + # Wrap the loading call in the safe_globals context manager. + with torch.serialization.safe_globals(safe_globals): + return load_hubert(config) + class VC: def __init__(self, config): self.n_spk = None @@ -170,8 +176,7 @@ class VC: times = [0, 0, 0] if self.hubert_model is None: - with torch.serialization.safe_globals({"fairseq.data.dictionary.Dictionary": Dictionary}): - self.hubert_model = load_hubert(self.config) + self.hubert_model = load_hubert_with_safe_globals(self.config) if file_index: file_index = (