mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 12:25:06 +08:00
added guard for torch.load to prevent loading pickles with unknown content
This commit is contained in:
parent
bba2ac8324
commit
875ddfeecf
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import modules.safe
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
models_path = os.path.join(script_path, "models")
|
models_path = os.path.join(script_path, "models")
|
||||||
|
89
modules/safe.py
Normal file
89
modules/safe.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# this code is adapted from the script contributed by anon from /h/
|
||||||
|
|
||||||
|
import io
|
||||||
|
import pickle
|
||||||
|
import collections
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy
|
||||||
|
import _codecs
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
|
||||||
|
def encode(*args):
|
||||||
|
out = _codecs.encode(*args)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def persistent_load(self, saved_id):
|
||||||
|
assert saved_id[0] == 'storage'
|
||||||
|
return torch.storage._TypedStorage()
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == 'collections' and name == 'OrderedDict':
|
||||||
|
return getattr(collections, name)
|
||||||
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
|
||||||
|
return getattr(torch._utils, name)
|
||||||
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']:
|
||||||
|
return getattr(torch, name)
|
||||||
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||||
|
return getattr(torch.nn.modules.container, name)
|
||||||
|
if module == 'numpy.core.multiarray' and name == 'scalar':
|
||||||
|
return numpy.core.multiarray.scalar
|
||||||
|
if module == 'numpy' and name == 'dtype':
|
||||||
|
return numpy.dtype
|
||||||
|
if module == '_codecs' and name == 'encode':
|
||||||
|
return encode
|
||||||
|
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
|
||||||
|
import pytorch_lightning.callbacks
|
||||||
|
return pytorch_lightning.callbacks.model_checkpoint
|
||||||
|
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||||
|
import pytorch_lightning.callbacks.model_checkpoint
|
||||||
|
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||||
|
if module == "__builtin__" and name == 'set':
|
||||||
|
return set
|
||||||
|
|
||||||
|
# Forbid everything else.
|
||||||
|
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
|
||||||
|
|
||||||
|
|
||||||
|
def check_pt(filename):
|
||||||
|
try:
|
||||||
|
|
||||||
|
# new pytorch format is a zip file
|
||||||
|
with zipfile.ZipFile(filename) as z:
|
||||||
|
with z.open('archive/data.pkl') as file:
|
||||||
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
unpickler.load()
|
||||||
|
|
||||||
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
|
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
unpickler = RestrictedUnpickler(file)
|
||||||
|
for i in range(5):
|
||||||
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
|
def load(filename, *args, **kwargs):
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not shared.cmd_opts.disable_safe_unpickle:
|
||||||
|
check_pt(filename)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||||
|
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return unsafe_torch_load(filename, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
unsafe_torch_load = torch.load
|
||||||
|
torch.load = load
|
@ -65,6 +65,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
|||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
|
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
Loading…
Reference in New Issue
Block a user