mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-18 04:10:11 +08:00
bddebe09ed
Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory)
515 lines
23 KiB
Python
515 lines
23 KiB
Python
import os
|
|
import sys
|
|
import traceback
|
|
|
|
import torch
|
|
import tqdm
|
|
import html
|
|
import datetime
|
|
import csv
|
|
|
|
from PIL import Image, PngImagePlugin
|
|
|
|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
|
|
import modules.textual_inversion.dataset
|
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|
|
|
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
|
|
insert_image_data_embed, extract_image_data_embed,
|
|
caption_image_overlay)
|
|
|
|
class Embedding:
|
|
def __init__(self, vec, name, step=None):
|
|
self.vec = vec
|
|
self.name = name
|
|
self.step = step
|
|
self.shape = None
|
|
self.vectors = 0
|
|
self.cached_checksum = None
|
|
self.sd_checkpoint = None
|
|
self.sd_checkpoint_name = None
|
|
self.optimizer_state_dict = None
|
|
|
|
def save(self, filename):
|
|
embedding_data = {
|
|
"string_to_token": {"*": 265},
|
|
"string_to_param": {"*": self.vec},
|
|
"name": self.name,
|
|
"step": self.step,
|
|
"sd_checkpoint": self.sd_checkpoint,
|
|
"sd_checkpoint_name": self.sd_checkpoint_name,
|
|
}
|
|
|
|
torch.save(embedding_data, filename)
|
|
|
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
|
|
optimizer_saved_dict = {
|
|
'hash': self.checksum(),
|
|
'optimizer_state_dict': self.optimizer_state_dict,
|
|
}
|
|
torch.save(optimizer_saved_dict, filename + '.optim')
|
|
|
|
def checksum(self):
|
|
if self.cached_checksum is not None:
|
|
return self.cached_checksum
|
|
|
|
def const_hash(a):
|
|
r = 0
|
|
for v in a:
|
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
|
return r
|
|
|
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
|
return self.cached_checksum
|
|
|
|
|
|
class EmbeddingDatabase:
|
|
def __init__(self, embeddings_dir):
|
|
self.ids_lookup = {}
|
|
self.word_embeddings = {}
|
|
self.skipped_embeddings = {}
|
|
self.dir_mtime = None
|
|
self.embeddings_dir = embeddings_dir
|
|
self.expected_shape = -1
|
|
|
|
def register_embedding(self, embedding, model):
|
|
|
|
self.word_embeddings[embedding.name] = embedding
|
|
|
|
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
|
|
|
first_id = ids[0]
|
|
if first_id not in self.ids_lookup:
|
|
self.ids_lookup[first_id] = []
|
|
|
|
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
|
|
|
return embedding
|
|
|
|
def get_expected_shape(self):
|
|
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
|
return vec.shape[1]
|
|
|
|
def load_textual_inversion_embeddings(self, force_reload = False):
|
|
mt = os.path.getmtime(self.embeddings_dir)
|
|
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
return
|
|
|
|
self.dir_mtime = mt
|
|
self.ids_lookup.clear()
|
|
self.word_embeddings.clear()
|
|
self.skipped_embeddings.clear()
|
|
self.expected_shape = self.get_expected_shape()
|
|
|
|
def process_file(path, filename):
|
|
name, ext = os.path.splitext(filename)
|
|
ext = ext.upper()
|
|
|
|
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
|
embed_image = Image.open(path)
|
|
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
|
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
|
name = data.get('name', name)
|
|
else:
|
|
data = extract_image_data_embed(embed_image)
|
|
name = data.get('name', name)
|
|
elif ext in ['.BIN', '.PT']:
|
|
data = torch.load(path, map_location="cpu")
|
|
else:
|
|
return
|
|
|
|
# textual inversion embeddings
|
|
if 'string_to_param' in data:
|
|
param_dict = data['string_to_param']
|
|
if hasattr(param_dict, '_parameters'):
|
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
emb = next(iter(param_dict.items()))[1]
|
|
# diffuser concepts
|
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
|
|
emb = next(iter(data.values()))
|
|
if len(emb.shape) == 1:
|
|
emb = emb.unsqueeze(0)
|
|
else:
|
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
|
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
embedding = Embedding(vec, name)
|
|
embedding.step = data.get('step', None)
|
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
embedding.vectors = vec.shape[0]
|
|
embedding.shape = vec.shape[-1]
|
|
|
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
|
self.register_embedding(embedding, shared.sd_model)
|
|
else:
|
|
self.skipped_embeddings[name] = embedding
|
|
|
|
for fn in os.listdir(self.embeddings_dir):
|
|
try:
|
|
fullfn = os.path.join(self.embeddings_dir, fn)
|
|
|
|
if os.stat(fullfn).st_size == 0:
|
|
continue
|
|
|
|
process_file(fullfn, fn)
|
|
except Exception:
|
|
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
continue
|
|
|
|
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
|
if len(self.skipped_embeddings) > 0:
|
|
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
|
|
|
def find_embedding_at_position(self, tokens, offset):
|
|
token = tokens[offset]
|
|
possible_matches = self.ids_lookup.get(token, None)
|
|
|
|
if possible_matches is None:
|
|
return None, None
|
|
|
|
for ids, embedding in possible_matches:
|
|
if tokens[offset:offset + len(ids)] == ids:
|
|
return embedding, len(ids)
|
|
|
|
return None, None
|
|
|
|
|
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|
cond_model = shared.sd_model.cond_stage_model
|
|
|
|
with devices.autocast():
|
|
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
|
|
|
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
|
|
|
for i in range(num_vectors_per_token):
|
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
|
|
|
# Remove illegal characters from name.
|
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
|
if not overwrite_old:
|
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
|
|
|
embedding = Embedding(vec, name)
|
|
embedding.step = 0
|
|
embedding.save(fn)
|
|
|
|
return fn
|
|
|
|
|
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
|
if shared.opts.training_write_csv_every == 0:
|
|
return
|
|
|
|
if step % shared.opts.training_write_csv_every != 0:
|
|
return
|
|
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
|
|
|
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
|
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
|
|
|
if write_csv_header:
|
|
csv_writer.writeheader()
|
|
|
|
epoch = (step - 1) // epoch_len
|
|
epoch_step = (step - 1) % epoch_len
|
|
|
|
csv_writer.writerow({
|
|
"step": step,
|
|
"epoch": epoch,
|
|
"epoch_step": epoch_step,
|
|
**values,
|
|
})
|
|
|
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
|
assert model_name, f"{name} not selected"
|
|
assert learn_rate, "Learning rate is empty or 0"
|
|
assert isinstance(batch_size, int), "Batch size must be integer"
|
|
assert batch_size > 0, "Batch size must be positive"
|
|
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
|
|
assert gradient_step > 0, "Gradient accumulation step must be positive"
|
|
assert data_root, "Dataset directory is empty"
|
|
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
|
|
assert os.listdir(data_root), "Dataset directory is empty"
|
|
assert template_file, "Prompt template file is empty"
|
|
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
|
assert steps, "Max steps is empty or 0"
|
|
assert isinstance(steps, int), "Max steps must be integer"
|
|
assert steps > 0 , "Max steps must be positive"
|
|
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
|
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
|
assert isinstance(create_image_every, int), "Create image must be integer"
|
|
assert create_image_every >= 0 , "Create image must be positive or 0"
|
|
if save_model_every or create_image_every:
|
|
assert log_directory, "Log directory is empty"
|
|
|
|
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
|
save_embedding_every = save_embedding_every or 0
|
|
create_image_every = create_image_every or 0
|
|
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
|
|
|
|
shared.state.textinfo = "Initializing textual inversion training..."
|
|
shared.state.job_count = steps
|
|
|
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
|
|
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
|
unload = shared.opts.unload_models_when_training
|
|
|
|
if save_embedding_every > 0:
|
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
|
os.makedirs(embedding_dir, exist_ok=True)
|
|
else:
|
|
embedding_dir = None
|
|
|
|
if create_image_every > 0:
|
|
images_dir = os.path.join(log_directory, "images")
|
|
os.makedirs(images_dir, exist_ok=True)
|
|
else:
|
|
images_dir = None
|
|
|
|
if create_image_every > 0 and save_image_with_stored_embedding:
|
|
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
|
|
os.makedirs(images_embeds_dir, exist_ok=True)
|
|
else:
|
|
images_embeds_dir = None
|
|
|
|
hijack = sd_hijack.model_hijack
|
|
|
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
|
checkpoint = sd_models.select_checkpoint()
|
|
|
|
initial_step = embedding.step or 0
|
|
if initial_step >= steps:
|
|
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
|
return embedding, filename
|
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
|
|
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
|
|
|
pin_memory = shared.opts.pin_memory
|
|
|
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
|
|
|
|
latent_sampling_method = ds.latent_sampling_method
|
|
|
|
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
|
|
|
|
if unload:
|
|
shared.parallel_processing_allowed = False
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
|
|
embedding.vec.requires_grad = True
|
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
|
if shared.opts.save_optimizer_state:
|
|
optimizer_state_dict = None
|
|
if os.path.exists(filename + '.optim'):
|
|
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
|
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
|
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
|
|
if optimizer_state_dict is not None:
|
|
optimizer.load_state_dict(optimizer_state_dict)
|
|
print("Loaded existing optimizer from checkpoint")
|
|
else:
|
|
print("No saved optimizer exists in checkpoint")
|
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
batch_size = ds.batch_size
|
|
gradient_step = ds.gradient_step
|
|
# n steps = batch_size * gradient_step * n image processed
|
|
steps_per_epoch = len(ds) // batch_size // gradient_step
|
|
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
|
|
loss_step = 0
|
|
_loss_step = 0 #internal
|
|
|
|
last_saved_file = "<none>"
|
|
last_saved_image = "<none>"
|
|
forced_filename = "<none>"
|
|
embedding_yet_to_be_embedded = False
|
|
|
|
pbar = tqdm.tqdm(total=steps - initial_step)
|
|
try:
|
|
for i in range((steps-initial_step) * gradient_step):
|
|
if scheduler.finished:
|
|
break
|
|
if shared.state.interrupted:
|
|
break
|
|
for j, batch in enumerate(dl):
|
|
# works as a drop_last=True for gradient accumulation
|
|
if j == max_steps_per_epoch:
|
|
break
|
|
scheduler.apply(optimizer, embedding.step)
|
|
if scheduler.finished:
|
|
break
|
|
if shared.state.interrupted:
|
|
break
|
|
|
|
with devices.autocast():
|
|
# c = stack_conds(batch.cond).to(devices.device)
|
|
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
|
|
# print(mask)
|
|
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
|
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
|
c = shared.sd_model.cond_stage_model(batch.cond_text)
|
|
loss = shared.sd_model(x, c)[0] / gradient_step
|
|
del x
|
|
|
|
_loss_step += loss.item()
|
|
scaler.scale(loss).backward()
|
|
|
|
# go back until we reach gradient accumulation steps
|
|
if (j + 1) % gradient_step != 0:
|
|
continue
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
embedding.step += 1
|
|
pbar.update()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss_step = _loss_step
|
|
_loss_step = 0
|
|
|
|
steps_done = embedding.step + 1
|
|
|
|
epoch_num = embedding.step // steps_per_epoch
|
|
epoch_step = embedding.step % steps_per_epoch
|
|
|
|
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
|
|
if embedding_dir is not None and steps_done % save_embedding_every == 0:
|
|
# Before saving, change name to match current checkpoint.
|
|
embedding_name_every = f'{embedding_name}-{steps_done}'
|
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
|
|
save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
|
|
embedding_yet_to_be_embedded = True
|
|
|
|
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
|
|
"loss": f"{loss_step:.7f}",
|
|
"learn_rate": scheduler.learn_rate
|
|
})
|
|
|
|
if images_dir is not None and steps_done % create_image_every == 0:
|
|
forced_filename = f'{embedding_name}-{steps_done}'
|
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
|
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
|
|
p = processing.StableDiffusionProcessingTxt2Img(
|
|
sd_model=shared.sd_model,
|
|
do_not_save_grid=True,
|
|
do_not_save_samples=True,
|
|
do_not_reload_embeddings=True,
|
|
)
|
|
|
|
if preview_from_txt2img:
|
|
p.prompt = preview_prompt
|
|
p.negative_prompt = preview_negative_prompt
|
|
p.steps = preview_steps
|
|
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
|
p.cfg_scale = preview_cfg_scale
|
|
p.seed = preview_seed
|
|
p.width = preview_width
|
|
p.height = preview_height
|
|
else:
|
|
p.prompt = batch.cond_text[0]
|
|
p.steps = 20
|
|
p.width = training_width
|
|
p.height = training_height
|
|
|
|
preview_text = p.prompt
|
|
|
|
processed = processing.process_images(p)
|
|
image = processed.images[0] if len(processed.images) > 0 else None
|
|
|
|
if unload:
|
|
shared.sd_model.first_stage_model.to(devices.cpu)
|
|
|
|
if image is not None:
|
|
shared.state.current_image = image
|
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
last_saved_image += f", prompt: {preview_text}"
|
|
|
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
|
|
|
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
|
|
|
|
info = PngImagePlugin.PngInfo()
|
|
data = torch.load(last_saved_file)
|
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
|
|
|
title = "<{}>".format(data.get('name', '???'))
|
|
|
|
try:
|
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
|
except Exception as e:
|
|
vectorSize = '?'
|
|
|
|
checkpoint = sd_models.select_checkpoint()
|
|
footer_left = checkpoint.model_name
|
|
footer_mid = '[{}]'.format(checkpoint.hash)
|
|
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
|
|
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
|
|
|
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
|
|
embedding_yet_to_be_embedded = False
|
|
|
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
|
last_saved_image += f", prompt: {preview_text}"
|
|
|
|
shared.state.job_no = embedding.step
|
|
|
|
shared.state.textinfo = f"""
|
|
<p>
|
|
Loss: {loss_step:.7f}<br/>
|
|
Step: {steps_done}<br/>
|
|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
|
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
|
</p>
|
|
"""
|
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
|
save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
|
|
except Exception:
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
pass
|
|
finally:
|
|
pbar.leave = False
|
|
pbar.close()
|
|
shared.sd_model.first_stage_model.to(devices.device)
|
|
shared.parallel_processing_allowed = old_parallel_processing_allowed
|
|
|
|
return embedding, filename
|
|
|
|
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
|
|
old_embedding_name = embedding.name
|
|
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
|
|
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
|
|
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
|
|
try:
|
|
embedding.sd_checkpoint = checkpoint.hash
|
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
|
if remove_cached_checksum:
|
|
embedding.cached_checksum = None
|
|
embedding.name = embedding_name
|
|
embedding.optimizer_state_dict = optimizer.state_dict()
|
|
embedding.save(filename)
|
|
except:
|
|
embedding.sd_checkpoint = old_sd_checkpoint
|
|
embedding.sd_checkpoint_name = old_sd_checkpoint_name
|
|
embedding.name = old_embedding_name
|
|
embedding.cached_checksum = old_cached_checksum
|
|
raise
|