From f637bb8788fcfc303e43d05fc39c9fa089985eac Mon Sep 17 00:00:00 2001 From: Ftps <63702646+Tps-F@users.noreply.github.com> Date: Sun, 13 Aug 2023 12:45:20 +0900 Subject: [PATCH] Cleanup config.py (#992) * Update config.py * miss --- config.py | 51 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/config.py b/config.py index 035713a..1eee375 100644 --- a/config.py +++ b/config.py @@ -1,3 +1,4 @@ +import os import argparse import sys import torch @@ -35,11 +36,13 @@ class Config: self.iscolab, self.noparallel, self.noautoopen, + self.dml ) = self.arg_parse() - self.instead="" + self.instead = "" self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() - def arg_parse(self) -> tuple: + @staticmethod + def arg_parse() -> tuple: exe = sys.executable or "python" parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=7865, help="Listen port") @@ -61,13 +64,14 @@ class Config: cmd_opts = parser.parse_args() cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865 - self.dml=cmd_opts.dml + return ( cmd_opts.pycmd, cmd_opts.port, cmd_opts.colab, cmd_opts.noparallel, cmd_opts.noautoopen, + cmd_opts.dml ) # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. @@ -112,12 +116,12 @@ class Config: f.write(strr) elif self.has_mps(): print("No supported Nvidia GPU found") - self.device = self.instead="mps" + self.device = self.instead = "mps" self.is_half = False use_fp32_config() else: print("No supported Nvidia GPU found") - self.device = self.instead="cpu" + self.device = self.instead = "cpu" self.is_half = False use_fp32_config() @@ -137,25 +141,34 @@ class Config: x_center = 38 x_max = 41 - if self.gpu_mem != None and self.gpu_mem <= 4: + if self.gpu_mem is not None and self.gpu_mem <= 4: x_pad = 1 x_query = 5 x_center = 30 x_max = 32 - if(self.dml==True): + if self.dml: print("use DirectML instead") - try:os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") - except:pass - try:os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") - except:pass + try: + os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") + except: + pass + try: + os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") + except: + + pass import torch_directml - self.device= torch_directml.device(torch_directml.default_device()) - self.is_half=False + self.device = torch_directml.device(torch_directml.default_device()) + self.is_half = False else: - if(self.instead): - print("use %s instead"%self.instead) - try:os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") - except:pass - try:os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") - except:pass + if self.instead: + print(f"use {self.instead} instead") + try: + os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") + except: + pass + try: + os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") + except: + pass return x_pad, x_query, x_center, x_max