From 8c9c139c654f6a41f1b5757bcf38930a0186371d Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 21 Sep 2024 01:02:49 +0900 Subject: [PATCH] support Flux schnell and cleanup --- modules/models/flux/flux.py | 65 +++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/modules/models/flux/flux.py b/modules/models/flux/flux.py index 109c72904..6326e97e6 100644 --- a/modules/models/flux/flux.py +++ b/modules/models/flux/flux.py @@ -133,14 +133,10 @@ def flux_time_shift(mu: float, sigma: float, t): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) class ModelSamplingFlux(torch.nn.Module): - def __init__(self, model_config=None): + def __init__(self, shift=1.15): super().__init__() - if model_config is not None: - sampling_settings = model_config.sampling_settings - else: - sampling_settings = {} - self.set_parameters(shift=sampling_settings.get("shift", 1.15)) + self.set_parameters(shift=shift) def set_parameters(self, shift=1.15, timesteps=10000): self.shift = shift @@ -175,29 +171,13 @@ class ModelSamplingFlux(torch.nn.Module): class BaseModel(torch.nn.Module): """Wrapper around the core FLUX model""" - def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""): + def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs): super().__init__() - params = dict( - image_model="flux", - in_channels=16, - vec_in_dim=768, - context_in_dim=4096, - hidden_size=3072, - mlp_ratio=4.0, - num_heads=24, - depth=19, - depth_single_blocks=38, - axes_dim=[16, 56, 56], - theta=10000, - qkv_bias=True, - guidance_embed=True, - ) - - self.diffusion_model = Flux(device=device, dtype=dtype, **params) - self.model_sampling = ModelSamplingFlux() - self.depth = params['depth'] - self.depth_single_block = params['depth_single_blocks'] + self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs) + self.model_sampling = ModelSamplingFlux(shift=shift) + self.depth = kwargs['depth'] + self.depth_single_block = kwargs['depth_single_blocks'] def apply_model(self, x, sigma, c_crossattn=None, y=None): dtype = self.get_dtype() @@ -215,9 +195,9 @@ class BaseModel(torch.nn.Module): class FLUX1LatentFormat: """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" - def __init__(self): - self.scale_factor = 0.3611 - self.shift_factor = 0.1159 + def __init__(self, scale_factor=0.3611, shift_factor=0.1159): + self.scale_factor = scale_factor + self.shift_factor = shift_factor def process_in(self, latent): return (latent - self.shift_factor) * self.scale_factor @@ -260,6 +240,22 @@ class FLUX1Inferencer(torch.nn.Module): def __init__(self, state_dict, use_ema=False): super().__init__() + params = dict( + image_model="flux", + in_channels=16, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10000, + qkv_bias=True, + guidance_embed=True, + ) + # detect model_prefix diffusion_model_prefix = "model.diffusion_model." if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: @@ -267,8 +263,15 @@ class FLUX1Inferencer(torch.nn.Module): elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict: diffusion_model_prefix = "" + shift=1.15 + # check guidance_in to detect Flux schnell + if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict: + print("Flux schnell detected") + params.update(dict(guidance_embed=False,)) + shift=1.0 + with torch.no_grad(): - self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference) + self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params) self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae) self.first_stage_model.dtype = devices.dtype_vae self.vae = self.first_stage_model # real vae