diff --git a/config.py b/config.py index 3748eac..d978577 100644 --- a/config.py +++ b/config.py @@ -36,7 +36,7 @@ class Config: self.iscolab, self.noparallel, self.noautoopen, - self.dml + self.dml, ) = self.arg_parse() self.instead = "" self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() @@ -71,7 +71,7 @@ class Config: cmd_opts.colab, cmd_opts.noparallel, cmd_opts.noautoopen, - cmd_opts.dml + cmd_opts.dml, ) # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. @@ -149,26 +149,38 @@ class Config: if self.dml: print("use DirectML instead") try: - os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") + os.rename( + "runtime\Lib\site-packages\onnxruntime", + "runtime\Lib\site-packages\onnxruntime-cuda", + ) except: pass try: - os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") + os.rename( + "runtime\Lib\site-packages\onnxruntime-dml", + "runtime\Lib\site-packages\onnxruntime", + ) except: - pass import torch_directml + self.device = torch_directml.device(torch_directml.default_device()) self.is_half = False else: if self.instead: print(f"use {self.instead} instead") try: - os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-dml") + os.rename( + "runtime\Lib\site-packages\onnxruntime", + "runtime\Lib\site-packages\onnxruntime-dml", + ) except: pass try: - os.rename("runtime\Lib\site-packages\onnxruntime-cuda","runtime\Lib\site-packages\onnxruntime") + os.rename( + "runtime\Lib\site-packages\onnxruntime-cuda", + "runtime\Lib\site-packages\onnxruntime", + ) except: pass return x_pad, x_query, x_center, x_max