diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index d8c311423..59d6ecc61 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -1,3 +1,4 @@ +import functools import os.path import urllib.parse from pathlib import Path @@ -15,10 +16,16 @@ from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() -allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] -if shared.opts.samples_format not in allowed_preview_extensions: - allowed_preview_extensions.append(shared.opts.samples_format) -allowed_preview_extensions_dot = ['.' + extension for extension in allowed_preview_extensions] +default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] + + +@functools.cache +def allowed_preview_extensions_with_extra(extra_extensions=None): + return set(default_allowed_preview_extensions) | set(extra_extensions or []) + + +def allowed_preview_extensions(): + return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) def register_page(page): @@ -38,9 +45,9 @@ def fetch_file(filename: str = ""): if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") - ext = os.path.splitext(filename)[1].lower() - if ext not in allowed_preview_extensions_dot: - raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.") + ext = os.path.splitext(filename)[1].lower()[1:] + if ext not in allowed_preview_extensions(): + raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) @@ -278,7 +285,7 @@ class ExtraNetworksPage: Find a preview PNG for a given path (without extension) and call link_preview on it. """ - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions], []) + potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: if os.path.isfile(file):