stable-diffusion-webui/modules/ui_tempdir.py
AUTOMATIC1111 b63dda3f45 linter
2024-03-02 08:43:46 +03:00

195 lines
7.0 KiB
Python

import os
import tempfile
from collections import namedtuple
from pathlib import Path
import gradio.components
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def register_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'): # gradio 3.15
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
gradio_app.temp_file_sets[0] = gradio_app.temp_file_sets[0] | {filename}
if hasattr(gradio_app, 'temp_dirs'): # gradio 3.9
gradio_app.temp_dirs = gradio_app.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
def check_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'):
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
return any(filename in fileset for fileset in gradio_app.temp_file_sets)
if hasattr(gradio_app, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio_app.temp_dirs)
return False
def save_pil_to_file(pil_image, cache_dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
register_tmp_file(shared.demo, filename_with_mtime)
return filename_with_mtime
if shared.opts.temp_dir:
dir = shared.opts.temp_dir
else:
dir = cache_dir
os.makedirs(dir, exist_ok=True)
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj.name
def move_files_to_cache(data, block, postprocess=False, add_urls=False, check_in_upload_folder=False):
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
Runs after postprocess and before preprocess.
Copied from gradio's processing_utils.py
Args:
data: The input or output data for a component. Can be a dictionary or a dataclass
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
add_urls: Whether to add URLs to the payload
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
"""
from gradio import FileData
from gradio.data_classes import GradioRootModel
from gradio.data_classes import GradioModel
from gradio_client import utils as client_utils
from gradio.utils import get_upload_folder, is_in_or_equal
def _move_to_cache(d: dict):
payload = FileData(**d)
# EDITED
payload.path = payload.path.rsplit('?', 1)[0]
# If the gradio app developer is returning a URL from
# postprocess, it means the component can display a URL
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
if payload.url and postprocess:
payload.path = payload.url
elif not block.proxy_url:
# EDITED
if check_tmp_file(shared.demo, payload.path):
temp_file_path = payload.path
else:
# If the file is on a remote server, do not move it to cache.
if check_in_upload_folder and not client_utils.is_http_url_like(
payload.path
):
path = os.path.abspath(payload.path)
if not is_in_or_equal(path, get_upload_folder()):
raise ValueError(
f"File {path} is not in the upload folder and cannot be accessed."
)
temp_file_path = block.move_resource_to_block_cache(payload.path)
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
payload.path = temp_file_path
if add_urls:
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
proxy_url = block.proxy_url.rstrip("/")
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
f"{url_prefix}"
):
url = payload.path
else:
url = f"{url_prefix}{payload.path}"
payload.url = url
return payload.model_dump()
if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()
return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj)
def install_ui_tempdir_override():
"""
override save to file function so that it also writes PNG info.
override gradio4's move_files_to_cache function to prevent it from writing a copy into a temporary directory.
"""
gradio.processing_utils.save_pil_to_cache = save_pil_to_file
gradio.processing_utils.move_files_to_cache = move_files_to_cache
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)
def is_gradio_temp_path(path):
"""
Check if the path is a temp dir used by gradio
"""
path = Path(path)
if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
return True
if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
if path.is_relative_to(gradio_temp_dir):
return True
if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
return True
return False