fix for flux

This commit is contained in:
Won-Kyu Park 2024-08-31 16:47:17 +09:00
parent d38732efae
commit 2d1db1a2d0
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
5 changed files with 113 additions and 66 deletions

View File

@ -0,0 +1,4 @@
model:
target: modules.models.flux.FLUX1Inferencer
params:
state_dict: null

View File

@ -22,7 +22,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]

View File

@ -1,9 +1,13 @@
# original code from https://github.com/black-forest-labs/flux
#
from dataclasses import dataclass
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
@ -18,7 +22,7 @@ class FluxParams:
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
axes_dim: list
theta: int
qkv_bias: bool
guidance_embed: bool
@ -29,11 +33,13 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
@ -45,13 +51,13 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
@ -60,6 +66,7 @@ class Flux(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device,
)
for _ in range(params.depth)
]
@ -67,14 +74,15 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device)
def forward(
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
@ -82,18 +90,18 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
@ -108,5 +116,32 @@ class Flux(nn.Module):
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
if self.final_layer:
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
# from comfy/ldm/common_dit.py
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
bs, c, h, w = x.shape
patch_size = 2
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -5,11 +5,11 @@ import torch
from einops import rearrange
from torch import Tensor, nn
from flux.math import attention, rope
from ..math import attention, rope
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
self.dim = dim
self.theta = theta
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -50,20 +48,20 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int):
def __init__(self, dim: int, dtype=None, device=None):
super().__init__()
self.scale = nn.Parameter(torch.ones(dim))
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
@ -73,26 +71,26 @@ class RMSNorm(torch.nn.Module):
class QKNorm(torch.nn.Module):
def __init__(self, dim: int):
def __init__(self, dim: int, dtype=None, device=None):
super().__init__()
self.query_norm = RMSNorm(dim)
self.key_norm = RMSNorm(dim)
self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
@ -111,13 +109,13 @@ class ModulationOut:
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
def __init__(self, dim: int, double: bool, dtype=None, device=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
@ -127,35 +125,35 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -188,6 +186,11 @@ class DoubleStreamBlock(nn.Module):
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
@ -202,7 +205,9 @@ class SingleStreamBlock(nn.Module):
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
qk_scale: float = None,
dtype=None,
device=None,
):
super().__init__()
self.hidden_dim = hidden_size
@ -212,17 +217,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim)
self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False)
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
@ -236,15 +241,18 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)

View File

@ -7,9 +7,9 @@ from huggingface_hub import hf_hub_download
from imwatermark import WatermarkEncoder
from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder
from .model import Flux, FluxParams
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
from .modules.conditioner import HFEmbedder
@dataclass