stable-diffusion-webui/modules/ui_tempdir.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

200 lines
7.4 KiB
Python
Raw Normal View History

import os
import tempfile
from collections import namedtuple
from pathlib import Path
import gradio.components
2024-03-02 13:25:10 +08:00
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
2024-03-02 13:43:46 +08:00
def register_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'): # gradio 3.15
2024-03-02 13:25:10 +08:00
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
2024-03-02 13:43:46 +08:00
gradio_app.temp_file_sets[0] = gradio_app.temp_file_sets[0] | {filename}
2024-03-02 13:43:46 +08:00
if hasattr(gradio_app, 'temp_dirs'): # gradio 3.9
gradio_app.temp_dirs = gradio_app.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
2024-03-02 13:43:46 +08:00
def check_tmp_file(gradio_app, filename):
if hasattr(gradio_app, 'temp_file_sets'):
2024-03-02 13:25:10 +08:00
if hasattr(gr.utils, 'abspath'): # gradio 4.19
filename = gr.utils.abspath(filename)
else:
filename = os.path.abspath(filename)
2024-03-02 13:43:46 +08:00
return any(filename in fileset for fileset in gradio_app.temp_file_sets)
2024-03-02 13:43:46 +08:00
if hasattr(gradio_app, 'temp_dirs'):
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio_app.temp_dirs)
return False
def save_pil_to_file(pil_image, cache_dir=None, format="png"):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
2023-05-04 15:55:57 +08:00
register_tmp_file(shared.demo, already_saved_as)
2024-02-24 00:26:56 +08:00
filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
register_tmp_file(shared.demo, filename_with_mtime)
return filename_with_mtime
if shared.opts.temp_dir:
dir = shared.opts.temp_dir
else:
dir = cache_dir
os.makedirs(dir, exist_ok=True)
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj.name
2024-04-22 07:59:29 +08:00
async def async_move_files_to_cache(data, block, postprocess=False, check_in_upload_folder=False, keep_in_cache=False):
2024-03-02 13:25:10 +08:00
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
2024-04-22 07:59:29 +08:00
Runs after .postprocess() and before .preprocess().
Copied from gradio's processing_utils.py
Args:
data: The input or output data for a component. Can be a dictionary or a dataclass
2024-03-02 13:25:10 +08:00
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
2024-03-02 13:25:10 +08:00
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
2024-04-22 07:59:29 +08:00
keep_in_cache: If True, the file will not be deleted from cache when the server is shut down.
"""
from gradio import FileData
from gradio.data_classes import GradioRootModel
from gradio.data_classes import GradioModel
from gradio_client import utils as client_utils
2024-04-22 07:59:29 +08:00
from gradio.utils import get_upload_folder, is_in_or_equal, is_static_file
2024-04-22 07:59:29 +08:00
async def _move_to_cache(d: dict):
payload = FileData(**d)
2024-03-02 13:25:10 +08:00
# EDITED
payload.path = payload.path.rsplit('?', 1)[0]
# If the gradio app developer is returning a URL from
# postprocess, it means the component can display a URL
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
2024-04-22 07:59:29 +08:00
if payload.url and postprocess and client_utils.is_http_url_like(payload.url):
2024-03-02 13:25:10 +08:00
payload.path = payload.url
2024-04-22 07:59:29 +08:00
elif is_static_file(payload):
pass
2024-03-02 13:25:10 +08:00
elif not block.proxy_url:
# EDITED
if check_tmp_file(shared.demo, payload.path):
temp_file_path = payload.path
else:
2024-03-02 13:40:06 +08:00
# If the file is on a remote server, do not move it to cache.
if check_in_upload_folder and not client_utils.is_http_url_like(
payload.path
):
path = os.path.abspath(payload.path)
if not is_in_or_equal(path, get_upload_folder()):
raise ValueError(
f"File {path} is not in the upload folder and cannot be accessed."
)
2024-04-22 07:59:29 +08:00
if not payload.is_stream:
temp_file_path = await block.async_move_resource_to_block_cache(
payload.path
)
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
payload.path = temp_file_path
if keep_in_cache:
block.keep_in_cache.add(payload.path)
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
proxy_url = block.proxy_url.rstrip("/")
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
f"{url_prefix}"
):
url = payload.path
else:
url = f"{url_prefix}{payload.path}"
payload.url = url
2024-03-02 13:25:10 +08:00
return payload.model_dump()
if isinstance(data, (GradioRootModel, GradioModel)):
data = data.model_dump()
2024-04-22 07:59:29 +08:00
return await client_utils.async_traverse(
data, _move_to_cache, client_utils.is_file_obj
)
def install_ui_tempdir_override():
"""
override save to file function so that it also writes PNG info.
override gradio4's move_files_to_cache function to prevent it from writing a copy into a temporary directory.
"""
2023-12-03 04:42:06 +08:00
gradio.processing_utils.save_pil_to_cache = save_pil_to_file
2024-04-22 07:59:29 +08:00
gradio.processing_utils.async_move_files_to_cache = async_move_files_to_cache
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
2023-05-10 16:37:18 +08:00
for root, _, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)
def is_gradio_temp_path(path):
"""
Check if the path is a temp dir used by gradio
"""
path = Path(path)
if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
return True
if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
if path.is_relative_to(gradio_temp_dir):
return True
if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
return True
return False