From 2d1db1a2d0878cd56521b71def2de1c5e7ca279c Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 31 Aug 2024 16:47:17 +0900 Subject: [PATCH] fix for flux --- configs/flux1-inference.yaml | 4 ++ modules/models/flux/math.py | 2 +- modules/models/flux/model.py | 67 ++++++++++++----- modules/models/flux/modules/layers.py | 100 ++++++++++++++------------ modules/models/flux/util.py | 6 +- 5 files changed, 113 insertions(+), 66 deletions(-) create mode 100644 configs/flux1-inference.yaml diff --git a/configs/flux1-inference.yaml b/configs/flux1-inference.yaml new file mode 100644 index 000000000..f9bbe9073 --- /dev/null +++ b/configs/flux1-inference.yaml @@ -0,0 +1,4 @@ +model: + target: modules.models.flux.FLUX1Inferencer + params: + state_dict: null diff --git a/modules/models/flux/math.py b/modules/models/flux/math.py index 0156bb6a2..4ad99818e 100644 --- a/modules/models/flux/math.py +++ b/modules/models/flux/math.py @@ -22,7 +22,7 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: return out.float() -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: +def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] diff --git a/modules/models/flux/model.py b/modules/models/flux/model.py index f33ab8323..c87397827 100644 --- a/modules/models/flux/model.py +++ b/modules/models/flux/model.py @@ -1,9 +1,13 @@ +# original code from https://github.com/black-forest-labs/flux +# from dataclasses import dataclass import torch + +from einops import rearrange, repeat from torch import Tensor, nn -from flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, +from .modules.layers import (DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding) @@ -18,7 +22,7 @@ class FluxParams: num_heads: int depth: int depth_single_blocks: int - axes_dim: list[int] + axes_dim: list theta: int qkv_bias: bool guidance_embed: bool @@ -29,11 +33,13 @@ class Flux(nn.Module): Transformer model for flow matching on sequences. """ - def __init__(self, params: FluxParams): + def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, **kwargs): super().__init__() + self.dtype = dtype + params = FluxParams(**kwargs) self.params = params - self.in_channels = params.in_channels + self.in_channels = params.in_channels * 2 * 2 self.out_channels = self.in_channels if params.hidden_size % params.num_heads != 0: raise ValueError( @@ -45,13 +51,13 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device) self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device) if params.guidance_embed else nn.Identity() ) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.double_blocks = nn.ModuleList( [ @@ -60,6 +66,7 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + dtype=dtype, device=device, ) for _ in range(params.depth) ] @@ -67,14 +74,15 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device) for _ in range(params.depth_single_blocks) ] ) - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + if final_layer: + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device) - def forward( + def forward_orig( self, img: Tensor, img_ids: Tensor, @@ -82,18 +90,18 @@ class Flux(nn.Module): txt_ids: Tensor, timesteps: Tensor, y: Tensor, - guidance: Tensor | None = None, + guidance: Tensor = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) - vec = self.time_in(timestep_embedding(timesteps, 256)) + vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") - vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) @@ -108,5 +116,32 @@ class Flux(nn.Module): img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + if self.final_layer: + img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img + + def forward(self, x, timestep, context, y, guidance, **kwargs): + # from comfy/ldm/common_dit.py + def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): + if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting(): + padding_mode = "reflect" + pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0] + pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1] + return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode) + + bs, c, h, w = x.shape + patch_size = 2 + x = pad_to_patch_size(x, (patch_size, patch_size)) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance) + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] diff --git a/modules/models/flux/modules/layers.py b/modules/models/flux/modules/layers.py index 091ddf624..aa830849e 100644 --- a/modules/models/flux/modules/layers.py +++ b/modules/models/flux/modules/layers.py @@ -5,11 +5,11 @@ import torch from einops import rearrange from torch import Tensor, nn -from flux.math import attention, rope +from ..math import attention, rope class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: list): super().__init__() self.dim = dim self.theta = theta @@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 """ t = time_factor * t half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - t.device - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) @@ -50,20 +48,20 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int): + def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None): super().__init__() - self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) class RMSNorm(torch.nn.Module): - def __init__(self, dim: int): + def __init__(self, dim: int, dtype=None, device=None): super().__init__() - self.scale = nn.Parameter(torch.ones(dim)) + self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) def forward(self, x: Tensor): x_dtype = x.dtype @@ -73,26 +71,26 @@ class RMSNorm(torch.nn.Module): class QKNorm(torch.nn.Module): - def __init__(self, dim: int): + def __init__(self, dim: int, dtype=None, device=None): super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) + self.query_norm = RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = RMSNorm(dim, dtype=dtype, device=device) - def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.norm = QKNorm(head_dim) - self.proj = nn.Linear(dim, dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.norm = QKNorm(head_dim, dtype=dtype, device=device) + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) @@ -111,13 +109,13 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool): + def __init__(self, dim: int, double: bool, dtype=None, device=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) - def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: + def forward(self, vec: Tensor) -> tuple: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) return ( @@ -127,35 +125,35 @@ class Modulation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True) - self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) - self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_mlp = nn.Sequential( - nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), nn.GELU(approximate="tanh"), - nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - self.txt_mod = Modulation(hidden_size, double=True) - self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device) - self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_mlp = nn.Sequential( - nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), nn.GELU(approximate="tanh"), - nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -188,6 +186,11 @@ class DoubleStreamBlock(nn.Module): # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + + if txt.dtype == torch.float16: + txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) + + return img, txt @@ -202,7 +205,9 @@ class SingleStreamBlock(nn.Module): hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, - qk_scale: float | None = None, + qk_scale: float = None, + dtype=None, + device=None, ): super().__init__() self.hidden_dim = hidden_size @@ -212,17 +217,17 @@ class SingleStreamBlock(nn.Module): self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in - self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - self.norm = QKNorm(head_dim) + self.norm = QKNorm(head_dim, dtype=dtype, device=device) self.hidden_size = hidden_size - self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False) + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) @@ -236,15 +241,18 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - return x + mod.gate * output + x += mod.gate * output + if x.dtype == torch.float16: + x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None): super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor) -> Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) diff --git a/modules/models/flux/util.py b/modules/models/flux/util.py index 77fc76c09..9303eb7cf 100644 --- a/modules/models/flux/util.py +++ b/modules/models/flux/util.py @@ -7,9 +7,9 @@ from huggingface_hub import hf_hub_download from imwatermark import WatermarkEncoder from safetensors.torch import load_file as load_sft -from flux.model import Flux, FluxParams -from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams -from flux.modules.conditioner import HFEmbedder +from .model import Flux, FluxParams +from .modules.autoencoder import AutoEncoder, AutoEncoderParams +from .modules.conditioner import HFEmbedder @dataclass