From f166868df6d5b263974d0381dd49e474323f272f Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Wed, 6 Dec 2023 01:35:30 +0900 Subject: [PATCH] DeepCache Implementation Mark 2 maybe... some invalid interrupt method? move to paper implementation fix descriptions, KeyError handle sgm for XL fix ruff, change default for out_block Implement Deepcache Optimization --- extensions-builtin/deepcache/deepcache.py | 179 ++++++++++++++++++ .../deepcache/scripts/deepcache_script.py | 75 ++++++++ .../deepcache/scripts/deepcache_xyz.py | 63 ++++++ .../scripts/forward_timestep_embed_patch.py | 62 ++++++ 4 files changed, 379 insertions(+) create mode 100644 extensions-builtin/deepcache/deepcache.py create mode 100644 extensions-builtin/deepcache/scripts/deepcache_script.py create mode 100644 extensions-builtin/deepcache/scripts/deepcache_xyz.py create mode 100644 extensions-builtin/deepcache/scripts/forward_timestep_embed_patch.py diff --git a/extensions-builtin/deepcache/deepcache.py b/extensions-builtin/deepcache/deepcache.py new file mode 100644 index 000000000..a26671abe --- /dev/null +++ b/extensions-builtin/deepcache/deepcache.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional +from collections import defaultdict + +import torch +from ldm.modules.diffusionmodules.openaimodel import timestep_embedding +from scripts.forward_timestep_embed_patch import forward_timestep_embed + +@dataclass +class DeepCacheParams: + cache_in_level: int = 0 + cache_enable_step: int = 0 + full_run_step_rate: int = 5 + # cache_latents_cpu: bool = False + # cache_latents_hires: bool = False + +class DeepCacheSession: + """ + Session for DeepCache, which holds cache data and provides functions for hooking the model. + """ + def __init__(self) -> None: + self.CACHE_LAST = {"timestep": {0}} + self.stored_forward = None + self.unet_reference = None + self.cache_success_count = 0 + self.cache_fail_count = 0 + self.fail_reasons = defaultdict(int) + self.success_reasons = defaultdict(int) + self.enumerated_timestep = {"value": -1} + + def log_skip(self, reason:str = 'disabled_by_default'): + self.fail_reasons[reason] += 1 + self.cache_fail_count += 1 + + def report(self): + # report cache success rate + total = self.cache_success_count + self.cache_fail_count + if total == 0: + return + print(f"DeepCache success rate: {self.cache_success_count / total * 100}% ({self.cache_success_count}/{total})") + for fail_reasons, count in self.fail_reasons.items(): + print(f" {fail_reasons}: {count}") + for success_reasons, count in self.success_reasons.items(): + print(f" {success_reasons}: {count}") + + def deepcache_hook_model(self, unet, params:DeepCacheParams): + """ + Hooks the given unet model to use DeepCache. + """ + caching_level = params.cache_in_level + # caching level 0 = no caching, idx for resnet layers + cache_enable_step = params.cache_enable_step + full_run_step_rate = params.full_run_step_rate # '5' means run full model every 5 steps + if full_run_step_rate < 1: + print(f"DeepCache disabled due to full_run_step_rate {full_run_step_rate} < 1 but enabled by user") + return # disabled + if getattr(unet, '_deepcache_hooked', False): + return # already hooked + CACHE_LAST = self.CACHE_LAST + self.stored_forward = unet.forward + self.enumerated_timestep["value"] = -1 + valid_caching_in_level = min(caching_level, len(unet.input_blocks) - 1) + valid_caching_out_level = min(valid_caching_in_level, len(unet.output_blocks) - 1) + # set to max if invalid + caching_level = valid_caching_out_level + valid_cache_timestep_range = 50 # total 1000, 50 + def put_cache(h:torch.Tensor, timestep:int, real_timestep:float): + """ + Registers cache + """ + CACHE_LAST["timestep"].add(timestep) + assert h is not None, f"Cannot cache None" + # maybe move to cpu and load later for low vram? + CACHE_LAST["last"] = h + CACHE_LAST[f"timestep_{timestep}"] = h + CACHE_LAST["real_timestep"] = real_timestep + def get_cache(current_timestep:int, real_timestep:float) -> Optional[torch.Tensor]: + """ + Returns the cached tensor for the given timestep and cache key. + """ + if current_timestep < cache_enable_step: + self.fail_reasons['disabled'] += 1 + self.cache_fail_count += 1 + return None + elif full_run_step_rate < 1: + self.fail_reasons['full_run_step_rate_disabled'] += 1 + self.cache_fail_count += 1 + return None + elif current_timestep % full_run_step_rate == 0: + if f"timestep_{current_timestep}" in CACHE_LAST: + self.cache_success_count += 1 + self.success_reasons['cached_exact'] += 1 + CACHE_LAST["last"] = CACHE_LAST[f"timestep_{current_timestep}"] # update last + return CACHE_LAST[f"timestep_{current_timestep}"] + else: + print(f"Cache not found for timestep {current_timestep}\n available: {list(CACHE_LAST.keys())}") + self.fail_reasons['full_run_step_rate_division'] += 1 + self.cache_fail_count += 1 + return None + elif CACHE_LAST.get("real_timestep", 0) + valid_cache_timestep_range < real_timestep: + self.fail_reasons['cache_outdated'] += 1 + self.cache_fail_count += 1 + return None + # check if cache exists + if "last" in CACHE_LAST: + self.success_reasons['cached_last'] += 1 + self.cache_success_count += 1 + return CACHE_LAST["last"] + self.fail_reasons['not_cached'] += 1 + self.cache_fail_count += 1 + return None + def hijacked_unet_forward(x, timesteps=None, context=None, y=None, **kwargs): + cache_cond = lambda : self.enumerated_timestep["value"] % full_run_step_rate == 0 or self.enumerated_timestep["value"] > cache_enable_step + use_cache_cond = lambda : self.enumerated_timestep["value"] > cache_enable_step and self.enumerated_timestep["value"] % full_run_step_rate != 0 + nonlocal CACHE_LAST + assert (y is not None) == ( + hasattr(unet, 'num_classes') and unet.num_classes is not None #v2 or xl + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype) + emb = unet.time_embed(t_emb) + if hasattr(unet, 'num_classes') and unet.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + unet.label_emb(y) + real_timestep = timesteps[0].item() + h = x.type(unet.dtype) + cached_h = get_cache(self.enumerated_timestep["value"], real_timestep) + for id, module in enumerate(unet.input_blocks): + self.log_skip('run_before_cache_input_block') + h = forward_timestep_embed(module, h, emb, context) + hs.append(h) + if cached_h is not None and use_cache_cond() and id == caching_level: + break + if not use_cache_cond(): + self.log_skip('run_before_cache_middle_block') + h = forward_timestep_embed(unet.middle_block, h, emb, context) + relative_cache_level = len(unet.output_blocks) - caching_level - 1 + for idx, module in enumerate(unet.output_blocks): + if cached_h is not None and use_cache_cond() and idx == relative_cache_level: + # use cache + h = cached_h + elif cache_cond() and idx == relative_cache_level: + # put cache + put_cache(h, self.enumerated_timestep["value"], real_timestep) + elif cached_h is not None and use_cache_cond() and idx < relative_cache_level: + # skip, h is already cached + continue + hsp = hs.pop() + h = torch.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape) + h = h.type(x.dtype) + self.enumerated_timestep["value"] += 1 + if unet.predict_codebook_ids: + return unet.id_predictor(h) + else: + return unet.out(h) + unet.forward = hijacked_unet_forward + unet._deepcache_hooked = True + self.unet_reference = unet + + def detach(self): + if self.unet_reference is None: + return + if not getattr(self.unet_reference, '_deepcache_hooked', False): + return + # detach + self.unet_reference.forward = self.stored_forward + self.unet_reference._deepcache_hooked = False + self.unet_reference = None + self.stored_forward = None + self.CACHE_LAST.clear() + self.cache_fail_count = self.cache_success_count = 0# diff --git a/extensions-builtin/deepcache/scripts/deepcache_script.py b/extensions-builtin/deepcache/scripts/deepcache_script.py new file mode 100644 index 000000000..ad161b217 --- /dev/null +++ b/extensions-builtin/deepcache/scripts/deepcache_script.py @@ -0,0 +1,75 @@ +from modules import scripts, script_callbacks, shared, processing +from deepcache import DeepCacheSession, DeepCacheParams +from scripts.deepcache_xyz import add_axis_options + +class ScriptDeepCache(scripts.Script): + + name = "DeepCache" + session: DeepCacheSession = None + + def title(self): + return self.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def get_deepcache_params(self, steps: int) -> DeepCacheParams: + return DeepCacheParams( + cache_in_level=shared.opts.deepcache_cache_resnet_level, + cache_enable_step=int(shared.opts.deepcache_cache_enable_step_percentage * steps), + full_run_step_rate=shared.opts.deepcache_full_run_step_rate, + ) + + def process_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs): + print("DeepCache process") + self.detach_deepcache() + if shared.opts.deepcache_enable: + self.configure_deepcache(self.get_deepcache_params(p.steps)) + + def before_hr(self, p:processing.StableDiffusionProcessing, *args): + print("DeepCache before_hr") + if self.session is not None: + self.session.enumerated_timestep["value"] = -1 # reset enumerated timestep + if not shared.opts.deepcache_hr_reuse: + self.detach_deepcache() + if shared.opts.deepcache_enable: + self.configure_deepcache(self.get_deepcache_params(getattr(p, 'hr_second_pass_steps', 0) or p.steps)) # use second pass steps if available + + def postprocess_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs): + print("DeepCache postprocess") + self.detach_deepcache() + + def configure_deepcache(self, params:DeepCacheParams): + if self.session is None: + self.session = DeepCacheSession() + self.session.deepcache_hook_model( + shared.sd_model.model.diffusion_model, #unet_model + params + ) + + def detach_deepcache(self): + print("Detaching DeepCache") + if self.session is None: + return + self.session.report() + self.session.detach() + self.session = None + +def on_ui_settings(): + import gradio as gr + options = { + "deepcache_explanation": shared.OptionHTML(""" + DeepCache optimizes by caching the results of mid-blocks, which is known for high level features, and reusing them in the next forward pass. + """), + "deepcache_enable": shared.OptionInfo(False, "Enable DeepCache").info("noticeable change in details of the generated picture"), + "deepcache_cache_resnet_level": shared.OptionInfo(0, "Cache Resnet level", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("Deeper = fewer layers cached"), + "deepcache_cache_enable_step_percentage": shared.OptionInfo(0.4, "Deepcaches is enabled after the step percentage", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache"), + "deepcache_full_run_step_rate": shared.OptionInfo(5, "Refreshes caches when step is divisible by number", gr.Slider, {"minimum": 0, "maximum": 1000, "step": 1}).info("5 = refresh caches every 5 steps"), + "deepcache_hr_reuse" : shared.OptionInfo(False, "Reuse for HR").info("Reuses cache information for HR generation"), + } + for name, opt in options.items(): + opt.section = ('deepcache', "DeepCache") + shared.opts.add_option(name, opt) + +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_before_ui(add_axis_options) diff --git a/extensions-builtin/deepcache/scripts/deepcache_xyz.py b/extensions-builtin/deepcache/scripts/deepcache_xyz.py new file mode 100644 index 000000000..ad0af8501 --- /dev/null +++ b/extensions-builtin/deepcache/scripts/deepcache_xyz.py @@ -0,0 +1,63 @@ +from modules import scripts +from modules.shared import opts + +xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module + +def int_applier(value_name:str, min_range:int = -1, max_range:int = -1): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + value = int(value) + # validate value + if not min_range == -1: + assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" + if not max_range == -1: + assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}" + def apply_int(p, x, xs): + validate(value_name, x) + opts.data[value_name] = int(x) + return apply_int + +def bool_applier(value_name:str): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false" + def apply_bool(p, x, xs): + validate(value_name, x) + value_boolean = x.lower() == "true" + opts.data[value_name] = value_boolean + return apply_bool + +def float_applier(value_name:str, min_range:float = -1, max_range:float = -1): + """ + Returns a function that applies the given value to the given value_name in opts.data. + """ + def validate(value_name:str, value:str): + value = float(value) + # validate value + if not min_range == -1: + assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}" + if not max_range == -1: + assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}" + def apply_float(p, x, xs): + validate(value_name, x) + opts.data[value_name] = float(x) + return apply_float + +def add_axis_options(): + extra_axis_options = [ + xyz_grid.AxisOption("[DeepCache] Enabled", str, bool_applier("deepcache_enable"), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[DeepCache] Cache Resnet level", int, int_applier("deepcache_cache_resnet_level", 0, 10)), + xyz_grid.AxisOption("[DeepCache] Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage", 0, 1)), + xyz_grid.AxisOption("[DeepCache] Cache Refresh Rate", int, int_applier("deepcache_full_run_step_rate", 0, 1000)), + xyz_grid.AxisOption("[DeepCache] HR Reuse", str, bool_applier("deepcache_hr_reuse"), choices=xyz_grid.boolean_choice(reverse=True)), + ] + set_a = {opt.label for opt in xyz_grid.axis_options} + set_b = {opt.label for opt in extra_axis_options} + if set_a.intersection(set_b): + return + + xyz_grid.axis_options.extend(extra_axis_options) diff --git a/extensions-builtin/deepcache/scripts/forward_timestep_embed_patch.py b/extensions-builtin/deepcache/scripts/forward_timestep_embed_patch.py new file mode 100644 index 000000000..f0b01247f --- /dev/null +++ b/extensions-builtin/deepcache/scripts/forward_timestep_embed_patch.py @@ -0,0 +1,62 @@ +""" +Patched forward_timestep_embed function to support the following: +@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py +""" +from ldm.modules.attention import SpatialTransformer +try: + from ldm.modules.attention import SpatialVideoTransformer +except (ImportError, ModuleNotFoundError): + SpatialVideoTransformer = None +from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample +try: + from ldm.modules.diffusionmodules.openaimodel import VideoResBlock +except (ImportError, ModuleNotFoundError): + VideoResBlock = None + +# SD XL modules from generative-models repo +from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM +try: + from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM +except (ImportError, ModuleNotFoundError): + SpatialVideoTransformerSGM = None +from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM +try: + from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM +except (ImportError, ModuleNotFoundError): + VideoResBlockSGM = None + +import torch.nn.functional as F + +def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): + for layer in ts: + if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)): + x = layer(x, emb) + elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)): + x = layer(x, context, time_context, num_video_frames, image_only_indicator) + elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)): + x = layer(x, context) + elif isinstance(layer, (Upsample, UpsampleSGM)): + x = forward_upsample(layer, x, output_shape=output_shape) + else: + x = layer(x) + return x + +def forward_upsample(self:Upsample, x, output_shape=None): + assert x.shape[1] == self.channels + if self.dims == 3: + shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] + if output_shape is not None: + shape[1] = output_shape[3] + shape[2] = output_shape[4] + else: + shape = [x.shape[2] * 2, x.shape[3] * 2] + if output_shape is not None: + shape[0] = output_shape[2] + shape[1] = output_shape[3] + + x = F.interpolate(x, size=shape, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x