mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-07 06:02:53 +08:00
a way to add an exception to unpickler without explicitly calling load_with_extra
This commit is contained in:
parent
c5bdba2089
commit
8eef9d8e78
@ -103,7 +103,7 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, *args, **kwargs)
|
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
@ -151,5 +151,42 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||||||
return unsafe_torch_load(filename, *args, **kwargs)
|
return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Extra:
|
||||||
|
"""
|
||||||
|
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
|
||||||
|
(because it's not your code making the torch.load call). The intended use is like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
import torch
|
||||||
|
from modules import safe
|
||||||
|
|
||||||
|
def handler(module, name):
|
||||||
|
if module == 'torch' and name in ['float64', 'float16']:
|
||||||
|
return getattr(torch, name)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
with safe.Extra(handler):
|
||||||
|
x = torch.load('model.pt')
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, handler):
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
global global_extra_handler
|
||||||
|
|
||||||
|
assert global_extra_handler is None, 'already inside an Extra() block'
|
||||||
|
global_extra_handler = self.handler
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
global global_extra_handler
|
||||||
|
|
||||||
|
global_extra_handler = None
|
||||||
|
|
||||||
|
|
||||||
unsafe_torch_load = torch.load
|
unsafe_torch_load = torch.load
|
||||||
torch.load = load
|
torch.load = load
|
||||||
|
global_extra_handler = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user