mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
support for SD3: infinite prompt length, token counting
This commit is contained in:
parent
a8fba9af35
commit
d686e73daa
225
modules/models/sd3/sd3_cond.py
Normal file
225
modules/models/sd3/sd3_cond.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
|
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
|
||||||
|
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SafetensorsMapping(typing.Mapping):
|
||||||
|
def __init__(self, file):
|
||||||
|
self.file = file
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file.keys())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for key in self.file.keys():
|
||||||
|
yield key
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.file.get_tensor(key)
|
||||||
|
|
||||||
|
|
||||||
|
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
||||||
|
CLIPL_CONFIG = {
|
||||||
|
"hidden_act": "quick_gelu",
|
||||||
|
"hidden_size": 768,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
||||||
|
CLIPG_CONFIG = {
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"intermediate_size": 5120,
|
||||||
|
"num_attention_heads": 20,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
||||||
|
T5_CONFIG = {
|
||||||
|
"d_ff": 10240,
|
||||||
|
"d_model": 4096,
|
||||||
|
"num_heads": 64,
|
||||||
|
"num_layers": 24,
|
||||||
|
"vocab_size": 32128,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
|
||||||
|
def __init__(self, clip_l, clip_g):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.clip_l = clip_l
|
||||||
|
self.clip_g = clip_g
|
||||||
|
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
self.id_start = empty[0]
|
||||||
|
self.id_end = empty[1]
|
||||||
|
self.id_pad = empty[1]
|
||||||
|
|
||||||
|
self.return_pooled = True
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
tokens_g = tokens.clone()
|
||||||
|
|
||||||
|
for batch_pos in range(tokens_g.shape[0]):
|
||||||
|
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
|
||||||
|
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
|
||||||
|
|
||||||
|
l_out, l_pooled = self.clip_l(tokens)
|
||||||
|
g_out, g_pooled = self.clip_g(tokens_g)
|
||||||
|
|
||||||
|
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||||
|
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||||
|
|
||||||
|
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
|
lg_out.pooled = vector_out
|
||||||
|
return lg_out
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3T5(torch.nn.Module):
|
||||||
|
def __init__(self, t5xxl):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.t5xxl = t5xxl
|
||||||
|
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
|
||||||
|
|
||||||
|
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
|
||||||
|
self.id_end = empty[0]
|
||||||
|
self.id_pad = empty[1]
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
def tokenize_line(self, line, *, target_token_count=None):
|
||||||
|
if shared.opts.emphasis != "None":
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
multipliers = []
|
||||||
|
|
||||||
|
for text_tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
if text == 'BREAK' and weight == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tokens += text_tokens
|
||||||
|
multipliers += [weight] * len(text_tokens)
|
||||||
|
|
||||||
|
tokens += [self.id_end]
|
||||||
|
multipliers += [1.0]
|
||||||
|
|
||||||
|
if target_token_count is not None:
|
||||||
|
if len(tokens) < target_token_count:
|
||||||
|
tokens += [self.id_pad] * (target_token_count - len(tokens))
|
||||||
|
multipliers += [1.0] * (target_token_count - len(tokens))
|
||||||
|
else:
|
||||||
|
tokens = tokens[0:target_token_count]
|
||||||
|
multipliers = multipliers[0:target_token_count]
|
||||||
|
|
||||||
|
return tokens, multipliers
|
||||||
|
|
||||||
|
def forward(self, texts, *, token_count):
|
||||||
|
if not self.t5xxl or not shared.opts.sd3_enable_t5:
|
||||||
|
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
|
||||||
|
|
||||||
|
tokens_batch = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
|
||||||
|
tokens_batch.append(tokens)
|
||||||
|
|
||||||
|
t5_out, t5_pooled = self.t5xxl(tokens_batch)
|
||||||
|
|
||||||
|
return t5_out
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Cond(torch.nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.tokenizer = SD3Tokenizer()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
||||||
|
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||||
|
|
||||||
|
if shared.opts.sd3_enable_t5:
|
||||||
|
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
||||||
|
else:
|
||||||
|
self.t5xxl = None
|
||||||
|
|
||||||
|
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||||
|
self.model_t5 = Sd3T5(self.t5xxl)
|
||||||
|
|
||||||
|
self.weights_loaded = False
|
||||||
|
|
||||||
|
def forward(self, prompts: list[str]):
|
||||||
|
lg_out, vector_out = self.model_lg(prompts)
|
||||||
|
|
||||||
|
token_count = lg_out.shape[1]
|
||||||
|
|
||||||
|
t5_out = self.model_t5(prompts, token_count=token_count)
|
||||||
|
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'crossattn': lgt_out,
|
||||||
|
'vector': vector_out,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_weights(self):
|
||||||
|
if self.weights_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||||
|
|
||||||
|
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||||
|
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||||
|
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||||
|
|
||||||
|
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||||
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||||
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
|
if self.t5xxl:
|
||||||
|
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||||
|
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||||
|
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
|
self.weights_loaded = True
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
return torch.tensor([[0]], device=devices.device) # XXX
|
||||||
|
|
||||||
|
def medvram_modules(self):
|
||||||
|
return [self.clip_g, self.clip_l, self.t5xxl]
|
||||||
|
|
||||||
|
def get_token_count(self, text):
|
||||||
|
_, token_count = self.model_lg.process_texts([text])
|
||||||
|
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
return self.model_lg.get_target_prompt_token_count(token_count)
|
@ -1,127 +1,12 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import k_diffusion
|
import k_diffusion
|
||||||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
|
||||||
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
|
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
|
||||||
|
from modules.models.sd3.sd3_cond import SD3Cond
|
||||||
|
|
||||||
from modules import shared, modelloader, devices
|
from modules import shared, devices
|
||||||
|
|
||||||
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
|
|
||||||
CLIPG_CONFIG = {
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_size": 1280,
|
|
||||||
"intermediate_size": 5120,
|
|
||||||
"num_attention_heads": 20,
|
|
||||||
"num_hidden_layers": 32,
|
|
||||||
}
|
|
||||||
|
|
||||||
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
|
|
||||||
CLIPL_CONFIG = {
|
|
||||||
"hidden_act": "quick_gelu",
|
|
||||||
"hidden_size": 768,
|
|
||||||
"intermediate_size": 3072,
|
|
||||||
"num_attention_heads": 12,
|
|
||||||
"num_hidden_layers": 12,
|
|
||||||
}
|
|
||||||
|
|
||||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
|
||||||
T5_CONFIG = {
|
|
||||||
"d_ff": 10240,
|
|
||||||
"d_model": 4096,
|
|
||||||
"num_heads": 64,
|
|
||||||
"num_layers": 24,
|
|
||||||
"vocab_size": 32128,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SafetensorsMapping(Mapping):
|
|
||||||
def __init__(self, file):
|
|
||||||
self.file = file
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.file.keys())
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for key in self.file.keys():
|
|
||||||
yield key
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.file.get_tensor(key)
|
|
||||||
|
|
||||||
|
|
||||||
class SD3Cond(torch.nn.Module):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.tokenizer = SD3Tokenizer()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
|
|
||||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
|
||||||
|
|
||||||
if shared.opts.sd3_enable_t5:
|
|
||||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
|
|
||||||
else:
|
|
||||||
self.t5xxl = None
|
|
||||||
|
|
||||||
self.weights_loaded = False
|
|
||||||
|
|
||||||
def forward(self, prompts: list[str]):
|
|
||||||
res = []
|
|
||||||
|
|
||||||
for prompt in prompts:
|
|
||||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
|
||||||
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
|
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
|
|
||||||
|
|
||||||
if self.t5xxl and shared.opts.sd3_enable_t5:
|
|
||||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
|
|
||||||
else:
|
|
||||||
t5_out = torch.zeros(l_out.shape[0:2] + (4096,), dtype=l_out.dtype, device=l_out.device)
|
|
||||||
|
|
||||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
|
||||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
|
||||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
|
||||||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
|
||||||
|
|
||||||
res.append({
|
|
||||||
'crossattn': lgt_out[0].to(devices.device),
|
|
||||||
'vector': vector_out[0].to(devices.device),
|
|
||||||
})
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def load_weights(self):
|
|
||||||
if self.weights_loaded:
|
|
||||||
return
|
|
||||||
|
|
||||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
|
||||||
|
|
||||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
|
||||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
|
||||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
|
||||||
|
|
||||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
|
||||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
|
||||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
||||||
|
|
||||||
if self.t5xxl:
|
|
||||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
|
||||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
|
||||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
|
||||||
|
|
||||||
self.weights_loaded = True
|
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
|
||||||
return torch.tensor([[0]], device=devices.device) # XXX
|
|
||||||
|
|
||||||
def medvram_modules(self):
|
|
||||||
return [self.clip_g, self.clip_l, self.t5xxl]
|
|
||||||
|
|
||||||
|
|
||||||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||||
|
@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
|
|||||||
|
|
||||||
|
|
||||||
class DictWithShape(dict):
|
class DictWithShape(dict):
|
||||||
def __init__(self, x, shape):
|
def __init__(self, x, shape=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.update(x)
|
self.update(x)
|
||||||
|
|
||||||
|
@ -325,6 +325,9 @@ class StableDiffusionModelHijack:
|
|||||||
if self.clip is None:
|
if self.clip is None:
|
||||||
return "-", "-"
|
return "-", "-"
|
||||||
|
|
||||||
|
if hasattr(self.clip, 'get_token_count'):
|
||||||
|
token_count = self.clip.get_token_count(text)
|
||||||
|
else:
|
||||||
_, token_count = self.clip.process_texts([text])
|
_, token_count = self.clip.process_texts([text])
|
||||||
|
|
||||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||||
|
@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
|
|||||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
class TextConditionalModel(torch.nn.Module):
|
||||||
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
def __init__(self):
|
||||||
have unlimited prompt length and assign weights to tokens in prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, wrapped, hijack):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.wrapped = wrapped
|
self.hijack = sd_hijack.model_hijack
|
||||||
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
|
||||||
depending on model."""
|
|
||||||
|
|
||||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
|
||||||
self.chunk_length = 75
|
self.chunk_length = 75
|
||||||
|
|
||||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
self.is_trainable = False
|
||||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
self.input_key = 'txt'
|
||||||
self.legacy_ucg_val = None
|
self.return_pooled = False
|
||||||
|
|
||||||
|
self.comma_token = None
|
||||||
|
self.id_start = None
|
||||||
|
self.id_end = None
|
||||||
|
self.id_pad = None
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
"""creates an empty PromptChunk and returns it"""
|
"""creates an empty PromptChunk and returns it"""
|
||||||
@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if opts.use_old_emphasis_implementation:
|
|
||||||
import modules.sd_hijack_clip_old
|
|
||||||
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
|
||||||
|
|
||||||
batch_chunks, token_count = self.process_texts(texts)
|
batch_chunks, token_count = self.process_texts(texts)
|
||||||
|
|
||||||
used_embeddings = {}
|
used_embeddings = {}
|
||||||
@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||||
|
|
||||||
if getattr(self.wrapped, 'return_pooled', False):
|
if self.return_pooled:
|
||||||
return torch.hstack(zs), zs[0].pooled
|
return torch.hstack(zs), zs[0].pooled
|
||||||
else:
|
else:
|
||||||
return torch.hstack(zs)
|
return torch.hstack(zs)
|
||||||
@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):
|
||||||
|
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||||
|
have unlimited prompt length and assign weights to tokens in prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hijack = hijack
|
||||||
|
|
||||||
|
self.wrapped = wrapped
|
||||||
|
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||||
|
depending on model."""
|
||||||
|
|
||||||
|
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||||
|
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||||
|
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
|
||||||
|
|
||||||
|
self.legacy_ucg_val = None # for sgm codebase
|
||||||
|
|
||||||
|
def forward(self, texts):
|
||||||
|
if opts.use_old_emphasis_implementation:
|
||||||
|
import modules.sd_hijack_clip_old
|
||||||
|
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||||
|
|
||||||
|
return super().forward(texts)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__(wrapped, hijack)
|
super().__init__(wrapped, hijack)
|
||||||
|
@ -722,7 +722,12 @@ def get_empty_cond(sd_model):
|
|||||||
d = sd_model.get_learned_conditioning([""])
|
d = sd_model.get_learned_conditioning([""])
|
||||||
return d['crossattn']
|
return d['crossattn']
|
||||||
else:
|
else:
|
||||||
return sd_model.cond_stage_model([""])
|
d = sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
if isinstance(d, dict):
|
||||||
|
d = d['crossattn']
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def send_model_to_cpu(m):
|
def send_model_to_cpu(m):
|
||||||
|
Loading…
Reference in New Issue
Block a user