From 219a0e242900388c9f68628294a1881821c9e920 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:46:34 +0900 Subject: [PATCH] support Flux1 --- modules/models/sd3/other_impls.py | 27 +++++++++++++++++---------- modules/sd_models.py | 12 ++++++++++-- modules/sd_models_config.py | 4 ++++ modules/sd_models_types.py | 3 +++ modules/sd_vae_taesd.py | 8 ++++++-- 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 78c1dc687..695d356cb 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear): def forward(self, x): return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None) +class AutocastLayerNorm(nn.LayerNorm): + def forward(self, x): + return torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps) + def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" @@ -41,9 +46,9 @@ class Mlp(nn.Module): out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) + self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device) self.act = act_layer - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) + self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device) def forward(self, x): x = self.fc1(x) @@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device): super().__init__() self.heads = heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) def forward(self, x, mask=None): q = self.q_proj(x) @@ -82,9 +87,11 @@ ACTIVATIONS = { class CLIPLayer(torch.nn.Module): def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device): super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) + #self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device) @@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module): super().__init__() self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l')) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device) def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True): x = self.embeddings(input_tokens) @@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module): self.num_layers = config_dict["num_hidden_layers"] self.text_model = CLIPTextModel_(config_dict, dtype, device) embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) + self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) self.text_projection.weight.copy_(torch.eye(embed_dim)) self.dtype = dtype diff --git a/modules/sd_models.py b/modules/sd_models.py index b4702151a..7e48c2328 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -33,6 +33,7 @@ class ModelType(enum.Enum): SDXL = 3 SSD = 4 SD3 = 5 + FLUX1 = 6 def replace_key(d, key, new_key, value): @@ -369,7 +370,7 @@ def check_fp8(model): enable_fp8 = False elif shared.opts.fp8_storage == "Enable": enable_fp8 = True - elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL": enable_fp8 = True else: enable_fp8 = False @@ -382,10 +383,14 @@ def set_model_type(model, state_dict): model.is_sdxl = False model.is_ssd = False model.is_sd3 = False + model.is_flux1 = False if "model.diffusion_model.x_embedder.proj.weight" in state_dict: model.is_sd3 = True model.model_type = ModelType.SD3 + elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: + model.is_flux1 = True + model.model_type = ModelType.FLUX1 elif hasattr(model, 'conditioner'): model.is_sdxl = True @@ -777,6 +782,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' +clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight' +clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight' +t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight' class SdModelData: @@ -909,7 +917,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_ if not checkpoint_config: checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 3c1e4a151..4251062c8 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") +config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml") def is_using_v_parameterization_for_sd2(state_dict): @@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename): if "model.diffusion_model.x_embedder.proj.weight" in sd: return config_sd3 + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + return config_flux1 + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 2fce2777b..867f8b6e2 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion): is_sd3: bool """True if the model's architecture is SD 3""" + is_flux1: bool + """True if the model's architecture is FLUX 1""" + latent_channels: int """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index d06253d2a..76771e95e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(decoder_path) else 4 + latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4 self.decoder = decoder(latent_channels) self.decoder.load_state_dict( @@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module): super().__init__() if latent_channels is None: - latent_channels = 16 if "taesd3" in str(encoder_path) else 4 + latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4 self.encoder = encoder(latent_channels) self.encoder.load_state_dict( @@ -97,6 +97,8 @@ def download_model(model_path, model_url): def decoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_decoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_decoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_decoder.pth" else: @@ -122,6 +124,8 @@ def decoder_model(): def encoder_model(): if shared.sd_model.is_sd3: model_name = "taesd3_encoder.pth" + elif shared.sd_model.is_flux1: + model_name = "taef1_encoder.pth" elif shared.sd_model.is_sdxl: model_name = "taesdxl_encoder.pth" else: