From 46c0e9b2fe52a875af92840baef94e65b4a8f5c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 24 Jun 2023 16:21:31 +0800 Subject: [PATCH] fix extract feature in MPS device --- extract_feature_print.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extract_feature_print.py b/extract_feature_print.py index 9b9fdc2..3ab2504 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -1,5 +1,7 @@ import os, sys, traceback +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' + # device=sys.argv[1] n_part = int(sys.argv[2]) i_part = int(sys.argv[3])