mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
support Flux schnell and cleanup
This commit is contained in:
parent
11c9bc719c
commit
8c9c139c65
@ -133,14 +133,10 @@ def flux_time_shift(mu: float, sigma: float, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
class ModelSamplingFlux(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
def __init__(self, shift=1.15):
|
||||
super().__init__()
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
|
||||
self.set_parameters(shift=shift)
|
||||
|
||||
def set_parameters(self, shift=1.15, timesteps=10000):
|
||||
self.shift = shift
|
||||
@ -175,29 +171,13 @@ class ModelSamplingFlux(torch.nn.Module):
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core FLUX model"""
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float16, state_dict=None, prefix=""):
|
||||
def __init__(self, shift=1.15, device=None, dtype=torch.float16, state_dict=None, prefix="", **kwargs):
|
||||
super().__init__()
|
||||
|
||||
params = dict(
|
||||
image_model="flux",
|
||||
in_channels=16,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
)
|
||||
|
||||
self.diffusion_model = Flux(device=device, dtype=dtype, **params)
|
||||
self.model_sampling = ModelSamplingFlux()
|
||||
self.depth = params['depth']
|
||||
self.depth_single_block = params['depth_single_blocks']
|
||||
self.diffusion_model = Flux(device=device, dtype=dtype, **kwargs)
|
||||
self.model_sampling = ModelSamplingFlux(shift=shift)
|
||||
self.depth = kwargs['depth']
|
||||
self.depth_single_block = kwargs['depth_single_blocks']
|
||||
|
||||
def apply_model(self, x, sigma, c_crossattn=None, y=None):
|
||||
dtype = self.get_dtype()
|
||||
@ -215,9 +195,9 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
class FLUX1LatentFormat:
|
||||
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
def __init__(self, scale_factor=0.3611, shift_factor=0.1159):
|
||||
self.scale_factor = scale_factor
|
||||
self.shift_factor = shift_factor
|
||||
|
||||
def process_in(self, latent):
|
||||
return (latent - self.shift_factor) * self.scale_factor
|
||||
@ -260,6 +240,22 @@ class FLUX1Inferencer(torch.nn.Module):
|
||||
def __init__(self, state_dict, use_ema=False):
|
||||
super().__init__()
|
||||
|
||||
params = dict(
|
||||
image_model="flux",
|
||||
in_channels=16,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
)
|
||||
|
||||
# detect model_prefix
|
||||
diffusion_model_prefix = "model.diffusion_model."
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
@ -267,8 +263,15 @@ class FLUX1Inferencer(torch.nn.Module):
|
||||
elif "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
diffusion_model_prefix = ""
|
||||
|
||||
shift=1.15
|
||||
# check guidance_in to detect Flux schnell
|
||||
if f"{diffusion_model_prefix}guidance_in.in_layer.weight" not in state_dict:
|
||||
print("Flux schnell detected")
|
||||
params.update(dict(guidance_embed=False,))
|
||||
shift=1.0
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = BaseModel(state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference)
|
||||
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=diffusion_model_prefix, device="cpu", dtype=devices.dtype_inference, **params)
|
||||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
|
||||
self.first_stage_model.dtype = devices.dtype_vae
|
||||
self.vae = self.first_stage_model # real vae
|
||||
|
Loading…
Reference in New Issue
Block a user