Update config.py

This commit is contained in:
RVC-Boss 2023-05-02 12:07:03 +00:00 committed by GitHub
parent 69ea94609b
commit 8370356d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 116 additions and 109 deletions

225
config.py
View File

@ -1,109 +1,116 @@
import argparse import argparse
import glob import glob
import sys import sys
import torch import torch
from multiprocessing import cpu_count from multiprocessing import cpu_count
class Config: class Config:
def __init__(self): def __init__(self):
self.device = "cuda:0" self.device = "cuda:0"
self.is_half = True self.is_half = True
self.n_cpu = 0 self.n_cpu = 0
self.gpu_name = None self.gpu_name = None
self.gpu_mem = None self.gpu_mem = None
( (
self.python_cmd, self.python_cmd,
self.listen_port, self.listen_port,
self.iscolab, self.iscolab,
self.noparallel, self.noparallel,
self.noautoopen, self.noautoopen,
) = self.arg_parse() ) = self.arg_parse()
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
def arg_parse(self) -> tuple: def arg_parse(self) -> tuple:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7865, help="Listen port") parser.add_argument("--port", type=int, default=7865, help="Listen port")
parser.add_argument( parser.add_argument(
"--pycmd", type=str, default="python", help="Python command" "--pycmd", type=str, default="python", help="Python command"
) )
parser.add_argument("--colab", action="store_true", help="Launch in colab") parser.add_argument("--colab", action="store_true", help="Launch in colab")
parser.add_argument( parser.add_argument(
"--noparallel", action="store_true", help="Disable parallel processing" "--noparallel", action="store_true", help="Disable parallel processing"
) )
parser.add_argument( parser.add_argument(
"--noautoopen", "--noautoopen",
action="store_true", action="store_true",
help="Do not open in browser automatically", help="Do not open in browser automatically",
) )
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865 cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
return ( return (
cmd_opts.pycmd, cmd_opts.pycmd,
cmd_opts.port, cmd_opts.port,
cmd_opts.colab, cmd_opts.colab,
cmd_opts.noparallel, cmd_opts.noparallel,
cmd_opts.noautoopen, cmd_opts.noautoopen,
) )
def device_config(self) -> tuple: def device_config(self) -> tuple:
if torch.cuda.is_available(): if torch.cuda.is_available():
i_device = int(self.device.split(":")[-1]) i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device) self.gpu_name = torch.cuda.get_device_name(i_device)
if ( if (
"16" in self.gpu_name ("16" in gpu_name and "V100"not in gpu_name.upper())
or "P40" in self.gpu_name.upper() or "P40" in self.gpu_name.upper()
or "1070" in self.gpu_name or "1070" in self.gpu_name
or "1080" in self.gpu_name or "1080" in self.gpu_name
): ):
print("16系显卡强制单精度") print("16系显卡强制单精度")
self.is_half = False self.is_half = False
for config_file in ["32k.json", "40k.json", "48k.json"]: for config_file in ["32k.json", "40k.json", "48k.json"]:
with open(f"configs/{config_file}", "r+") as f: with open(f"configs/{config_file}", "r") as f:
strr = f.read().replace("true", "false") strr = f.read().replace("true", "false")
f.write(strr) with open(f"configs/{config_file}", "w") as f:
self.gpu_mem = int( f.write(strr)
torch.cuda.get_device_properties(i_device).total_memory with open("trainset_preprocess_pipeline_print.py", "r") as f:
/ 1024 strr = f.read().replace("3.7", "3.0")
/ 1024 with open("trainset_preprocess_pipeline_print.py", "w") as f:
/ 1024 f.write(strr)
+ 0.4 else:
) self.gpu_name = None
if self.gpu_mem <= 4: self.gpu_mem = int(
with open("trainset_preprocess_pipeline_print.py", "r+") as f: torch.cuda.get_device_properties(i_device).total_memory
strr = f.read().replace("3.7", "3.0") / 1024
f.write(strr) / 1024
else: / 1024
self.gpu_name = None + 0.4
elif torch.backends.mps.is_available(): )
print("没有发现支持的N卡, 使用MPS进行推理") if self.gpu_mem <= 4:
self.device = "mps" with open("trainset_preprocess_pipeline_print.py", "r") as f:
else: strr = f.read().replace("3.7", "3.0")
print("没有发现支持的N卡, 使用CPU进行推理") with open("trainset_preprocess_pipeline_print.py", "w") as f:
self.device = "cpu" f.write(strr)
elif torch.backends.mps.is_available():
if self.n_cpu == 0: print("没有发现支持的N卡, 使用MPS进行推理")
self.n_cpu = cpu_count() self.device = "mps"
else:
if self.is_half: print("没有发现支持的N卡, 使用CPU进行推理")
# 6G显存配置 self.device = "cpu"
x_pad = 3 self.is_half = True
x_query = 10
x_center = 60 if self.n_cpu == 0:
x_max = 65 self.n_cpu = cpu_count()
else:
# 5G显存配置 if self.is_half:
x_pad = 1 # 6G显存配置
x_query = 6 x_pad = 3
x_center = 38 x_query = 10
x_max = 41 x_center = 60
x_max = 65
if self.gpu_name != None and self.gpu_mem <= 4: else:
x_pad = 1 # 5G显存配置
x_query = 5 x_pad = 1
x_center = 30 x_query = 6
x_max = 32 x_center = 38
x_max = 41
return x_pad, x_query, x_center, x_max
if self.gpu_mem != None and self.gpu_mem <= 4:
x_pad = 1
x_query = 5
x_center = 30
x_max = 32
return x_pad, x_query, x_center, x_max