mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-28 10:25:05 +08:00
initial SD3 support
This commit is contained in:
parent
a7116aa9a1
commit
5b2a60b8e2
@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
|
||||
## Credits
|
||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||
|
||||
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
|
||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
|
||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||
|
5
configs/sd3-inference.yaml
Normal file
5
configs/sd3-inference.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
model:
|
||||
target: modules.models.sd3.sd3_model.SD3Inferencer
|
||||
params:
|
||||
shift: 3
|
||||
state_dict: null
|
@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
else:
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
|
||||
|
||||
for name, module in cond_stage_model.named_modules():
|
||||
network_name = name.replace(".", "_")
|
||||
network_layer_mapping[network_name] = module
|
||||
module.network_layer_name = network_name
|
||||
|
@ -6,7 +6,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from other_impls import attention, Mlp
|
||||
from modules.models.sd3.other_impls import attention, Mlp
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import torch, math, einops
|
||||
from mmdit import MMDiT
|
||||
from modules.models.sd3.mmdit import MMDiT
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""Wrapper around the core MM-DiT model"""
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix=""):
|
||||
def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=""):
|
||||
super().__init__()
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2]
|
||||
depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64
|
||||
num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1]
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1]
|
||||
context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {
|
||||
|
166
modules/models/sd3/sd3_model.py
Normal file
166
modules/models/sd3/sd3_model.py
Normal file
@ -0,0 +1,166 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Mapping
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
import k_diffusion
|
||||
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
|
||||
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
|
||||
|
||||
from modules import shared, modelloader, devices
|
||||
|
||||
CLIPG_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
|
||||
CLIPG_CONFIG = {
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"intermediate_size": 5120,
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
}
|
||||
|
||||
CLIPL_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
|
||||
CLIPL_CONFIG = {
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"intermediate_size": 3072,
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
}
|
||||
|
||||
T5_URL = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
|
||||
T5_CONFIG = {
|
||||
"d_ff": 10240,
|
||||
"d_model": 4096,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"vocab_size": 32128,
|
||||
}
|
||||
|
||||
|
||||
class SafetensorsMapping(Mapping):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for key in self.file.keys():
|
||||
yield key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.file.get_tensor(key)
|
||||
|
||||
|
||||
class SD3Cond(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = SD3Tokenizer()
|
||||
|
||||
with torch.no_grad():
|
||||
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32)
|
||||
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
|
||||
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32)
|
||||
|
||||
self.weights_loaded = False
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
res = []
|
||||
|
||||
for prompt in prompts:
|
||||
tokens = self.tokenizer.tokenize_with_weights(prompt)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(tokens["l"])
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(tokens["g"])
|
||||
t5_out, t5_pooled = self.t5xxl.encode_token_weights(tokens["t5xxl"])
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
res.append({
|
||||
'crossattn': lgt_out[0].to(devices.device),
|
||||
'vector': vector_out[0].to(devices.device),
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
def load_weights(self):
|
||||
if self.weights_loaded:
|
||||
return
|
||||
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
self.weights_loaded = True
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.tensor([[0]], device=devices.device) # XXX
|
||||
|
||||
|
||||
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
|
||||
def __init__(self, inner_model, sigmas):
|
||||
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
|
||||
self.inner_model = inner_model
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
return self.inner_model.apply_model(input, sigma, **kwargs)
|
||||
|
||||
|
||||
class SD3Inferencer(torch.nn.Module):
|
||||
def __init__(self, state_dict, shift=3, use_ema=False):
|
||||
super().__init__()
|
||||
|
||||
self.shift = shift
|
||||
|
||||
with torch.no_grad():
|
||||
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
|
||||
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
|
||||
self.first_stage_model.dtype = self.model.diffusion_model.dtype
|
||||
|
||||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
||||
|
||||
self.cond_stage_model = SD3Cond()
|
||||
self.cond_stage_key = 'txt'
|
||||
|
||||
self.parameterization = "eps"
|
||||
self.model.conditioning_key = "crossattn"
|
||||
|
||||
self.latent_format = SD3LatentFormat()
|
||||
self.latent_channels = 16
|
||||
|
||||
def after_load_weights(self):
|
||||
self.cond_stage_model.load_weights()
|
||||
|
||||
def ema_scope(self):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def get_learned_conditioning(self, batch: list[str]):
|
||||
return self.cond_stage_model(batch)
|
||||
|
||||
def apply_model(self, x, t, cond):
|
||||
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
|
||||
|
||||
def decode_first_stage(self, latent):
|
||||
latent = self.latent_format.process_out(latent)
|
||||
return self.first_stage_model.decode(latent)
|
||||
|
||||
def encode_first_stage(self, image):
|
||||
latent = self.first_stage_model.encode(image)
|
||||
return self.latent_format.process_in(latent)
|
||||
|
||||
def create_denoiser(self):
|
||||
return SD3Denoiser(self, self.model.model_sampling.sigmas)
|
@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||
|
||||
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||
latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
|
||||
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
@ -1,7 +1,9 @@
|
||||
import collections
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import re
|
||||
@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
from modules.shared import opts
|
||||
@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
SD1 = 1
|
||||
SD2 = 2
|
||||
SDXL = 3
|
||||
SSD = 4
|
||||
SD3 = 5
|
||||
|
||||
|
||||
def replace_key(d, key, new_key, value):
|
||||
keys = list(d.keys())
|
||||
|
||||
@ -368,6 +376,36 @@ def check_fp8(model):
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def set_model_type(model, state_dict):
|
||||
model.is_sd1 = False
|
||||
model.is_sd2 = False
|
||||
model.is_sdxl = False
|
||||
model.is_ssd = False
|
||||
model.is_ssd3 = False
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
|
||||
model.is_sd3 = True
|
||||
model.model_type = ModelType.SD3
|
||||
elif hasattr(model, 'conditioner'):
|
||||
model.is_sdxl = True
|
||||
|
||||
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
|
||||
model.is_ssd = True
|
||||
model.model_type = ModelType.SSD
|
||||
else:
|
||||
model.model_type = ModelType.SDXL
|
||||
elif hasattr(model.cond_stage_model, 'model'):
|
||||
model.is_sd2 = True
|
||||
model.model_type = ModelType.SD2
|
||||
else:
|
||||
model.is_sd1 = True
|
||||
model.model_type = ModelType.SD1
|
||||
|
||||
|
||||
def set_model_fields(model):
|
||||
if not hasattr(model, 'latent_channels'):
|
||||
model.latent_channels = 4
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||
set_model_type(model, state_dict)
|
||||
set_model_fields(model)
|
||||
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
@ -552,8 +589,7 @@ def patch_given_betas():
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
|
||||
def repair_config(sd_config, state_dict=None):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
@ -563,6 +599,7 @@ def repair_config(sd_config):
|
||||
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
|
||||
@ -580,6 +617,7 @@ def repair_config(sd_config):
|
||||
sd_config.model.params.unet_config.params.use_checkpoint = False
|
||||
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
@ -715,6 +753,25 @@ def send_model_to_trash(m):
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def instantiate_from_config(config, state_dict=None):
|
||||
constructor = get_obj_from_str(config["target"])
|
||||
|
||||
params = {**config.get("params", {})}
|
||||
|
||||
if state_dict and "state_dict" in params and params["state_dict"] is None:
|
||||
params["state_dict"] = state_dict
|
||||
|
||||
return constructor(**params)
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
repair_config(sd_config, state_dict)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
sd_model = instantiate_from_config(sd_config.model, state_dict)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
sd_model = instantiate_from_config(sd_config.model, state_dict)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if hasattr(sd_model, "after_load_weights"):
|
||||
sd_model.after_load_weights()
|
||||
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
|
@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
|
||||
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if "model.diffusion_model.x_embedder.proj.weight" in sd:
|
||||
return config_sd3
|
||||
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_sdxl_inpainting
|
||||
else:
|
||||
return config_sdxl
|
||||
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
return config_alt_diffusion_m18
|
||||
|
@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
|
||||
|
||||
is_sd1: bool
|
||||
"""True if the model's architecture is SD 1.x"""
|
||||
|
||||
is_sd3: bool
|
||||
"""True if the model's architecture is SD 3"""
|
||||
|
||||
latent_channels: int
|
||||
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
|
||||
|
@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
|
||||
return x_sample
|
||||
@ -246,7 +246,7 @@ class Sampler:
|
||||
self.eta_infotext_field = 'Eta'
|
||||
self.eta_default = 1.0
|
||||
|
||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||
self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
|
||||
|
||||
self.p = None
|
||||
self.model_wrap_cfg = None
|
||||
|
@ -53,6 +53,11 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
|
||||
|
||||
if denoiser_constructor is not None:
|
||||
self.model_wrap = denoiser_constructor()
|
||||
else:
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||
|
||||
|
@ -8,9 +8,9 @@ sd_vae_approx_models = {}
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, latent_channels=4):
|
||||
super(VAEApprox, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 8, (7, 7))
|
||||
self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
|
||||
self.conv2 = nn.Conv2d(8, 16, (5, 5))
|
||||
self.conv3 = nn.Conv2d(16, 32, (3, 3))
|
||||
self.conv4 = nn.Conv2d(32, 64, (3, 3))
|
||||
@ -40,7 +40,13 @@ def download_model(model_path, model_url):
|
||||
|
||||
|
||||
def model():
|
||||
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "vaeapprox-sd3.pt"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "vaeapprox-sdxl.pt"
|
||||
else:
|
||||
model_name = "model.pt"
|
||||
|
||||
loaded_model = sd_vae_approx_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
@ -52,7 +58,7 @@ def model():
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||
|
||||
loaded_model = VAEApprox()
|
||||
loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
|
||||
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
@ -64,7 +70,18 @@ def model():
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
if shared.sd_model.is_sdxl:
|
||||
if shared.sd_model.is_sd3:
|
||||
coeffs = [
|
||||
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
|
||||
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
|
||||
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
|
||||
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
|
||||
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
|
||||
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
|
||||
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
|
||||
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
|
||||
]
|
||||
elif shared.sd_model.is_sdxl:
|
||||
coeffs = [
|
||||
[ 0.3448, 0.4168, 0.4395],
|
||||
[-0.1953, -0.0290, 0.0250],
|
||||
|
@ -34,9 +34,9 @@ class Block(nn.Module):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
|
||||
def decoder():
|
||||
def decoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(4, 64), nn.ReLU(),
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
@ -44,13 +44,13 @@ def decoder():
|
||||
)
|
||||
|
||||
|
||||
def encoder():
|
||||
def encoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
|
||||
|
||||
@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||
def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.decoder = decoder()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
|
||||
|
||||
self.decoder = decoder(latent_channels)
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth"):
|
||||
def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = encoder()
|
||||
|
||||
if latent_channels is None:
|
||||
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
|
||||
|
||||
self.encoder = encoder(latent_channels)
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@ -87,7 +95,13 @@ def download_model(model_path, model_url):
|
||||
|
||||
|
||||
def decoder_model():
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_decoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_decoder.pth"
|
||||
else:
|
||||
model_name = "taesd_decoder.pth"
|
||||
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
@ -106,7 +120,13 @@ def decoder_model():
|
||||
|
||||
|
||||
def encoder_model():
|
||||
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
|
||||
if shared.sd_model.is_sd3:
|
||||
model_name = "taesd3_encoder.pth"
|
||||
elif shared.sd_model.is_sdxl:
|
||||
model_name = "taesdxl_encoder.pth"
|
||||
else:
|
||||
model_name = "taesd_encoder.pth"
|
||||
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
|
Loading…
Reference in New Issue
Block a user