support Flux schnell and cleanup

This commit is contained in:
Won-Kyu Park 2024-09-21 01:02:49 +09:00
parent 11c9bc719c
commit 8c9c139c65
No known key found for this signature in database
GPG Key ID: 53AA79C8C9535D15

View File

@ -133,14 +133,10 @@ def flux_time_shift(mu: float, sigma: float, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
def __init__(self, shift=1.15):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
self.set_parameters(shift=shift)
def set_parameters(self, shift=1.15, timesteps=10000):
self.shift = shift
@ -175,29 +171,13 @@ class ModelSamplingFlux(torch.nn.Module):
class BaseModel(torch.nn.Module):
"""Wrapper around the core FLUX model"""
def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""):
def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs):
super().__init__()
params = dict(
image_model="flux",
in_channels=16,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10000,
qkv_bias=True,
guidance_embed=True,
)
self.diffusion_model = Flux(device=device, dtype=dtype, **params)
self.model_sampling = ModelSamplingFlux()
self.depth = params['depth']
self.depth_single_block = params['depth_single_blocks']
self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs)
self.model_sampling = ModelSamplingFlux(shift=shift)
self.depth = kwargs['depth']
self.depth_single_block = kwargs['depth_single_blocks']
def apply_model(self, x, sigma, c_crossattn=None, y=None):
dtype = self.get_dtype()
@ -215,9 +195,9 @@ class BaseModel(torch.nn.Module):
class FLUX1LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
def __init__(self, scale_factor=0.3611, shift_factor=0.1159):
self.scale_factor = scale_factor
self.shift_factor = shift_factor
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
@ -260,6 +240,22 @@ class FLUX1Inferencer(torch.nn.Module):
def __init__(self, state_dict, use_ema=False):
super().__init__()
params = dict(
image_model="flux",
in_channels=16,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10000,
qkv_bias=True,
guidance_embed=True,
)
# detect model_prefix
diffusion_model_prefix = "model.diffusion_model."
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
@ -267,8 +263,15 @@ class FLUX1Inferencer(torch.nn.Module):
elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
diffusion_model_prefix = ""
shift=1.15
# check guidance_in to detect Flux schnell
if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict:
print("Flux schnell detected")
params.update(dict(guidance_embed=False,))
shift=1.0
with torch.no_grad():
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference)
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = devices.dtype_vae
self.vae = self.first_stage_model # real vae