import torch,traceback,os,pdb from collections import OrderedDict def savee(ckpt,sr,if_f0,name,epoch): try: opt = OrderedDict() opt["weight"] = {} for key in ckpt.keys(): if ("enc_q" in key): continue opt["weight"][key] = ckpt[key].half() if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4,4], 109, 256, 48000] elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] opt["info"] = "%sepoch"%epoch opt["sr"] = sr opt["f0"] =if_f0 torch.save(opt, "weights/%s.pth"%name) return "Success." except: return traceback.format_exc() def show_info(path): try: a = torch.load(path, map_location="cpu") return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s"%(a.get("info","None"),a.get("sr","None"),a.get("f0","None"),) except: return traceback.format_exc() def extract_small_model(path,name,sr,if_f0,info): try: ckpt = torch.load(path, map_location="cpu") if("model"in ckpt):ckpt=ckpt["model"] opt = OrderedDict() opt["weight"] = {} for key in ckpt.keys(): if ("enc_q" in key): continue opt["weight"][key] = ckpt[key].half() if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000] elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4,4], 109, 256, 48000] elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] if(info==""):info="Extracted model." opt["info"] = info opt["sr"] = sr opt["f0"] =int(if_f0) torch.save(opt, "weights/%s.pth"%name) return "Success." except: return traceback.format_exc() def change_info(path,info,name): try: ckpt = torch.load(path, map_location="cpu") ckpt["info"]=info if(name==""):name=os.path.basename(path) torch.save(ckpt, "weights/%s"%name) return "Success." except: return traceback.format_exc() def merge(path1,path2,alpha1,sr,f0,info,name): try: def extract(ckpt): a = ckpt["model"] opt = OrderedDict() opt["weight"] = {} for key in a.keys(): if ("enc_q" in key): continue opt["weight"][key] = a[key] return opt ckpt1 = torch.load(path1, map_location="cpu") ckpt2 = torch.load(path2, map_location="cpu") cfg = ckpt1["config"] if("model"in ckpt1): ckpt1=extract(ckpt1) else: ckpt1=ckpt1["weight"] if("model"in ckpt2): ckpt2=extract(ckpt2) else: ckpt2=ckpt2["weight"] if(sorted(list(ckpt1.keys()))!=sorted(list(ckpt2.keys()))):return "Fail to merge the models. The model architectures are not the same." opt = OrderedDict() opt["weight"] = {} for key in ckpt1.keys(): # try: if(key=="emb_g.weight"and ckpt1[key].shape!=ckpt2[key].shape): min_shape0=min(ckpt1[key].shape[0],ckpt2[key].shape[0]) opt["weight"][key] = (alpha1 * (ckpt1[key][:min_shape0].float()) + (1 - alpha1) * (ckpt2[key][:min_shape0].float())).half() else: opt["weight"][key] = (alpha1*(ckpt1[key].float())+(1-alpha1)*(ckpt2[key].float())).half() # except: # pdb.set_trace() opt["config"] = cfg ''' if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] ''' opt["sr"]=sr opt["f0"]=1 if f0=="是"else 0 opt["info"]=info torch.save(opt, "weights/%s.pth"%name) return "Success." except: return traceback.format_exc()