Update process_ckpt.py

This commit is contained in:
RVC-Boss 2023-05-17 15:39:24 +00:00 committed by GitHub
parent e5374b2041
commit 6fb1f8c1b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 255 additions and 252 deletions

View File

@ -1,252 +1,255 @@
import torch, traceback, os, pdb import torch, traceback, os, pdb,sys
from collections import OrderedDict now_dir = os.getcwd()
sys.path.append(now_dir)
from collections import OrderedDict
def savee(ckpt, sr, if_f0, name, epoch, version): from i18n import I18nAuto
try: i18n = I18nAuto()
opt = OrderedDict()
opt["weight"] = {} def savee(ckpt, sr, if_f0, name, epoch, version):
for key in ckpt.keys(): try:
if "enc_q" in key: opt = OrderedDict()
continue opt["weight"] = {}
opt["weight"][key] = ckpt[key].half() for key in ckpt.keys():
if sr == "40k": if "enc_q" in key:
opt["config"] = [ continue
1025, opt["weight"][key] = ckpt[key].half()
32, if sr == "40k":
192, opt["config"] = [
192, 1025,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 10, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 10, 2, 2],
256, 512,
40000, [16, 16, 4, 4],
] 109,
elif sr == "48k": 256,
opt["config"] = [ 40000,
1025, ]
32, elif sr == "48k":
192, opt["config"] = [
192, 1025,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 6, 2, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 6, 2, 2, 2],
256, 512,
48000, [16, 16, 4, 4, 4],
] 109,
elif sr == "32k": 256,
opt["config"] = [ 48000,
513, ]
32, elif sr == "32k":
192, opt["config"] = [
192, 513,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 4, 2, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 4, 2, 2, 2],
256, 512,
32000, [16, 16, 4, 4, 4],
] 109,
opt["info"] = "%sepoch" % epoch 256,
opt["sr"] = sr 32000,
opt["f0"] = if_f0 ]
opt["version"] = version opt["info"] = "%sepoch" % epoch
torch.save(opt, "weights/%s.pth" % name) opt["sr"] = sr
return "Success." opt["f0"] = if_f0
except: opt["version"] = version
return traceback.format_exc() torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
def show_info(path): return traceback.format_exc()
try:
a = torch.load(path, map_location="cpu")
return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % ( def show_info(path):
a.get("info", "None"), try:
a.get("sr", "None"), a = torch.load(path, map_location="cpu")
a.get("f0", "None"), return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
a.get("version", "None"), a.get("info", "None"),
) a.get("sr", "None"),
except: a.get("f0", "None"),
return traceback.format_exc() a.get("version", "None"),
)
except:
def extract_small_model(path, name, sr, if_f0, info, version): return traceback.format_exc()
try:
ckpt = torch.load(path, map_location="cpu")
if "model" in ckpt: def extract_small_model(path, name, sr, if_f0, info, version):
ckpt = ckpt["model"] try:
opt = OrderedDict() ckpt = torch.load(path, map_location="cpu")
opt["weight"] = {} if "model" in ckpt:
for key in ckpt.keys(): ckpt = ckpt["model"]
if "enc_q" in key: opt = OrderedDict()
continue opt["weight"] = {}
opt["weight"][key] = ckpt[key].half() for key in ckpt.keys():
if sr == "40k": if "enc_q" in key:
opt["config"] = [ continue
1025, opt["weight"][key] = ckpt[key].half()
32, if sr == "40k":
192, opt["config"] = [
192, 1025,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 10, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 10, 2, 2],
256, 512,
40000, [16, 16, 4, 4],
] 109,
elif sr == "48k": 256,
opt["config"] = [ 40000,
1025, ]
32, elif sr == "48k":
192, opt["config"] = [
192, 1025,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 6, 2, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 6, 2, 2, 2],
256, 512,
48000, [16, 16, 4, 4, 4],
] 109,
elif sr == "32k": 256,
opt["config"] = [ 48000,
513, ]
32, elif sr == "32k":
192, opt["config"] = [
192, 513,
768, 32,
2, 192,
6, 192,
3, 768,
0, 2,
"1", 6,
[3, 7, 11], 3,
[[1, 3, 5], [1, 3, 5], [1, 3, 5]], 0,
[10, 4, 2, 2, 2], "1",
512, [3, 7, 11],
[16, 16, 4, 4, 4], [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
109, [10, 4, 2, 2, 2],
256, 512,
32000, [16, 16, 4, 4, 4],
] 109,
if info == "": 256,
info = "Extracted model." 32000,
opt["info"] = info ]
opt["version"] = version if info == "":
opt["sr"] = sr info = "Extracted model."
opt["f0"] = int(if_f0) opt["info"] = info
torch.save(opt, "weights/%s.pth" % name) opt["version"] = version
return "Success." opt["sr"] = sr
except: opt["f0"] = int(if_f0)
return traceback.format_exc() torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
def change_info(path, info, name): return traceback.format_exc()
try:
ckpt = torch.load(path, map_location="cpu")
ckpt["info"] = info def change_info(path, info, name):
if name == "": try:
name = os.path.basename(path) ckpt = torch.load(path, map_location="cpu")
torch.save(ckpt, "weights/%s" % name) ckpt["info"] = info
return "Success." if name == "":
except: name = os.path.basename(path)
return traceback.format_exc() torch.save(ckpt, "weights/%s" % name)
return "Success."
except:
def merge(path1, path2, alpha1, sr, f0, info, name, version): return traceback.format_exc()
try:
def extract(ckpt): def merge(path1, path2, alpha1, sr, f0, info, name, version):
a = ckpt["model"] try:
opt = OrderedDict()
opt["weight"] = {} def extract(ckpt):
for key in a.keys(): a = ckpt["model"]
if "enc_q" in key: opt = OrderedDict()
continue opt["weight"] = {}
opt["weight"][key] = a[key] for key in a.keys():
return opt if "enc_q" in key:
continue
ckpt1 = torch.load(path1, map_location="cpu") opt["weight"][key] = a[key]
ckpt2 = torch.load(path2, map_location="cpu") return opt
cfg = ckpt1["config"]
if "model" in ckpt1: ckpt1 = torch.load(path1, map_location="cpu")
ckpt1 = extract(ckpt1) ckpt2 = torch.load(path2, map_location="cpu")
else: cfg = ckpt1["config"]
ckpt1 = ckpt1["weight"] if "model" in ckpt1:
if "model" in ckpt2: ckpt1 = extract(ckpt1)
ckpt2 = extract(ckpt2) else:
else: ckpt1 = ckpt1["weight"]
ckpt2 = ckpt2["weight"] if "model" in ckpt2:
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): ckpt2 = extract(ckpt2)
return "Fail to merge the models. The model architectures are not the same." else:
opt = OrderedDict() ckpt2 = ckpt2["weight"]
opt["weight"] = {} if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
for key in ckpt1.keys(): return "Fail to merge the models. The model architectures are not the same."
# try: opt = OrderedDict()
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape: opt["weight"] = {}
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0]) for key in ckpt1.keys():
opt["weight"][key] = ( # try:
alpha1 * (ckpt1[key][:min_shape0].float()) if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
+ (1 - alpha1) * (ckpt2[key][:min_shape0].float()) min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
).half() opt["weight"][key] = (
else: alpha1 * (ckpt1[key][:min_shape0].float())
opt["weight"][key] = ( + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float()) ).half()
).half() else:
# except: opt["weight"][key] = (
# pdb.set_trace() alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
opt["config"] = cfg ).half()
""" # except:
if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] # pdb.set_trace()
elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] opt["config"] = cfg
elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] """
""" if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
opt["sr"] = sr elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
opt["f0"] = 1 if f0 == "" else 0 elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
opt["version"] = version """
opt["info"] = info opt["sr"] = sr
torch.save(opt, "weights/%s.pth" % name) opt["f0"] = 1 if f0 == i18n("") else 0
return "Success." opt["version"] = version
except: opt["info"] = info
return traceback.format_exc() torch.save(opt, "weights/%s.pth" % name)
return "Success."
except:
return traceback.format_exc()