add flux model wrapper

This commit is contained in:
Won-Kyu Park 2024-08-31 16:40:28 +09:00
parent 853551bd6e
commit d38732efae
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 343 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from .flux import FLUX1Inferencer
__all__ = [
"FLUX1Inferencer",
]

338
modules/models/flux/flux.py Normal file
View File

@ -0,0 +1,338 @@
import contextlib
import os
import safetensors
import torch
import math
import k_diffusion
from transformers import CLIPTokenizer
from modules import shared, devices, modelloader, sd_hijack_clip
from modules.models.sd3.sd3_impls import SDVAE
from modules.models.sd3.sd3_cond import CLIPL_CONFIG, T5_CONFIG, CLIPL_URL, T5_URL, SafetensorsMapping, Sd3T5
from modules.models.sd3.other_impls import SDClipModel, T5XXLModel, SDTokenizer, T5XXLTokenizer
from PIL import Image
from .model import Flux
class FluxTokenizer:
def __init__(self):
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.t5xxl = T5XXLTokenizer()
def tokenize_with_weights(self, text:str):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)
return out
class Flux1ClipL(sd_hijack_clip.TextConditionalModel):
def __init__(self, clip_l):
super().__init__()
self.clip_l = clip_l
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
empty = self.tokenizer('')["input_ids"]
self.id_start = empty[0]
self.id_end = empty[1]
self.id_pad = empty[1]
self.return_pooled = True
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def encode_with_transformers(self, tokens):
tokens_g = tokens.clone()
for batch_pos in range(tokens_g.shape[0]):
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
l_out, l_pooled = self.clip_l(tokens)
l_out = torch.cat([l_out], dim=-1)
l_out = torch.nn.functional.pad(l_out, (0, 4096 - l_out.shape[-1]))
vector_out = torch.cat([l_pooled], dim=-1)
l_out.pooled = vector_out
return l_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
class FluxCond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = FluxTokenizer()
with torch.no_grad():
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
else:
self.t5xxl = None
self.model_l = Flux1ClipL(self.clip_l)
self.model_t5 = Sd3T5(self.t5xxl)
def forward(self, prompts: list[str]):
with devices.without_autocast():
l_out, vector_out = self.model_l(prompts)
t5_out = self.model_t5(prompts, token_count=l_out.shape[1])
lt_out = torch.cat([l_out, t5_out], dim=-2)
return {
'crossattn': lt_out,
'vector': vector_out,
}
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return self.model_l.encode_embedding_init_text(init_text, nvpt)
def tokenize(self, texts):
return self.model_l.tokenize(texts)
def medvram_modules(self):
return [self.clip_l, self.t5xxl]
def get_token_count(self, text):
_, token_count = self.model_l.process_texts([text])
return token_count
def get_target_prompt_token_count(self, token_count):
return self.model_l.get_target_prompt_token_count(token_count)
def flux_time_shift(mu: float, sigma: float, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma
def sigma(self, timestep):
return flux_time_shift(self.shift, 1.0, timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class BaseModel(torch.nn.Module):
"""Wrapper around the core FLUX model"""
def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""):
super().__init__()
params = dict(
image_model="flux",
in_channels=16,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10000,
qkv_bias=True,
guidance_embed=True,
)
self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params)
self.model_sampling = ModelSamplingFlux()
self.depth = 19
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
timestep = self.model_sampling.timestep(sigma).float()
guidance = torch.FloatTensor([3.5]).to(device=devices.device, dtype=dtype)
model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype), guidance=guidance).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
def get_dtype(self):
return self.diffusion_model.dtype
class FLUX1LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor
def decode_latent_to_preview(self, x0):
"""Quick RGB approximate preview of sd3 latents"""
factors = torch.tensor([
[-0.0404, 0.0159, 0.0609], [ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503], [-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530], [ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262], [-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001], [ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013], [ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976], [-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020], [-0.1273, -0.0932, -0.0680],
], device="cpu")
latent_image = x0[0].permute(1, 2, 0).cpu() @ factors
latents_ubyte = (((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()).cpu()
return Image.fromarray(latents_ubyte.numpy())
class FLUX1Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model
def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)
class FLUX1Inferencer(torch.nn.Module):
def __init__(self, state_dict, use_ema=False):
super().__init__()
# detect model_prefix
diffusion_model_prefix = "model.diffusion_model."
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
diffusion_model_prefix = "model.diffusion_model."
elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
diffusion_model_prefix = ""
with torch.no_grad():
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.text_encoders = FluxCond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
self.model.conditioning_key = "crossattn"
self.latent_format = FLUX1LatentFormat()
self.latent_channels = 16
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()
def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)
def get_first_stage_encoding(self, x):
return x
def create_denoiser(self):
return FLUX1Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'text_encoders'),
(self, 'model'),
]
def add_noise_to_latent(self, x, noise, amount):
return x * (1 - amount) + noise * amount
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16
def diffusers_weight_mapping(self):
for i in range(self.model.depth):
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"