support Flux1

This commit is contained in:
Won-Kyu Park 2024-08-31 16:46:34 +09:00
parent 9c0fd83b5e
commit 219a0e2429
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
5 changed files with 40 additions and 14 deletions

View File

@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear):
def forward(self, x):
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
class AutocastLayerNorm(nn.LayerNorm):
def forward(self, x):
return torch.nn.functional.layer_norm(
x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps)
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
@ -41,9 +46,9 @@ class Mlp(nn.Module):
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device):
super().__init__()
self.heads = heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None):
q = self.q_proj(x)
@ -82,9 +87,11 @@ ACTIVATIONS = {
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
#self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
#self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module):
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module):
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device)
embed_dim = config_dict["hidden_size"]
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype

View File

@ -33,6 +33,7 @@ class ModelType(enum.Enum):
SDXL = 3
SSD = 4
SD3 = 5
FLUX1 = 6
def replace_key(d, key, new_key, value):
@ -369,7 +370,7 @@ def check_fp8(model):
enable_fp8 = False
elif shared.opts.fp8_storage == "Enable":
enable_fp8 = True
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL":
enable_fp8 = True
else:
enable_fp8 = False
@ -382,10 +383,14 @@ def set_model_type(model, state_dict):
model.is_sdxl = False
model.is_ssd = False
model.is_sd3 = False
model.is_flux1 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
model.is_flux1 = True
model.model_type = ModelType.FLUX1
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
@ -777,6 +782,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight'
clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight'
t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight'
class SdModelData:
@ -909,7 +917,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
if not checkpoint_config:
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict)
timer.record("find config")

View File

@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict):
@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename):
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
return config_flux1
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting

View File

@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion):
is_sd3: bool
"""True if the model's architecture is SD 3"""
is_flux1: bool
"""True if the model's architecture is FLUX 1"""
latent_channels: int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""

View File

@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict(
@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module):
super().__init__()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict(
@ -97,6 +97,8 @@ def download_model(model_path, model_url):
def decoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif shared.sd_model.is_flux1:
model_name = "taef1_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
@ -122,6 +124,8 @@ def decoder_model():
def encoder_model():
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif shared.sd_model.is_flux1:
model_name = "taef1_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else: