mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
textual inversion support for SDXL
This commit is contained in:
parent
4ca9f70b59
commit
6f0abbb71a
@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
|
|||||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||||
text_cond_models.append(conditioner.embedders[i])
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||||
text_cond_models.append(conditioner.embedders[i])
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
|
||||||
@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
self.textual_inversion_key = textual_inversion_key
|
||||||
|
|
||||||
def forward(self, input_ids):
|
def forward(self, input_ids):
|
||||||
batch_fixes = self.embeddings.fixes
|
batch_fixes = self.embeddings.fixes
|
||||||
@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||||||
vecs = []
|
vecs = []
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = devices.cond_cast_unet(embedding.vec)
|
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
emb = devices.cond_cast_unet(vec)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
position += 1
|
position += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
emb_len = int(embedding.vectors)
|
||||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
next_chunk()
|
next_chunk()
|
||||||
|
|
||||||
|
@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
|
|||||||
return torch.cat(res, dim=1)
|
return torch.cat(res, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||||
|
return embedder.tokenize(texts)
|
||||||
|
|
||||||
|
raise AssertionError('no tokenizer available')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def process_texts(self, texts):
|
def process_texts(self, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
return embedder.process_texts(texts)
|
return embedder.process_texts(texts)
|
||||||
@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
|
|||||||
|
|
||||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
|
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
|
||||||
|
@ -181,29 +181,38 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
param_dict = data['string_to_param']
|
param_dict = data['string_to_param']
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
# diffuser concepts
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
emb = next(iter(data.values()))
|
||||||
if len(emb.shape) == 1:
|
if len(emb.shape) == 1:
|
||||||
emb = emb.unsqueeze(0)
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = data.get('step', None)
|
embedding.step = data.get('step', None)
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
embedding.vectors = vec.shape[0]
|
embedding.vectors = vectors
|
||||||
embedding.shape = vec.shape[-1]
|
embedding.shape = shape
|
||||||
embedding.filename = path
|
embedding.filename = path
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user