import torch from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback def catenate_conds(conds): if not isinstance(conds[0], dict): return torch.cat(conds) return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} def subscript_cond(cond, a, b): if not isinstance(cond, dict): return cond[a:b] return {key: vec[a:b] for key, vec in cond.items()} def pad_cond(tensor, repeats, empty): if not isinstance(tensor, dict): return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) return tensor class CFGDenoiser(torch.nn.Module): """ Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) that can take a noisy picture and produce a noise-free picture using two guidances (prompts) instead of one. Originally, the second prompt is just an empty string, but we use non-empty negative prompt. """ def __init__(self, sampler): super().__init__() self.model_wrap = None self.mask = None self.nmask = None self.mask_blend_power = 1 self.mask_blend_scale = 1 self.inpaint_detail_preservation = 16 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" self.total_steps = None """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler self.model_wrap = None self.p = None # NOTE: masking before denoising can cause the original latents to be oversmoothed # as the original latents do not have noise self.mask_before_denoising = False @property def inner_model(self): raise NotImplementedError() def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) for i, conds in enumerate(conds_list): for cond_index, weight in conds: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) return denoised def combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3) denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) return denoised def get_pred_x0(self, x_in, x_out, sigma): return x_out def update_inner_model(self): self.model_wrap = None c, uc = self.p.get_conds() self.sampler.sampler_extra_args['cond'] = c self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): def latent_blend(a, b, t): """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. The "detail_preservation" factor biases the magnitude interpolation towards the larger of the two magnitudes. """ # Record the original latent vector magnitudes. # We bring them to a power so that larger magnitudes are favored over smaller ones. # 64-bit operations are used here to allow large exponents. a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation one_minus_t = 1 - t # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) # Linearly interpolate the image vectors. image_interp = a * one_minus_t + b * t # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) return image_interp def get_modified_nmask(nmask, _sigma): """ Converts a negative mask representing the transparency of the original latent vectors being overlayed to a mask that is scaled according to the denoising strength for this step. Where: 0 = fully opaque, infinite density, fully masked 1 = fully transparent, zero density, fully unmasked We bring this transparency to a power, as this allows one to simulate N number of blending operations where N can be any positive real value. Using this one can control the balance of influence between the denoiser and the original latents according to the sigma value. NOTE: "mask" is not used """ return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException if sd_samplers_common.apply_refiner(self): cond = self.sampler.sampler_extra_args['cond'] uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond if isinstance(uncond, dict): make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} else: make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) else: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond sigma_in = denoiser_params.sigma tensor = denoiser_params.text_cond uncond = denoiser_params.text_uncond skip_uncond = False # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: skip_uncond = True x_in = x_in[:-batch_size] sigma_in = sigma_in[:-batch_size] self.padded_cond_uncond = False if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: empty = shared.sd_model.cond_stage_model_empty_prompt num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: tensor = pad_cond(tensor, -num_repeats, empty) self.padded_cond_uncond = True elif num_repeats > 0: uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: cond_in = catenate_conds([tensor, uncond, uncond]) elif skip_uncond: cond_in = tensor else: cond_in = catenate_conds([tensor, uncond]) if shared.opts.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: c_crossattn = subscript_cond(tensor, a, b) else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) cfg_denoised_callback(denoised_params) devices.test_for_nans(x_out, "unet") if is_edit_model: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) elif skip_uncond: denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) if opts.live_preview_content == "Prompt": preview = self.sampler.last_latent elif opts.live_preview_content == "Negative prompt": preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) else: preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) denoised = after_cfg_callback_params.x self.step += 1 return denoised