fix for flux

This commit is contained in:
Won-Kyu Park 2024-08-31 16:47:17 +09:00
parent d38732efae
commit 2d1db1a2d0
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15
5 changed files with 113 additions and 66 deletions

View File

@ -0,0 +1,4 @@
model:
target: modules.models.flux.FLUX1Inferencer
params:
state_dict: null

View File

@ -22,7 +22,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.float() return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]

View File

@ -1,9 +1,13 @@
# original code from https://github.com/black-forest-labs/flux
#
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from einops import rearrange, repeat
from torch import Tensor, nn from torch import Tensor, nn
from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock, MLPEmbedder, SingleStreamBlock,
timestep_embedding) timestep_embedding)
@ -18,7 +22,7 @@ class FluxParams:
num_heads: int num_heads: int
depth: int depth: int
depth_single_blocks: int depth_single_blocks: int
axes_dim: list[int] axes_dim: list
theta: int theta: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
@ -29,11 +33,13 @@ class Flux(nn.Module):
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
""" """
def __init__(self, params: FluxParams): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs):
super().__init__() super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params self.params = params
self.in_channels = params.in_channels self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(
@ -45,13 +51,13 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size self.hidden_size = params.hidden_size
self.num_heads = params.num_heads self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device)
self.guidance_in = ( self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity()
) )
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
@ -60,6 +66,7 @@ class Flux(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
dtype=dtype, device=device,
) )
for _ in range(params.depth) for _ in range(params.depth)
] ]
@ -67,14 +74,15 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device)
def forward( def forward_orig(
self, self,
img: Tensor, img: Tensor,
img_ids: Tensor, img_ids: Tensor,
@ -82,18 +90,18 @@ class Flux(nn.Module):
txt_ids: Tensor, txt_ids: Tensor,
timesteps: Tensor, timesteps: Tensor,
y: Tensor, y: Tensor,
guidance: Tensor | None = None, guidance: Tensor = None,
) -> Tensor: ) -> Tensor:
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.") raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img # running on sequences img
img = self.img_in(img) img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256)) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed: if self.params.guidance_embed:
if guidance is None: if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.") raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y) vec = vec + self.vector_in(y)
txt = self.txt_in(txt) txt = self.txt_in(txt)
@ -108,5 +116,32 @@ class Flux(nn.Module):
img = block(img, vec=vec, pe=pe) img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) if self.final_layer:
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
# from comfy/ldm/common_dit.py
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
bs, c, h, w = x.shape
patch_size = 2
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@ -5,11 +5,11 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from flux.math import attention, rope from ..math import attention, rope
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list[int]): def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.theta = theta self.theta = theta
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
t.device
)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -50,20 +48,20 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int): def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None):
super().__init__() super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x))) return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int, dtype=None, device=None):
super().__init__() super().__init__()
self.scale = nn.Parameter(torch.ones(dim)) self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor): def forward(self, x: Tensor):
x_dtype = x.dtype x_dtype = x.dtype
@ -73,26 +71,26 @@ class RMSNorm(torch.nn.Module):
class QKNorm(torch.nn.Module): class QKNorm(torch.nn.Module):
def __init__(self, dim: int): def __init__(self, dim: int, dtype=None, device=None):
super().__init__() super().__init__()
self.query_norm = RMSNorm(dim) self.query_norm = RMSNorm(dim, dtype=dtype, device=device)
self.key_norm = RMSNorm(dim) self.key_norm = RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q) q = self.query_norm(q)
k = self.key_norm(k) k = self.key_norm(k)
return q.to(v), k.to(v) return q.to(v), k.to(v)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim) self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor: def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x) qkv = self.qkv(x)
@ -111,13 +109,13 @@ class ModulationOut:
class Modulation(nn.Module): class Modulation(nn.Module):
def __init__(self, dim: int, double: bool): def __init__(self, dim: int, double: bool, dtype=None, device=None):
super().__init__() super().__init__()
self.is_double = double self.is_double = double
self.multiplier = 6 if double else 3 self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return ( return (
@ -127,35 +125,35 @@ class Modulation(nn.Module):
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True) self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential( self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"), nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
self.txt_mod = Modulation(hidden_size, double=True) self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential( self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"), nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
) )
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -188,6 +186,11 @@ class DoubleStreamBlock(nn.Module):
# calculate the txt bloks # calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt return img, txt
@ -202,7 +205,9 @@ class SingleStreamBlock(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: float | None = None, qk_scale: float = None,
dtype=None,
device=None,
): ):
super().__init__() super().__init__()
self.hidden_dim = hidden_size self.hidden_dim = hidden_size
@ -212,17 +217,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in # qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out # proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim) self.norm = QKNorm(head_dim, dtype=dtype, device=device)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec) mod, _ = self.modulation(vec)
@ -236,15 +241,18 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module): class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None):
super().__init__() super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor: def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)

View File

@ -7,9 +7,9 @@ from huggingface_hub import hf_hub_download
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
from safetensors.torch import load_file as load_sft from safetensors.torch import load_file as load_sft
from flux.model import Flux, FluxParams from .model import Flux, FluxParams
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.autoencoder import AutoEncoder, AutoEncoderParams
from flux.modules.conditioner import HFEmbedder from .modules.conditioner import HFEmbedder
@dataclass @dataclass