mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-04-11 23:39:01 +08:00
Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Add lora-embedding bundle system
This commit is contained in:
commit
4be7b620c2
33
extensions-builtin/Lora/lora_logger.py
Normal file
33
extensions-builtin/Lora/lora_logger.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
COLORS = {
|
||||||
|
"DEBUG": "\033[0;36m", # CYAN
|
||||||
|
"INFO": "\033[0;32m", # GREEN
|
||||||
|
"WARNING": "\033[0;33m", # YELLOW
|
||||||
|
"ERROR": "\033[0;31m", # RED
|
||||||
|
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
||||||
|
"RESET": "\033[0m", # RESET COLOR
|
||||||
|
}
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
colored_record = copy.copy(record)
|
||||||
|
levelname = colored_record.levelname
|
||||||
|
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
||||||
|
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
||||||
|
return super().format(colored_record)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("lora")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(
|
||||||
|
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
|
||||||
|
)
|
||||||
|
logger.addHandler(handler)
|
@ -93,6 +93,7 @@ class Network: # LoraModule
|
|||||||
self.unet_multiplier = 1.0
|
self.unet_multiplier = 1.0
|
||||||
self.dyn_dim = None
|
self.dyn_dim = None
|
||||||
self.modules = {}
|
self.modules = {}
|
||||||
|
self.bundle_embeddings = {}
|
||||||
self.mtime = None
|
self.mtime = None
|
||||||
|
|
||||||
self.mentioned_name = None
|
self.mentioned_name = None
|
||||||
|
@ -16,6 +16,9 @@ import torch
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||||
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
|
|
||||||
|
from lora_logger import logger
|
||||||
|
|
||||||
module_types = [
|
module_types = [
|
||||||
network_lora.ModuleTypeLora(),
|
network_lora.ModuleTypeLora(),
|
||||||
@ -151,9 +154,19 @@ def load_network(name, network_on_disk):
|
|||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||||
|
|
||||||
matched_networks = {}
|
matched_networks = {}
|
||||||
|
bundle_embeddings = {}
|
||||||
|
|
||||||
for key_network, weight in sd.items():
|
for key_network, weight in sd.items():
|
||||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||||
|
if key_network_without_network_parts == "bundle_emb":
|
||||||
|
emb_name, vec_name = network_part.split(".", 1)
|
||||||
|
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||||
|
if vec_name.split('.')[0] == 'string_to_param':
|
||||||
|
_, k2 = vec_name.split('.', 1)
|
||||||
|
emb_dict['string_to_param'] = {k2: weight}
|
||||||
|
else:
|
||||||
|
emb_dict[vec_name] = weight
|
||||||
|
bundle_embeddings[emb_name] = emb_dict
|
||||||
|
|
||||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
@ -197,6 +210,14 @@ def load_network(name, network_on_disk):
|
|||||||
|
|
||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
|
embeddings = {}
|
||||||
|
for emb_name, data in bundle_embeddings.items():
|
||||||
|
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
|
||||||
|
embedding.loaded = None
|
||||||
|
embeddings[emb_name] = embedding
|
||||||
|
|
||||||
|
net.bundle_embeddings = embeddings
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
@ -212,11 +233,15 @@ def purge_networks_from_memory():
|
|||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
|
emb_db = sd_hijack.model_hijack.embedding_db
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
if net.name in names:
|
if net.name in names:
|
||||||
already_loaded[net.name] = net
|
already_loaded[net.name] = net
|
||||||
|
for emb_name, embedding in net.bundle_embeddings.items():
|
||||||
|
if embedding.loaded:
|
||||||
|
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
||||||
|
|
||||||
loaded_networks.clear()
|
loaded_networks.clear()
|
||||||
|
|
||||||
@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
|
for emb_name, embedding in net.bundle_embeddings.items():
|
||||||
|
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
|
||||||
|
logger.warning(
|
||||||
|
f'Skip bundle embedding: "{emb_name}"'
|
||||||
|
' as it was already loaded from embeddings folder'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
embedding.loaded = False
|
||||||
|
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
|
||||||
|
embedding.loaded = True
|
||||||
|
emb_db.register_embedding(embedding, shared.sd_model)
|
||||||
|
else:
|
||||||
|
emb_db.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
@ -567,6 +607,7 @@ extra_network_lora = None
|
|||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
loaded_networks = []
|
loaded_networks = []
|
||||||
|
loaded_bundle_embeddings = {}
|
||||||
networks_in_memory = {}
|
networks_in_memory = {}
|
||||||
available_network_hash_lookup = {}
|
available_network_hash_lookup = {}
|
||||||
forbidden_network_aliases = {}
|
forbidden_network_aliases = {}
|
||||||
|
@ -181,40 +181,7 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
|
||||||
embedding.step = data.get('step', None)
|
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.filename = path
|
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
||||||
|
if 'string_to_param' in data: # textual inversion embeddings
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
embedding.vectors = vectors
|
||||||
|
embedding.shape = shape
|
||||||
|
|
||||||
|
if filepath:
|
||||||
|
embedding.filename = filepath
|
||||||
|
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if shared.opts.training_write_csv_every == 0:
|
if shared.opts.training_write_csv_every == 0:
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user