add diffusers weight mapping for flux lora

* add QkvLinear class for Flux lora
This commit is contained in:
Won-Kyu Park 2024-09-05 20:15:37 +09:00
parent 477ff35517
commit d6a609a539
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
2 changed files with 36 additions and 11 deletions

View File

@ -196,7 +196,8 @@ class BaseModel(torch.nn.Module):
self.diffusion_model = Flux(device=device, dtype=devices.dtype, **params)
self.model_sampling = ModelSamplingFlux()
self.depth = 19
self.depth = params['depth']
self.depth_single_block = params['depth_single_blocks']
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
@ -326,13 +327,33 @@ class FLUX1Inferencer(torch.nn.Module):
return width // 16 * 16, height // 16 * 16
def diffusers_weight_mapping(self):
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# please see also https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
for i in range(self.model.depth):
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_double_blocks_{i}_txt_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_q_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_k_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_v_proj", f"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.add_out_proj.0", f"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_add_out", f"diffusion_model_double_blocks_{i}_txt_attn_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_k", f"diffusion_model_double_blocks_{i}_img_attn_qkv_k_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_q", f"diffusion_model_double_blocks_{i}_img_attn_qkv_q_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_v", f"diffusion_model_double_blocks_{i}_img_attn_qkv_v_proj"
yield f"transformer.transformer_blocks.{i}.attn.to_out.0", f"diffusion_model_double_blocks_{i}_img_attn_proj"
yield f"transformer.transformer_blocks.{i}.ff.net.0.proj", f"diffusion_model_double_blocks_{i}_img_mlp_0"
yield f"transformer.transformer_blocks.{i}.ff.net.2", f"diffusion_model_double_blocks_{i}_img_mlp_2"
yield f"transformer.transformer_blocks.{i}.ff_context.net.0.proj", f"diffusion_model_double_blocks_{i}_txt_mlp_0"
yield f"transformer.transformer_blocks.{i}.ff_context.net.2", f"diffusion_model_double_blocks_{i}_txt_mlp_2"
yield f"transformer.transformer_blocks.{i}.norm1.linear", f"diffusion_model_double_blocks_{i}_img_mod_lin"
yield f"transformer.transformer_blocks.{i}.norm1_context.linear", f"diffusion_model_double_blocks_{i}_txt_mod_lin"
for i in range(self.model.depth_single_block):
yield f"transformer.single_transformer_blocks.{i}.attn.to_q", f"diffusion_model_single_blocks_{i}_linear1_q_proj"
yield f"transformer.single_transformer_blocks.{i}.attn.to_k", f"diffusion_model_single_blocks_{i}_linear1_k_proj"
yield f"transformer.single_transformer_blocks.{i}.attn.to_v", f"diffusion_model_single_blocks_{i}_linear1_v_proj"
yield f"transformer.single_transformer_blocks.{i}.proj_mlp", f"diffusion_model_single_blocks_{i}_linear1_mlp_proj"
yield f"transformer.single_transformer_blocks.{i}.proj_out", f"diffusion_model_single_blocks_{i}_linear2"
yield f"transformer.single_transformer_blocks.{i}.morm.linear", f"diffusion_model_single_blocks_{i}_modulation_lin"

View File

@ -82,13 +82,17 @@ class QKNorm(torch.nn.Module):
return q.to(v), k.to(v)
class QkvLinear(torch.nn.Linear):
pass
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
@ -217,7 +221,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
self.linear1 = QkvLinear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)