mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-07 06:02:53 +08:00
image embedding data cache (#16556)
This commit is contained in:
parent
d88a3c15f7
commit
deb3803a3a
@ -291,6 +291,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||
"textual_inversion_image_embedding_data_cache": OptionInfo(False, 'Cache the data of image embeddings').info('potentially increase TI load time at the cost some disk space'),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {
|
||||
|
@ -12,7 +12,7 @@ import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@ -116,6 +116,7 @@ class EmbeddingDatabase:
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
self.previously_displayed_embeddings = ()
|
||||
self.image_embedding_cache = cache.cache('image-embedding')
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
@ -154,6 +155,31 @@ class EmbeddingDatabase:
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
def read_embedding_from_image(self, path, name):
|
||||
try:
|
||||
ondisk_mtime = os.path.getmtime(path)
|
||||
|
||||
if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0):
|
||||
# cache will only be used if the file has not been modified time matches
|
||||
return cache_embedding.get('data', None), cache_embedding.get('name', None)
|
||||
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
elif data := extract_image_data_embed(embed_image):
|
||||
name = data.get('name', name)
|
||||
|
||||
if data is None or shared.opts.textual_inversion_image_embedding_data_cache:
|
||||
# data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
|
||||
# results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
|
||||
self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime}
|
||||
|
||||
return data, name
|
||||
except Exception:
|
||||
errors.report(f"Error loading embedding {path}", exc_info=True)
|
||||
return None, None
|
||||
|
||||
def load_from_file(self, path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
@ -163,17 +189,10 @@ class EmbeddingDatabase:
|
||||
if second_ext.upper() == '.PREVIEW':
|
||||
return
|
||||
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
if data:
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
# if data is None, means this is not an embedding, just a preview image
|
||||
return
|
||||
data, name = self.read_embedding_from_image(path, name)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
elif ext in ['.SAFETENSORS']:
|
||||
@ -191,7 +210,6 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")
|
||||
|
||||
|
||||
def load_from_dir(self, embdir):
|
||||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user