From 56917dbeca8d268ce9d23af90b0ee66fc8761988 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 15 Aug 2023 10:40:59 +0800 Subject: [PATCH] Format code (#1011) Co-authored-by: github-actions[bot] --- config.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/config.py b/config.py index 3748eac..d978577 100644 --- a/config.py +++ b/config.py @@ -36,7 +36,7 @@ class Config: self.iscolab, self.noparallel, self.noautoopen, - self.dml + self.dml, ) = self.arg_parse() self.instead = "" self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config() @@ -71,7 +71,7 @@ class Config: cmd_opts.colab, cmd_opts.noparallel, cmd_opts.noautoopen, - cmd_opts.dml + cmd_opts.dml, ) # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. @@ -149,26 +149,38 @@ class Config: if self.dml: print("use DirectML instead") try: - os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-cuda") + os.rename( + "runtime\Lib\site-packages\onnxruntime", + "runtime\Lib\site-packages\onnxruntime-cuda", + ) except: pass try: - os.rename("runtime\Lib\site-packages\onnxruntime-dml","runtime\Lib\site-packages\onnxruntime") + os.rename( + "runtime\Lib\site-packages\onnxruntime-dml", + "runtime\Lib\site-packages\onnxruntime", + ) except: - pass import torch_directml + self.device = torch_directml.device(torch_directml.default_device()) self.is_half = False else: if self.instead: print(f"use {self.instead} instead") try: - os.rename("runtime\Lib\site-packages\onnxruntime","runtime\Lib\site-packages\onnxruntime-dml") + os.rename( + "runtime\Lib\site-packages\onnxruntime", + "runtime\Lib\site-packages\onnxruntime-dml", + ) except: pass try: - os.rename("runtime\Lib\site-packages\onnxruntime-cuda","runtime\Lib\site-packages\onnxruntime") + os.rename( + "runtime\Lib\site-packages\onnxruntime-cuda", + "runtime\Lib\site-packages\onnxruntime", + ) except: pass return x_pad, x_query, x_center, x_max