From 13df5754650b60e09c8031810780986453ee6f08 Mon Sep 17 00:00:00 2001 From: Andray Date: Fri, 17 May 2024 02:37:35 +0400 Subject: [PATCH] add flag use mime for batches --- modules/cmd_args.py | 1 + modules/img2img.py | 2 +- modules/launch_utils.py | 8 +++----- modules/postprocessing.py | 2 +- modules/shared.py | 1 + modules/util.py | 36 +++++++++++++++++++++++++++++------- requirements.txt | 3 ++- requirements_versions.txt | 1 + 8 files changed, 39 insertions(+), 15 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index d71982b2c..1cd0caac6 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -126,3 +126,4 @@ parser.add_argument("--skip-load-model-at-start", action='store_true', help="if parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system") parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system') parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file") +parser.add_argument("--use-mime-file-filtering-for-batch-from-dir", action='store_true', help="allows passing images with no or with incorrect extension in batch from directory") diff --git a/modules/img2img.py b/modules/img2img.py index 24f869f5c..5aa73fc74 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -22,7 +22,7 @@ def process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, processing.fix_seed(p) if isinstance(input, str): - batch_images = list(shared.walk_files(input, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) + batch_images = list(shared.walk_image_files(input)) else: batch_images = [os.path.abspath(x.name) for x in input] diff --git a/modules/launch_utils.py b/modules/launch_utils.py index e22da4ec6..d8309efa0 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -300,16 +300,14 @@ def requirements_met(requirements_file): package = m.group(1).strip() version_required = (m.group(2) or "").strip() - if version_required == "": - continue - try: version_installed = importlib.metadata.version(package) except Exception: return False - if packaging.version.parse(version_required) != packaging.version.parse(version_installed): - return False + if version_required != "": + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): + return False return True diff --git a/modules/postprocessing.py b/modules/postprocessing.py index a413d1027..b84b1b5f1 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -27,7 +27,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert input_dir, 'input directory not selected' - image_list = shared.listfiles(input_dir) + image_list = shared.walk_image_files(input_dir) for filename in image_list: yield filename, filename else: diff --git a/modules/shared.py b/modules/shared.py index 2a3787f99..62fbe5a74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,6 +82,7 @@ listfiles = util.listfiles html_path = util.html_path html = util.html walk_files = util.walk_files +walk_image_files = util.walk_image_files ldm_print = util.ldm_print reload_gradio_theme = shared_gradio_themes.reload_gradio_theme diff --git a/modules/util.py b/modules/util.py index 7911b0db7..60406783b 100644 --- a/modules/util.py +++ b/modules/util.py @@ -1,5 +1,6 @@ import os import re +import filetype from modules import shared from modules.paths_internal import script_path, cwd @@ -28,7 +29,7 @@ def html(filename): return "" -def walk_files(path, allowed_extensions=None): +def walk_files(path, allowed_extensions=None, allowed_mime=None): if not os.path.exists(path): return @@ -40,15 +41,36 @@ def walk_files(path, allowed_extensions=None): for root, _, files in items: for filename in sorted(files, key=natural_sort_key): - if allowed_extensions is not None: - _, ext = os.path.splitext(filename) - if ext.lower() not in allowed_extensions: - continue - + filepath = os.path.join(root, filename) if not shared.opts.list_hidden_files and ("/." in root or "\\." in root): continue - yield os.path.join(root, filename) + if allowed_extensions or allowed_mime: + file_allowed = False + + if allowed_extensions is not None: + _, ext = os.path.splitext(filename) + if ext.lower() in allowed_extensions: + file_allowed = True + + if allowed_mime is not None: + if filetype.guess(filepath).mime.startswith(allowed_mime): + file_allowed = True + else: + file_allowed = True + + if not file_allowed: + continue + + yield filepath + + +def walk_image_files(path): + if shared.cmd_opts.use_mime_file_filtering_for_batch_from_dir: + return walk_files(path, allowed_mime='image/') + else: + return walk_files(path, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")) + def ldm_print(*args, **kwargs): diff --git a/requirements.txt b/requirements.txt index 0d6bac600..56e6112c1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,4 +31,5 @@ torch torchdiffeq torchsde transformers==4.30.2 -pillow-avif-plugin==1.4.3 \ No newline at end of file +pillow-avif-plugin==1.4.3 +filetype diff --git a/requirements_versions.txt b/requirements_versions.txt index 0306ce94f..3560f3cfc 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -33,3 +33,4 @@ torchsde==0.2.6 transformers==4.30.2 httpx==0.24.1 pillow-avif-plugin==1.4.3 +filetype