image embedding data cache (#16556)

This commit is contained in:
w-e-w 2024-10-29 23:28:21 +09:00 committed by GitHub
parent d88a3c15f7
commit deb3803a3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 13 deletions

View File

@ -291,6 +291,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
"textual_inversion_image_embedding_data_cache": OptionInfo(False, 'Cache the data of image embeddings').info('potentially increase TI load time at the cost some disk space'),
}))
options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {

View File

@ -12,7 +12,7 @@ import safetensors.torch
import numpy as np
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@ -116,6 +116,7 @@ class EmbeddingDatabase:
self.expected_shape = -1
self.embedding_dirs = {}
self.previously_displayed_embeddings = ()
self.image_embedding_cache = cache.cache('image-embedding')
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
@ -154,6 +155,31 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def read_embedding_from_image(self, path, name):
try:
ondisk_mtime = os.path.getmtime(path)
if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0):
# cache will only be used if the file has not been modified time matches
return cache_embedding.get('data', None), cache_embedding.get('name', None)
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
elif data := extract_image_data_embed(embed_image):
name = data.get('name', name)
if data is None or shared.opts.textual_inversion_image_embedding_data_cache:
# data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
# results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime}
return data, name
except Exception:
errors.report(f"Error loading embedding {path}", exc_info=True)
return None, None
def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
@ -163,17 +189,10 @@ class EmbeddingDatabase:
if second_ext.upper() == '.PREVIEW':
return
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
if data:
name = data.get('name', name)
else:
# if data is None, means this is not an embedding, just a preview image
return
data, name = self.read_embedding_from_image(path, name)
if data is None:
return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
@ -191,7 +210,6 @@ class EmbeddingDatabase:
else:
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
return