textual inversion support for SDXL

This commit is contained in:
AUTOMATIC1111 2023-07-29 15:15:06 +03:00
parent 4ca9f70b59
commit 6f0abbb71a
4 changed files with 29 additions and 9 deletions

View File

@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i]) text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2': if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i]) text_cond_models.append(conditioner.embedders[i])
@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings): def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.embeddings = embeddings self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids): def forward(self, input_ids):
batch_fixes = self.embeddings.fixes batch_fixes = self.embeddings.fixes
@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = [] vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = devices.cond_cast_unet(embedding.vec) vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])

View File

@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += 1 position += 1
continue continue
emb_len = int(embedding.vec.shape[0]) emb_len = int(embedding.vectors)
if len(chunk.tokens) + emb_len > self.chunk_length: if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk() next_chunk()

View File

@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
return torch.cat(res, dim=1) return torch.cat(res, dim=1)
def tokenize(self: sgm.modules.GeneralConditioner, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
return embedder.tokenize(texts)
raise AssertionError('no tokenizer available')
def process_texts(self, texts): def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts) return embedder.process_texts(texts)
@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count

View File

@ -181,29 +181,38 @@ class EmbeddingDatabase:
else: else:
return return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
param_dict = data['string_to_param'] param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts vec = emb.detach().to(devices.device, dtype=torch.float32)
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values())) emb = next(iter(data.values()))
if len(emb.shape) == 1: if len(emb.shape) == 1:
emb = emb.unsqueeze(0) emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else: else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0] embedding.vectors = vectors
embedding.shape = vec.shape[-1] embedding.shape = shape
embedding.filename = path embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')