mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
make it possible for extensions/scripts to add their own embedding directories
This commit is contained in:
parent
a0c87f1fdf
commit
085427de0e
@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
|
|||||||
clip = None
|
clip = None
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
|
|||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
|
@ -66,17 +66,41 @@ class Embedding:
|
|||||||
return self.cached_checksum
|
return self.cached_checksum
|
||||||
|
|
||||||
|
|
||||||
|
class DirWithTextualInversionEmbeddings:
|
||||||
|
def __init__(self, path):
|
||||||
|
self.path = path
|
||||||
|
self.mtime = None
|
||||||
|
|
||||||
|
def has_changed(self):
|
||||||
|
if not os.path.isdir(self.path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
mt = os.path.getmtime(self.path)
|
||||||
|
if self.mtime is None or mt > self.mtime:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not os.path.isdir(self.path):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.mtime = os.path.getmtime(self.path)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingDatabase:
|
class EmbeddingDatabase:
|
||||||
def __init__(self, embeddings_dir):
|
def __init__(self):
|
||||||
self.ids_lookup = {}
|
self.ids_lookup = {}
|
||||||
self.word_embeddings = {}
|
self.word_embeddings = {}
|
||||||
self.skipped_embeddings = {}
|
self.skipped_embeddings = {}
|
||||||
self.dir_mtime = None
|
|
||||||
self.embeddings_dir = embeddings_dir
|
|
||||||
self.expected_shape = -1
|
self.expected_shape = -1
|
||||||
|
self.embedding_dirs = {}
|
||||||
|
|
||||||
|
def add_embedding_dir(self, path):
|
||||||
|
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||||
|
|
||||||
|
def clear_embedding_dirs(self):
|
||||||
|
self.embedding_dirs.clear()
|
||||||
|
|
||||||
def register_embedding(self, embedding, model):
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
self.word_embeddings[embedding.name] = embedding
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||||
@ -93,69 +117,62 @@ class EmbeddingDatabase:
|
|||||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||||
return vec.shape[1]
|
return vec.shape[1]
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
def load_from_file(self, path, filename):
|
||||||
mt = os.path.getmtime(self.embeddings_dir)
|
name, ext = os.path.splitext(filename)
|
||||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
ext = ext.upper()
|
||||||
return
|
|
||||||
|
|
||||||
self.dir_mtime = mt
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||||
self.ids_lookup.clear()
|
_, second_ext = os.path.splitext(name)
|
||||||
self.word_embeddings.clear()
|
if second_ext.upper() == '.PREVIEW':
|
||||||
self.skipped_embeddings.clear()
|
|
||||||
self.expected_shape = self.get_expected_shape()
|
|
||||||
|
|
||||||
def process_file(path, filename):
|
|
||||||
name, ext = os.path.splitext(filename)
|
|
||||||
ext = ext.upper()
|
|
||||||
|
|
||||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
|
||||||
_, second_ext = os.path.splitext(name)
|
|
||||||
if second_ext.upper() == '.PREVIEW':
|
|
||||||
return
|
|
||||||
|
|
||||||
embed_image = Image.open(path)
|
|
||||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
|
||||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
|
||||||
name = data.get('name', name)
|
|
||||||
else:
|
|
||||||
data = extract_image_data_embed(embed_image)
|
|
||||||
name = data.get('name', name)
|
|
||||||
elif ext in ['.BIN', '.PT']:
|
|
||||||
data = torch.load(path, map_location="cpu")
|
|
||||||
else:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# textual inversion embeddings
|
embed_image = Image.open(path)
|
||||||
if 'string_to_param' in data:
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||||
param_dict = data['string_to_param']
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||||
if hasattr(param_dict, '_parameters'):
|
name = data.get('name', name)
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
# diffuser concepts
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
data = extract_image_data_embed(embed_image)
|
||||||
|
name = data.get('name', name)
|
||||||
|
elif ext in ['.BIN', '.PT']:
|
||||||
|
data = torch.load(path, map_location="cpu")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
# textual inversion embeddings
|
||||||
embedding = Embedding(vec, name)
|
if 'string_to_param' in data:
|
||||||
embedding.step = data.get('step', None)
|
param_dict = data['string_to_param']
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
if hasattr(param_dict, '_parameters'):
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
embedding.vectors = vec.shape[0]
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
embedding.shape = vec.shape[-1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
emb = next(iter(data.values()))
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
if len(emb.shape) == 1:
|
||||||
else:
|
emb = emb.unsqueeze(0)
|
||||||
self.skipped_embeddings[name] = embedding
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
for root, dirs, fns in os.walk(self.embeddings_dir):
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
embedding.vectors = vec.shape[0]
|
||||||
|
embedding.shape = vec.shape[-1]
|
||||||
|
|
||||||
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
else:
|
||||||
|
self.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
|
def load_from_dir(self, embdir):
|
||||||
|
if not os.path.isdir(embdir.path):
|
||||||
|
return
|
||||||
|
|
||||||
|
for root, dirs, fns in os.walk(embdir.path):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
try:
|
try:
|
||||||
fullfn = os.path.join(root, fn)
|
fullfn = os.path.join(root, fn)
|
||||||
@ -163,12 +180,32 @@ class EmbeddingDatabase:
|
|||||||
if os.stat(fullfn).st_size == 0:
|
if os.stat(fullfn).st_size == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
process_file(fullfn, fn)
|
self.load_from_file(fullfn, fn)
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||||
|
if not force_reload:
|
||||||
|
need_reload = False
|
||||||
|
for path, embdir in self.embedding_dirs.items():
|
||||||
|
if embdir.has_changed():
|
||||||
|
need_reload = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not need_reload:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
self.skipped_embeddings.clear()
|
||||||
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
|
for path, embdir in self.embedding_dirs.items():
|
||||||
|
self.load_from_dir(embdir)
|
||||||
|
embdir.update()
|
||||||
|
|
||||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||||
if len(self.skipped_embeddings) > 0:
|
if len(self.skipped_embeddings) > 0:
|
||||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||||
@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||||
assert steps, "Max steps is empty or 0"
|
assert steps, "Max steps is empty or 0"
|
||||||
assert isinstance(steps, int), "Max steps must be integer"
|
assert isinstance(steps, int), "Max steps must be integer"
|
||||||
assert steps > 0 , "Max steps must be positive"
|
assert steps > 0, "Max steps must be positive"
|
||||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
||||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
assert create_image_every >= 0, "Create image must be positive or 0"
|
||||||
if save_model_every or create_image_every:
|
if save_model_every or create_image_every:
|
||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
|
Loading…
Reference in New Issue
Block a user