diff --git a/extract_feature_print.py b/extract_feature_print.py index 9b9fdc2..3ab2504 100644 --- a/extract_feature_print.py +++ b/extract_feature_print.py @@ -1,5 +1,7 @@ import os, sys, traceback +os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' + # device=sys.argv[1] n_part = int(sys.argv[2]) i_part = int(sys.argv[3])