From 8eef9d8e782aa0655241e43f67059aa7bef3bdca Mon Sep 17 00:00:00 2001
From: AUTOMATIC <16777216c@gmail.com>
Date: Sun, 25 Dec 2022 09:03:56 +0300
Subject: [PATCH] a way to add an exception to unpickler without explicitly
 calling load_with_extra

---
 modules/safe.py | 39 ++++++++++++++++++++++++++++++++++++++-
 1 file changed, 38 insertions(+), 1 deletion(-)

diff --git a/modules/safe.py b/modules/safe.py
index 479c8b86d..ec23a53c4 100644
--- a/modules/safe.py
+++ b/modules/safe.py
@@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):
 
 
 def load(filename, *args, **kwargs):
-    return load_with_extra(filename, *args, **kwargs)
+    return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
 
 
 def load_with_extra(filename, extra_handler=None, *args, **kwargs):
@@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
     return unsafe_torch_load(filename, *args, **kwargs)
 
 
+class Extra:
+    """
+    A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
+    (because it's not your code making the torch.load call). The intended use is like this:
+
+```
+import torch
+from modules import safe
+
+def handler(module, name):
+    if module == 'torch' and name in ['float64', 'float16']:
+        return getattr(torch, name)
+
+    return None
+
+with safe.Extra(handler):
+    x = torch.load('model.pt')
+```
+    """
+
+    def __init__(self, handler):
+        self.handler = handler
+
+    def __enter__(self):
+        global global_extra_handler
+
+        assert global_extra_handler is None, 'already inside an Extra() block'
+        global_extra_handler = self.handler
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        global global_extra_handler
+
+        global_extra_handler = None
+
+
 unsafe_torch_load = torch.load
 torch.load = load
+global_extra_handler = None
+