From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001
From: Nuullll <vfirst218@gmail.com>
Date: Sat, 6 Jan 2024 19:09:56 +0800
Subject: [PATCH] [IPEX] Fix torch.Generator hijack

---
 modules/xpu_specific.py | 20 ++++++++++++++++----
 1 file changed, 16 insertions(+), 4 deletions(-)

diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index 4e11125b2..1137891a6 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
     return torch.reshape(result, (*N, L, Ev))
 
 
+def is_xpu_device(device: str | torch.device = None):
+    if device is None:
+        return False
+    if isinstance(device, str):
+        return device.startswith("xpu")
+    return device.type == "xpu"
+
+
 if has_xpu:
-    # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
-    CondFunc('torch.Generator',
-        lambda orig_func, device=None: torch.xpu.Generator(device),
-        lambda orig_func, device=None: device is not None and device.type == "xpu")
+    try:
+        # torch.Generator supports "xpu" device since 2.1
+        torch.Generator("xpu")
+    except:
+        # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1)
+        CondFunc('torch.Generator',
+            lambda orig_func, device=None: torch.xpu.Generator(device),
+            lambda orig_func, device=None: is_xpu_device(device))
 
     # W/A for some OPs that could not handle different input dtypes
     CondFunc('torch.nn.functional.layer_norm',