diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 3e9f53f59..f577fc000 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -60,9 +60,57 @@ def save_pil_to_file(pil_image, cache_dir=None, format="png"): return file_obj.name +def move_files_to_cache(data, block, postprocess=False): + """Move files to cache and replace the file path with the cache path. + + Runs after postprocess and before preprocess. + + Copied from gradio's processing_utils.py + + Args: + data: The input or output data for a component. Can be a dictionary or a dataclass + block: The component + postprocess: Whether its running from postprocessing + """ + + from gradio import FileData + from gradio.processing_utils import move_resource_to_block_cache + from gradio.data_classes import GradioRootModel + from gradio.data_classes import GradioModel + from gradio_client import utils as client_utils + + def _move_to_cache(d: dict): + payload = FileData(**d) + # If the gradio app developer is returning a URL from + # postprocess, it means the component can display a URL + # without it being served from the gradio server + # This makes it so that the URL is not downloaded and speeds up event processing + if payload.url and postprocess: + temp_file_path = payload.url + else: + + if check_tmp_file(shared.demo, payload.path): + temp_file_path = payload.path + else: + temp_file_path = move_resource_to_block_cache(payload.path, block) + assert temp_file_path is not None + payload.path = temp_file_path + return payload.model_dump() + + if isinstance(data, (GradioRootModel, GradioModel)): + data = data.model_dump() + + return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) + + def install_ui_tempdir_override(): - """override save to file function so that it also writes PNG info""" + """ + override save to file function so that it also writes PNG info. + override gradio4's move_files_to_cache function to prevent it from writing a copy into a temporary directory. + """ + gradio.processing_utils.save_pil_to_cache = save_pil_to_file + gradio.processing_utils.move_files_to_cache = move_files_to_cache def on_tmpdir_changed():