From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: [PATCH] initial SDXL refiner support --- modules/sd_hijack.py | 18 ++++++++---- modules/sd_models.py | 3 +- modules/sd_models_config.py | 3 ++ modules/sd_models_xl.py | 57 ++++++++++++++++++++++++++++++------- modules/shared.py | 9 ++++-- 5 files changed, 71 insertions(+), 19 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbed..2b274c182 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_models.py b/modules/sd_models.py index 077021755..267f4d8ed 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -414,6 +414,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' +sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: @@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict + clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 04c09ab00..8266fa397 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename): if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index a7240dc08..01320c7a5 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: width = getattr(self, 'target_width', 1024) height = getattr(self, 'target_height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) sdxl_conds = { "txt": batch, - "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), } - force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c @@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility return x + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - - model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] - model.cond_stage_key = model.cond_stage_model.input_key + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.conditioner.wrapped = torch.nn.Module() -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None diff --git a/modules/shared.py b/modules/shared.py index 71afd94f1..234ede0d7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), - "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) options_templates.update(options_section(('optimizations', "Optimizations"), {