From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: [PATCH 01/23] getting SD2.1 to run on SDXL repo --- modules/launch_utils.py | 3 ++ modules/paths.py | 1 + modules/prompt_parser.py | 64 +++++++++++++++++++++++++------ modules/sd_hijack.py | 9 +++++ modules/sd_hijack_open_clip.py | 4 ++ modules/sd_models.py | 8 +++- modules/sd_models_config.py | 2 + modules/sd_models_xl.py | 40 +++++++++++++++++++ modules/sd_samplers_kdiffusion.py | 45 +++++++++++++++++----- 9 files changed, 152 insertions(+), 24 deletions(-) create mode 100644 modules/sd_models_xl.py diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 0e0dbca49..3b740dbd4 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -235,11 +235,13 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -297,6 +299,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index bada804e6..f509a85f7 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,6 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0e..d7f9e9a9c 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps): cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): - cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) + if isinstance(conds, dict): + cond = {k: v[i] for k, v in conds.items()} + else: + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond - res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + is_dict = isinstance(param, dict) + + if is_dict: + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break - res[i] = cond_schedule[target_index].cond + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond return res +def stack_conds(tensors): + # if prompts have wildly different lengths above the limit we'll get tensors of different shapes + # and won't be able to torch.stack them. So this fixes that. + token_count = max([x.shape[0] for x in tensors]) + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] + last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + + + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): param = c.batch[0][0].schedules[0].cond @@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): conds_list.append(conds_for_batch) - # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes - # and won't be able to torch.stack them. So this fixes that. - token_count = max([x.shape[0] for x in tensors]) - for i in range(len(tensors)): - if tensors[i].shape[0] != token_count: - last_vector = tensors[i][-1:] - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + if isinstance(tensors[0], dict): + keys = list(tensors[0].keys()) + stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys} + stacked = DictWithShape(stacked, stacked['crossattn'].shape) + else: + stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype) - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) + return conds_list, stacked re_attention = re.compile(r""" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce2..c4b9211fd 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -166,6 +166,15 @@ class StableDiffusionModelHijack: undo_optimizations() def hijack(self, m): + conditioner = getattr(m, 'conditioner', None) + if conditioner: + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e8529..6ac5bda6e 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,6 +16,10 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007c..8d639583c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if hasattr(model, 'conditioner'): + sd_models_xl.extend_sdxl(model) + model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply weights to model") @@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training + if hasattr(model, 'logvar'): + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237d..96501569b 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,11 +6,13 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 000000000..d43b88686 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + c = self.conditioner({'txt': batch}) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def extend_sdxl(model): + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + + model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_key = model.cond_stage_model.input_key + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b763..73289ce47 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -53,6 +53,28 @@ k_diffusion_scheduler = { } +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) @@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: - tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: - cond_in = torch.cat([tensor, uncond]) + cond_in = catenate_conds([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: - c_crossattn = [tensor[a:b]] + c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: [PATCH 02/23] SDXL support --- modules/launch_utils.py | 17 ++++++++++ modules/lowvram.py | 49 ++++++++++++++++++++-------- modules/paths.py | 9 +++++- modules/processing.py | 7 ++-- modules/prompt_parser.py | 23 ++++++++++++-- modules/sd_hijack.py | 23 +++++++++++++- modules/sd_hijack_clip.py | 16 +++++++--- modules/sd_hijack_open_clip.py | 38 +++++++++++++++++++--- modules/sd_hijack_optimizations.py | 51 +++++++++++++++++++++++++----- modules/sd_models.py | 14 ++++++-- modules/sd_models_config.py | 5 ++- modules/sd_models_xl.py | 27 +++++++++++++--- modules/sd_samplers_kdiffusion.py | 2 +- modules/shared.py | 2 ++ requirements.txt | 1 + requirements_versions.txt | 1 + 16 files changed, 241 insertions(+), 44 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 3b740dbd4..aa9d18802 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,6 +224,20 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + import importlib + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -319,11 +333,14 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) + mute_sdxl_imports() + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) + def configure_for_tests(): if "--api" not in sys.argv: sys.argv.append("--api") diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf0..da4f33a8a 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): send_me_to_gpu(first_stage_model, None) return first_stage_model_decode(z) - # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model + to_remain_in_cpu = [ + (sd_model, 'first_stage_model'), + (sd_model, 'depth_model'), + (sd_model, 'embedder'), + (sd_model, 'model'), + (sd_model, 'embedder'), + ] - # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then - # send the model to GPU. Then put modules back. the modules will be in CPU. - stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None + is_sdxl = hasattr(sd_model, 'conditioner') + is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + + if is_sdxl: + to_remain_in_cpu.append((sd_model, 'conditioner')) + elif is_sd2: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) + else: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer')) + + # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model + stored = [] + for obj, field in to_remain_in_cpu: + module = getattr(obj, field, None) + stored.append(module) + setattr(obj, field, None) + + # send the model to GPU. sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored + + # put modules back. the modules will be in CPU. + for (obj, field), module in zip(to_remain_in_cpu, stored): + setattr(obj, field, module) # register hooks for those the first three models - sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + if is_sdxl: + sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) + elif is_sd2: + sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) + else: + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap @@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer - del sd_model.cond_stage_model.transformer - if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) else: diff --git a/modules/paths.py b/modules/paths.py index f509a85f7..1100a8dc6 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), @@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index cd568a208..85d354232 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -343,10 +343,13 @@ class StableDiffusionProcessing: return cache[1] def setup_conds(self): + prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index d7f9e9a9c..33810669f 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import namedtuple from typing import List @@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -def get_learned_conditioning(model, prompts, steps): +class SdConditioning(list): + """ + A list with prompts for stable diffusion's conditioner model. + Can also specify width and height of created image - SDXL needs it. + """ + def __init__(self, prompts, width=None, height=None): + super().__init__() + self.extend(prompts) + self.width = width or getattr(prompts, 'width', None) + self.height = height or getattr(prompts, 'height', None) + + +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") -def get_multicond_prompt_list(prompts): + +def get_multicond_prompt_list(prompts: SdConditioning | list[str]): res_indexes = [] - prompt_flat_list = [] prompt_indexes = {} + prompt_flat_list = SdConditioning(prompts) + prompt_flat_list.clear() for prompt in prompts: subprompts = re_AND.split(prompt) @@ -201,6 +217,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c4b9211fd..266811f95 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward @@ -56,6 +61,9 @@ def apply_optimizations(option=None): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + sgm.modules.diffusionmodules.model.nonlinearity = silu + sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + if current_optimizer is not None: current_optimizer.undo() current_optimizer = None @@ -89,6 +97,10 @@ def undo_optimizations(): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + def fix_checkpoint(): """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want @@ -170,10 +182,19 @@ class StableDiffusionModelHijack: if conditioner: for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] - if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenCLIPEmbedder': + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a76667..6c17a81d2 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will - be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z = z * (original_mean / new_mean) + z *= (original_mean / new_mean) return z diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index 6ac5bda6e..fcf5ad07d 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None - def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' @@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27adea..e99c9ba5f 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork import ldm.modules.attention import ldm.modules.diffusionmodules.model +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: @@ -39,6 +43,9 @@ class SdOptimization: ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + class SdOptimizationXformers(SdOptimization): name = "xformers" @@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): @@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): @@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): @@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): @@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization): cmd_opt = "opt_split_attention_v1" priority = 10 - def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 class SdOptimizationInvokeAI(SdOptimization): @@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): @@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def list_optimizers(res): @@ -155,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None): +def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None): # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) + def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) + def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): @@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size): r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r + def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): @@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size): r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r + def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) @@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v): slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) + def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) + def einsum_op_tensor_mem(q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: @@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] @@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v): # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): if q.device.type == 'cuda': return einsum_op_cuda(q, k, v) @@ -328,7 +354,8 @@ def einsum_op(q, k, v): # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q = self.to_q(x) @@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None): +def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x + def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape @@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None): +def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) + # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x): return h3 + def xformers_attnblock_forward(self, x): try: h_ = x @@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) + def sdp_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x): out = self.proj_out(out) return x + out + def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8d639583c..e4aae597c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -411,6 +411,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' class SdModelData: @@ -445,6 +446,15 @@ class SdModelData: model_data = SdModelData() +def get_empty_cond(sd_model): + if hasattr(sd_model, 'conditioner'): + d = sd_model.get_learned_conditioning([""]) + return d['crossattn'] + else: + return sd_model.cond_stage_model([""]) + + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict timer.record("find config") @@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") with devices.autocast(), torch.no_grad(): - sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) + sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 96501569b..2e92479a1 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + return config_sdxl + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: return config_unclip diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d43b88686..e8e270c38 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,18 +1,30 @@ from __future__ import annotations +import sys + import torch import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer -from modules import devices +from modules import devices, shared, prompt_parser -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - c = self.conditioner({'txt': batch}) + width = getattr(self, 'target_width', 1024) + height = getattr(self, 'target_height', 1024) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + } + + c = self.conditioner(sdxl_conds) return c @@ -26,7 +38,7 @@ def extend_sdxl(model): model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] model.cond_stage_key = model.cond_stage_model.input_key model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" @@ -34,7 +46,14 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.is_xl = True + sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.modules.attention.print = lambda *args: None +sgm.modules.diffusionmodules.model.print = lambda *args: None +sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None +sgm.modules.encoders.modules.print = lambda *args: None + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 73289ce47..5552a8dc7 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size diff --git a/modules/shared.py b/modules/shared.py index b7518de68..71afd94f1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), + "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { diff --git a/requirements.txt b/requirements.txt index 3142085ea..b3f8a7f41 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ kornia lark numpy omegaconf +open-clip-torch piexif psutil diff --git a/requirements_versions.txt b/requirements_versions.txt index f71b9d6c5..b826bf431 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -15,6 +15,7 @@ kornia==0.6.7 lark==1.1.2 numpy==1.23.5 omegaconf==2.2.3 +open-clip-torch==2.20.0 piexif==1.1.3 psutil~=5.9.5 pytorch_lightning==1.9.4 From 5cf623c58ef3c158e8b25f7c3d516ffc16769fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:08:19 +0300 Subject: [PATCH 03/23] linter --- modules/paths.py | 2 +- modules/sd_models_xl.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/paths.py b/modules/paths.py index 1100a8dc6..c6f8904ea 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -41,7 +41,7 @@ for d, must_exist, what, options in path_dirs: # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. sys.path.insert(0, d) - import sgm + import sgm # noqa: F401 sys.path.pop(0) else: sys.path.append(d) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index e8e270c38..9224c1a35 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sys - import torch import sgm.models.diffusion From a04c95512148fc6df64535a995fbc8f499cae206 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 00:12:25 +0300 Subject: [PATCH 04/23] fix importlib.machinery issue on github's autotests #yolo --- modules/launch_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index aa9d18802..4f48f3a13 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -227,13 +227,14 @@ def run_extensions_installers(settings_file): def mute_sdxl_imports(): """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - import importlib + class Dummy: + pass - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module = Dummy() module.LPIPS = None sys.modules['taming.modules.losses.lpips'] = module - module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module = Dummy() module.StableDataModuleFromConfig = None sys.modules['sgm.data'] = module From b717eb7e56a4e620e77a2225e80223c89cb4f0d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 08:29:37 +0300 Subject: [PATCH 05/23] mute unneeded SDXL imports for tests too --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 4f48f3a13..56b972d51 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -334,8 +334,6 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) - mute_sdxl_imports() - if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) @@ -357,6 +355,8 @@ def configure_for_tests(): def start(): + mute_sdxl_imports() + print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: From ac4ccfa1369e74492b467294eab96c3f558b297b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:30:33 +0300 Subject: [PATCH 06/23] get attention optimizations to work --- modules/hypernetworks/hypernetwork.py | 2 +- modules/launch_utils.py | 1 + modules/sd_hijack_optimizations.py | 14 +++++++------- modules/sd_models_xl.py | 3 +++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b877..c4821d21a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): return context_k, context_v -def attention_CrossAttention_forward(self, x, context=None, mask=None): +def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 56b972d51..183730d29 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -239,6 +239,7 @@ def mute_sdxl_imports(): sys.modules['sgm.data'] = module + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index e99c9ba5f..b5f85ba51 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -173,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -355,7 +355,7 @@ def einsum_op(q, k, v): return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) @@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 9224c1a35..4d1aa497b 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -55,3 +55,6 @@ sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.encoders.modules.print = lambda *args: None +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file From 21aec6f567f52271efbbe33a2ab6561f9a47b787 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 09:38:54 +0300 Subject: [PATCH 07/23] lint --- modules/sd_models_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 4d1aa497b..1dd4459fe 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -57,4 +57,4 @@ sgm.modules.encoders.modules.print = lambda *args: None # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True -sgm.modules.attention.XFORMERS_IS_AVAILABLE = False \ No newline at end of file +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False From 594c8e7b263d9b37f4b18b56b159aeb6d1bba1b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 11:35:52 +0300 Subject: [PATCH 08/23] fix CLIP doing the unneeded normalization revert SD2.1 back to use the original repo add SDXL's force_zero_embeddings to negative prompt --- modules/processing.py | 2 +- modules/prompt_parser.py | 14 ++++++++++---- modules/sd_hijack.py | 2 +- modules/sd_hijack_clip.py | 15 +++++++++++++++ modules/sd_models_config.py | 1 - modules/sd_models_xl.py | 3 ++- 6 files changed, 29 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 85d354232..f01a6907f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -344,7 +344,7 @@ class StableDiffusionProcessing: def setup_conds(self): prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) - negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 33810669f..b29d079d5 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -116,11 +116,17 @@ class SdConditioning(list): A list with prompts for stable diffusion's conditioner model. Can also specify width and height of created image - SDXL needs it. """ - def __init__(self, prompts, width=None, height=None): + def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): super().__init__() self.extend(prompts) - self.width = width or getattr(prompts, 'width', None) - self.height = height or getattr(prompts, 'height', None) + + if copy_from is None: + copy_from = prompts + + self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) + self.width = width or getattr(copy_from, 'width', None) + self.height = height or getattr(copy_from, 'height', None) + def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): @@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): res.append(cached) continue - texts = [x[1] for x in prompt_schedule] + texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) conds = model.get_learned_conditioning(texts) cond_schedule = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 266811f95..647cdfbed 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -190,7 +190,7 @@ class StableDiffusionModelHijack: if typename == 'FrozenCLIPEmbedder': model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6c17a81d2..b3771909f 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -323,3 +323,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded + + +class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 2e92479a1..04c09ab00 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -12,7 +12,6 @@ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "conf config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1dd4459fe..b799ff46f 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), } - c = self.conditioner(sdxl_conds) + force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c From 76ebb175ca996e93c063e7109c9f478a268952b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 12:59:31 +0300 Subject: [PATCH 09/23] lora support --- extensions-builtin/Lora/lora.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cd46e6c7e..03f1ef856 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -179,6 +179,11 @@ def load_lora(name, lora_on_disk): if m: sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts: + key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_diffusers] = key continue From 6f23da603d3cbba82262a3c62cc44c8d5cb9e6db Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 16:18:39 +0300 Subject: [PATCH 10/23] fix broken img2img --- modules/sd_models_xl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b799ff46f..b19036f11 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -32,6 +32,9 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): return self.model(x, t, cond) +def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility + return x + def extend_sdxl(model): dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype @@ -50,6 +53,7 @@ def extend_sdxl(model): sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None From b8159d0919dcaa3a1a8f29e3aa30c25fe8e5f13b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:24:54 +0300 Subject: [PATCH 11/23] add XL support for live previews: approx and TAESD --- modules/sd_models_xl.py | 2 +- modules/sd_vae_approx.py | 37 ++++++++++++++++++++++++++----------- modules/sd_vae_taesd.py | 26 +++++++++++++------------- 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index b19036f11..af445a610 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,7 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_xl = True + model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f004683..b348f3aee 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch from torch import nn -from modules import devices, paths +from modules import devices, paths, shared -sd_vae_approx_model = None +sd_vae_approx_models = {} class VAEApprox(nn.Module): @@ -31,19 +31,34 @@ class VAEApprox(nn.Module): return x +def download_model(model_path, model_url): + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading VAEApprox model to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + def model(): - global sd_vae_approx_model + model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) - if sd_vae_approx_model is None: - model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") - sd_vae_approx_model = VAEApprox() + if loaded_model is None: + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) if not os.path.exists(model_path): - model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") - sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - sd_vae_approx_model.eval() - sd_vae_approx_model.to(devices.device, devices.dtype) + model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name) - return sd_vae_approx_model + if not os.path.exists(model_path): + model_path = os.path.join(paths.models_path, "VAE-approx", model_name) + download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) + + loaded_model = VAEApprox() + loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_approx_models[model_name] = loaded_model + + return loaded_model def cheap_approximation(sample): diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e87..5bf7c76e1 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,9 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared -sd_vae_taesd = None +sd_vae_taesd_models = {} def conv(n_in, n_out, **kwargs): @@ -61,9 +61,7 @@ class TAESD(nn.Module): return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) -def download_model(model_path): - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - +def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,17 +70,19 @@ def download_model(model_path): def model(): - global sd_vae_taesd + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) - if sd_vae_taesd is None: - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") - download_model(model_path) + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() - sd_vae_taesd.to(devices.device, devices.dtype) + loaded_model = TAESD(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder From e16ebc917dfc902f041963df0d4e99e8141cf82f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:32:35 +0300 Subject: [PATCH 12/23] repair --no-half for SDXL --- modules/sd_models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index e4aae597c..9e8cb3cfc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -395,10 +395,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + if hasattr(sd_config.model.params, 'unet_config'): + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" From ff73841c608f5f02e6352bb235d9dbf63d922990 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:42:16 +0300 Subject: [PATCH 13/23] mute SDXL imports in the place there SDXL is imported for the first time instead of launch.py --- modules/launch_utils.py | 18 ------------------ modules/paths.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 183730d29..01ea7c916 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,22 +224,6 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) -def mute_sdxl_imports(): - """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" - - class Dummy: - pass - - module = Dummy() - module.LPIPS = None - sys.modules['taming.modules.losses.lpips'] = module - - module = Dummy() - module.StableDataModuleFromConfig = None - sys.modules['sgm.data'] = module - - - def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -356,8 +340,6 @@ def configure_for_tests(): def start(): - mute_sdxl_imports() - print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: diff --git a/modules/paths.py b/modules/paths.py index c6f8904ea..250523399 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio import modules.safe # noqa: F401 +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + class Dummy: + pass + + module = Dummy() + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = Dummy() + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + # data_path = cmd_opts_pre.data sys.path.insert(0, script_path) @@ -18,6 +33,8 @@ for possible_sd_path in possible_sd_paths: assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" +mute_sdxl_imports() + path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), From 6c5f83b19b331d51bde28c5033d13d0d64c11e54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:17:50 +0300 Subject: [PATCH 14/23] add support for SDXL loras with te1/te2 modules --- extensions-builtin/Lora/lora.py | 41 +++++++++++++++++++++++++-------- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 1 - 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 03f1ef856..4b5da7b52 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return key @@ -142,10 +150,20 @@ class LoraUpDownModule: def assign_lora_names_to_compvis_modules(sd_model): lora_layer_mapping = {} - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + lora_name = f'{i}_{name.replace(".", "_")}' + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name for name, module in shared.sd_model.model.named_modules(): lora_name = name.replace(".", "_") @@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk): keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - for key_diffusers, weight in sd.items(): - key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) - key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) + for key_lora, weight in sd.items(): + key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) + key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: @@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk): sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts: - key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model") + if sd_module is None and "lora_unet" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: - keys_failed_to_match[key_diffusers] = key + keys_failed_to_match[key_lora] = key continue lora_module = lora.modules.get(key, None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9e8cb3cfc..077021755 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if hasattr(model, 'conditioner'): + model.is_sdxl = hasattr(model, 'conditioner') + if model.is_sdxl: sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index af445a610..a7240dc08 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,6 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning From dc3906185656dae75fcefe96625b1dcd0d31579c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:19:41 +0300 Subject: [PATCH 15/23] thank you linter --- extensions-builtin/Lora/lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 4b5da7b52..302490fbd 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -229,9 +229,9 @@ def load_lora(name, lora_on_disk): elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}') continue - raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}") + raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}") with torch.no_grad(): module.weight.copy_(weight) @@ -243,7 +243,7 @@ def load_lora(name, lora_on_disk): elif lora_key == "lora_down.weight": lora_module.down = module else: - raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") + raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha") if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: [PATCH 16/23] initial SDXL refiner support --- modules/sd_hijack.py | 18 ++++++++---- modules/sd_models.py | 3 +- modules/sd_models_config.py | 3 ++ modules/sd_models_xl.py | 57 ++++++++++++++++++++++++++++++------- modules/shared.py | 9 ++++-- 5 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbed..2b274c182 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_models.py b/modules/sd_models.py index 077021755..267f4d8ed 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -414,6 +414,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' +sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: @@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict + clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 04c09ab00..8266fa397 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename): if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index a7240dc08..01320c7a5 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: width = getattr(self, 'target_width', 1024) height = getattr(self, 'target_height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) sdxl_conds = { "txt": batch, - "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), } - force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c @@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility return x + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - - model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] - model.cond_stage_key = model.cond_stage_model.input_key + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.conditioner.wrapped = torch.nn.Module() -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None diff --git a/modules/shared.py b/modules/shared.py index 71afd94f1..234ede0d7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), - "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { From b7dbeda0d9e475aafa9db0cfe015bf724502ec20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:19:08 +0300 Subject: [PATCH 17/23] linter --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 267f4d8ed..729f03d7f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -478,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") From abb948dab09841571dd24c6be9ff9d6b212778ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:28:01 +0300 Subject: [PATCH 18/23] raise maximum Negative Guidance minimum sigma due to request in PR discussion --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 234ede0d7..89b7132e4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -439,7 +439,7 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), From 9a3f35b028a8026291679c35e1df5b2aea327a1d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:56:01 +0300 Subject: [PATCH 19/23] repair medvram and lowvram --- modules/lowvram.py | 4 +++- modules/sd_hijack_open_clip.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/lowvram.py b/modules/lowvram.py index da4f33a8a..6bbc11eb7 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -100,7 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) if sd_model.embedder: sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) - parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model + + if hasattr(sd_model, 'cond_stage_model'): + parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index fcf5ad07d..bb0b96c72 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded @@ -66,6 +66,6 @@ class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWi def encode_embedding_init_text(self, init_text, nvpt): ids = tokenizer.encode(init_text) ids = torch.asarray([ids], device=devices.device, dtype=torch.int) - embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) return embedded From 471a5a66b73921d569242daccc5275cb195e3f06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 17:54:09 +0300 Subject: [PATCH 20/23] add more relevant fields to caching conds --- modules/processing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f01a6907f..f68e010dd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -330,8 +330,21 @@ class StableDiffusionProcessing: caches is a list with items described above. """ + + cached_params = ( + required_prompts, + steps, + opts.CLIP_stop_at_last_layers, + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, + opts.sdxl_crop_top, + self.width, + self.height, + ) + for cache in caches: - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + if cache[0] is not None and cached_params == cache[0]: return cache[1] cache = caches[0] @@ -339,7 +352,7 @@ class StableDiffusionProcessing: with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) + cache[0] = cached_params return cache[1] def setup_conds(self): From ac2d47ff4c00b041cae3d882c2832662c2c64935 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 20:27:41 +0300 Subject: [PATCH 21/23] add cheap VAE approximation coeffs for SDXL --- modules/sd_vae_approx.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index b348f3aee..86bd658ad 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -64,12 +64,22 @@ def model(): def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 - coefs = torch.tensor([ - [0.298, 0.207, 0.208], - [0.187, 0.286, 0.173], - [-0.158, 0.189, 0.264], - [-0.184, -0.271, -0.473], - ]).to(sample.device) + if shared.sd_model.is_sdxl: + coeffs = [ + [ 0.3448, 0.4168, 0.4395], + [-0.1953, -0.0290, 0.0250], + [ 0.1074, 0.0886, -0.0163], + [-0.3730, -0.2499, -0.2088], + ] + else: + coeffs = [ + [ 0.298, 0.207, 0.208], + [ 0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ] + + coefs = torch.tensor(coeffs).to(sample.device) x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) From 5dee0fa1f812cf9f5fa6675c22c9a57afad39983 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 21:41:21 +0300 Subject: [PATCH 22/23] add a message about unsupported samplers --- modules/sd_samplers.py | 3 +++ modules/sd_samplers_compvis.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f22aad8f2..bea2684c4 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -28,6 +28,9 @@ def create_sampler(name, model): assert config is not None, f'bad sampler name: {name}' + if model.is_sdxl and config.options.get("no_sdxl", False): + raise Exception(f"Sampler {config.name} is not supported for SDXL") + sampler = config.constructor(model) sampler.config = config diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bdae8b404..4a8396f97 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), + sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), + sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), ] From 14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 07:33:16 +0300 Subject: [PATCH 23/23] fix an issue in live previews that happens when you use SDXL with fp16 VAE --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f68e010dd..eb4a60eba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -539,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): - with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae)) return x