mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-03-04 13:04:53 +08:00
support Flux1
This commit is contained in:
parent
9c0fd83b5e
commit
219a0e2429
@ -24,6 +24,11 @@ class AutocastLinear(nn.Linear):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
|
||||
|
||||
class AutocastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.layer_norm(
|
||||
x, self.normalized_shape, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None, self.eps)
|
||||
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
@ -41,9 +46,9 @@ class Mlp(nn.Module):
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
@ -61,10 +66,10 @@ class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.q_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_proj = AutocastLinear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
q = self.q_proj(x)
|
||||
@ -82,9 +87,11 @@ ACTIVATIONS = {
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):
|
||||
super().__init__()
|
||||
self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm1 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)
|
||||
self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.layer_norm2 = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
#self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)
|
||||
|
||||
@ -131,7 +138,7 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.final_layer_norm = AutocastLayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
@ -150,7 +157,7 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection = AutocastLinear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
|
@ -33,6 +33,7 @@ class ModelType(enum.Enum):
|
||||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
FLUX1 = 6
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
@ -369,7 +370,7 @@ def check_fp8(model):
|
||||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
elif any(getattr(model, attr, False) for attr in ("is_sdxl", "is_flux1")) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
@ -382,10 +383,14 @@ def set_model_type(model, state_dict):
|
||||
model.is_sdxl = False
|
||||
model.is_ssd = False
|
||||
model.is_sd3 = False
|
||||
model.is_flux1 = False
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
||||
model.is_sd3 = True
|
||||
model.model_type = ModelType.SD3
|
||||
elif "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict:
|
||||
model.is_flux1 = True
|
||||
model.model_type = ModelType.FLUX1
|
||||
elif hasattr(model, 'conditioner'):
|
||||
model.is_sdxl = True
|
||||
|
||||
@ -777,6 +782,9 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embe
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
clip_l_clip_weight = 'text_encoders.clip_l.transformer.text_model.final_layer_norm.weight'
|
||||
clip_g_clip_weight = 'text_encoders.clip_g.transformer.text_model.final_layer_norm.weight'
|
||||
t5xxl_clip_weight = 'text_encoders.t5xxl.transformer.encoder.final_layer_norm.weight'
|
||||
|
||||
|
||||
class SdModelData:
|
||||
@ -909,7 +917,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_
|
||||
|
||||
if not checkpoint_config:
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight, clip_l_clip_weight, clip_g_clip_weight ] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
|
@ -25,6 +25,7 @@ config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
config_flux1 = os.path.join(sd_configs_path, "flux1-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
@ -78,6 +79,9 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
return config_sd3
|
||||
|
||||
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
|
||||
return config_flux1
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
|
@ -36,5 +36,8 @@ class WebuiSdModel(LatentDiffusion):
|
||||
is_sd3: bool
|
||||
"""True if the model's architecture is SD 3"""
|
||||
|
||||
is_flux1: bool
|
||||
"""True if the model's architecture is FLUX 1"""
|
||||
|
||||
latent_channels: int
|
||||
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
|
||||
|
@ -63,7 +63,7 @@ class TAESDDecoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
||||
latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4
|
||||
|
||||
self.decoder = decoder(latent_channels)
|
||||
self.decoder.load_state_dict(
|
||||
@ -79,7 +79,7 @@ class TAESDEncoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
||||
latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4
|
||||
|
||||
self.encoder = encoder(latent_channels)
|
||||
self.encoder.load_state_dict(
|
||||
@ -97,6 +97,8 @@ def download_model(model_path, model_url):
|
||||
def decoder_model():
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_decoder.pth"
|
||||
elif shared.sd_model.is_flux1:
|
||||
model_name = "taef1_decoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_decoder.pth"
|
||||
else:
|
||||
@ -122,6 +124,8 @@ def decoder_model():
|
||||
def encoder_model():
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_encoder.pth"
|
||||
elif shared.sd_model.is_flux1:
|
||||
model_name = "taef1_encoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_encoder.pth"
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user