diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index d17febc68..fc1e91e9d 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -196,7 +196,8 @@ class BaseModel(torch.nn.Module): self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params) self.model_sampling = ModelSamplingFlux() - self.depth = 19 + self.depth = params['depth'] + self.depth_single_block = params['depth_single_blocks'] def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() @@ -326,13 +327,33 @@ class FLUX1Inferencer(torch.nn.Module): return width // 16 * 16, height // 16 * 16 def diffusers_weight_mapping(self): + # https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py + # please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py for i in range(self.model.depth): - yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj" - yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj" + yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj" + + yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj" + + yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2" + yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0" + yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2" + yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin" + yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin" + + for i in range(self.model.depth_single_block): + yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj" + yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj" + yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj" + + yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2" + yield f"transformer.single_transformer_blocks.{i}.morm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin" diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py index aa830849e..2202f5dcb 100644 --- a/modules/models/flux/modules/layers.py +++ b/modules/models/flux/modules/layers.py @@ -82,13 +82,17 @@ class QKNorm(torch.nn.Module): return q.to(v), k.to(v) +class QkvLinear(torch.nn.Linear): + pass + + class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device) self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) @@ -217,7 +221,7 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) # proj and mlp_out self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)