diff --git a/train/process_ckpt.py b/train/process_ckpt.py index 3840345..0be10b7 100644 --- a/train/process_ckpt.py +++ b/train/process_ckpt.py @@ -69,12 +69,7 @@ def merge(path1,path2,alpha1,sr,f0,info,name): return opt ckpt1 = torch.load(path1, map_location="cpu") ckpt2 = torch.load(path2, map_location="cpu") - opt["config"] = ckpt1["config"] - ''' - if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] - elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] - elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] - ''' + cfg = ckpt1["config"] if("model"in ckpt1): ckpt1=extract(ckpt1) else: ckpt1=ckpt1["weight"] if("model"in ckpt2): ckpt2=extract(ckpt2) @@ -91,6 +86,12 @@ def merge(path1,path2,alpha1,sr,f0,info,name): opt["weight"][key] = (alpha1*(ckpt1[key].float())+(1-alpha1)*(ckpt2[key].float())).half() # except: # pdb.set_trace() + opt["config"] = cfg + ''' + if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000] + elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000] + elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000] + ''' opt["sr"]=sr opt["f0"]=1 if f0=="是"else 0 opt["info"]=info