diff --git a/modules/shared_options.py b/modules/shared_options.py index efede7067..ebe59d4da 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -291,6 +291,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), + "textual_inversion_image_embedding_data_cache": OptionInfo(False, 'Cache the data of image embeddings').info('potentially increase TI load time at the cost some disk space'), })) options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index dc7833e93..f209b8834 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -12,7 +12,7 @@ import safetensors.torch import numpy as np from PIL import Image, PngImagePlugin -from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes +from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -116,6 +116,7 @@ class EmbeddingDatabase: self.expected_shape = -1 self.embedding_dirs = {} self.previously_displayed_embeddings = () + self.image_embedding_cache = cache.cache('image-embedding') def add_embedding_dir(self, path): self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) @@ -154,6 +155,31 @@ class EmbeddingDatabase: vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] + def read_embedding_from_image(self, path, name): + try: + ondisk_mtime = os.path.getmtime(path) + + if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0): + # cache will only be used if the file has not been modified time matches + return cache_embedding.get('data', None), cache_embedding.get('name', None) + + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) + elif data := extract_image_data_embed(embed_image): + name = data.get('name', name) + + if data is None or shared.opts.textual_inversion_image_embedding_data_cache: + # data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled + # results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads + self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime} + + return data, name + except Exception: + errors.report(f"Error loading embedding {path}", exc_info=True) + return None, None + def load_from_file(self, path, filename): name, ext = os.path.splitext(filename) ext = ext.upper() @@ -163,17 +189,10 @@ class EmbeddingDatabase: if second_ext.upper() == '.PREVIEW': return - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - if data: - name = data.get('name', name) - else: - # if data is None, means this is not an embedding, just a preview image - return + data, name = self.read_embedding_from_image(path, name) + if data is None: + return + elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") elif ext in ['.SAFETENSORS']: @@ -191,7 +210,6 @@ class EmbeddingDatabase: else: print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.") - def load_from_dir(self, embdir): if not os.path.isdir(embdir.path): return