Merge 873aea82ea1b58c2600dbe6ffa299ad46ab5b4c5 into 374bb6cc384d2a19422c0b07d69de0a41d1f3f4d

This commit is contained in:
AngelBottomless 2025-03-15 15:38:02 +05:30 committed by GitHub
commit e9442a62c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 384 additions and 0 deletions

View File

@ -0,0 +1,180 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from collections import defaultdict
import torch
from ldm.modules.diffusionmodules.openaimodel import timestep_embedding
from scripts.forward_timestep_embed_patch import forward_timestep_embed
from logging import getLogger
@dataclass
class DeepCacheParams:
cache_in_level: int = 0
cache_enable_step: int = 0
full_run_step_rate: int = 5
# cache_latents_cpu: bool = False
# cache_latents_hires: bool = False
class DeepCacheSession:
"""
Session for DeepCache, which holds cache data and provides functions for hooking the model.
"""
def __init__(self) -> None:
self.CACHE_LAST = {"timestep": {0}}
self.stored_forward = None
self.unet_reference = None
self.cache_success_count = 0
self.cache_fail_count = 0
self.fail_reasons = defaultdict(int)
self.success_reasons = defaultdict(int)
self.enumerated_timestep = {"value": -1}
def log_skip(self, reason:str = 'disabled_by_default'):
self.fail_reasons[reason] += 1
self.cache_fail_count += 1
def report(self):
# report cache success rate
total = self.cache_success_count + self.cache_fail_count
if total == 0:
return
logger = getLogger()
level = logger.getEffectiveLevel()
logger.log(level, "DeepCache Information :")
for fail_reasons, count in self.fail_reasons.items():
logger.log(level, f" {fail_reasons}: {count}")
for success_reasons, count in self.success_reasons.items():
logger.log(level, f" {success_reasons}: {count}")
def deepcache_hook_model(self, unet, params:DeepCacheParams):
"""
Hooks the given unet model to use DeepCache.
"""
caching_level = params.cache_in_level
# caching level 0 = no caching, idx for resnet layers
cache_enable_step = params.cache_enable_step
full_run_step_rate = params.full_run_step_rate # '5' means run full model every 5 steps
if full_run_step_rate < 1:
print(f"DeepCache disabled due to full_run_step_rate {full_run_step_rate} < 1 but enabled by user")
return # disabled
if getattr(unet, '_deepcache_hooked', False):
return # already hooked
CACHE_LAST = self.CACHE_LAST
self.stored_forward = unet.forward
self.enumerated_timestep["value"] = -1
valid_caching_in_level = min(caching_level, len(unet.input_blocks) - 1)
valid_caching_out_level = min(valid_caching_in_level, len(unet.output_blocks) - 1)
# set to max if invalid
caching_level = valid_caching_out_level
valid_cache_timestep_range = 50 # total 1000, 50
def put_cache(h:torch.Tensor, timestep:int, real_timestep:float):
"""
Registers cache
"""
CACHE_LAST["timestep"].add(timestep)
assert h is not None, "Cannot cache None"
# maybe move to cpu and load later for low vram?
CACHE_LAST["last"] = h
CACHE_LAST[f"timestep_{timestep}"] = h
CACHE_LAST["real_timestep"] = real_timestep
def get_cache(current_timestep:int, real_timestep:float) -> Optional[torch.Tensor]:
"""
Returns the cached tensor for the given timestep and cache key.
"""
if current_timestep < cache_enable_step:
self.fail_reasons['disabled'] += 1
self.cache_fail_count += 1
return None
elif full_run_step_rate < 1:
self.fail_reasons['full_run_step_rate_disabled'] += 1
self.cache_fail_count += 1
return None
elif current_timestep % full_run_step_rate == 0:
if f"timestep_{current_timestep}" in CACHE_LAST:
self.cache_success_count += 1
self.success_reasons['cached_exact'] += 1
CACHE_LAST["last"] = CACHE_LAST[f"timestep_{current_timestep}"] # update last
return CACHE_LAST[f"timestep_{current_timestep}"]
self.fail_reasons['full_run_step_rate_division'] += 1
self.cache_fail_count += 1
return None
elif CACHE_LAST.get("real_timestep", 0) + valid_cache_timestep_range < real_timestep:
self.fail_reasons['cache_outdated'] += 1
self.cache_fail_count += 1
return None
# check if cache exists
if "last" in CACHE_LAST:
self.success_reasons['cached_last'] += 1
self.cache_success_count += 1
return CACHE_LAST["last"]
self.fail_reasons['not_cached'] += 1
self.cache_fail_count += 1
return None
def hijacked_unet_forward(x, timesteps=None, context=None, y=None, **kwargs):
cache_cond = lambda : self.enumerated_timestep["value"] % full_run_step_rate == 0 or self.enumerated_timestep["value"] > cache_enable_step
use_cache_cond = lambda : self.enumerated_timestep["value"] > cache_enable_step and self.enumerated_timestep["value"] % full_run_step_rate != 0
nonlocal CACHE_LAST
assert (y is not None) == (
hasattr(unet, 'num_classes') and unet.num_classes is not None #v2 or xl
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
emb = unet.time_embed(t_emb)
if hasattr(unet, 'num_classes') and unet.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + unet.label_emb(y)
real_timestep = timesteps[0].item()
h = x.type(unet.dtype)
cached_h = get_cache(self.enumerated_timestep["value"], real_timestep)
for id, module in enumerate(unet.input_blocks):
self.log_skip('run_before_cache_input_block')
h = forward_timestep_embed(module, h, emb, context)
hs.append(h)
if cached_h is not None and use_cache_cond() and id == caching_level:
break
if not use_cache_cond():
self.log_skip('run_before_cache_middle_block')
h = forward_timestep_embed(unet.middle_block, h, emb, context)
relative_cache_level = len(unet.output_blocks) - caching_level - 1
for idx, module in enumerate(unet.output_blocks):
if cached_h is not None and use_cache_cond() and idx == relative_cache_level:
# use cache
h = cached_h
elif cache_cond() and idx == relative_cache_level:
# put cache
put_cache(h, self.enumerated_timestep["value"], real_timestep)
elif cached_h is not None and use_cache_cond() and idx < relative_cache_level:
# skip, h is already cached
continue
hsp = hs.pop()
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape)
h = h.type(x.dtype)
self.enumerated_timestep["value"] += 1
if unet.predict_codebook_ids:
return unet.id_predictor(h)
else:
return unet.out(h)
unet.forward = hijacked_unet_forward
unet._deepcache_hooked = True
self.unet_reference = unet
def detach(self):
if self.unet_reference is None:
return
if not getattr(self.unet_reference, '_deepcache_hooked', False):
return
# detach
self.unet_reference.forward = self.stored_forward
self.unet_reference._deepcache_hooked = False
self.unet_reference = None
self.stored_forward = None
self.CACHE_LAST.clear()
self.cache_fail_count = self.cache_success_count = 0#

View File

@ -0,0 +1,78 @@
from modules import scripts, script_callbacks, shared, processing
from deepcache import DeepCacheSession, DeepCacheParams
from scripts.deepcache_xyz import add_axis_options
class ScriptDeepCache(scripts.Script):
name = "DeepCache"
session: DeepCacheSession = None
def title(self):
return self.name
def show(self, is_img2img):
return scripts.AlwaysVisible
def get_deepcache_params(self, steps: int, enable_step_at:int = None) -> DeepCacheParams:
return DeepCacheParams(
cache_in_level=shared.opts.deepcache_cache_resnet_level,
cache_enable_step=int(shared.opts.deepcache_cache_enable_step_percentage * steps) if enable_step_at is None else enable_step_at,
full_run_step_rate=shared.opts.deepcache_full_run_step_rate,
)
def process_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
print("DeepCache process")
self.detach_deepcache()
if shared.opts.deepcache_enable:
self.configure_deepcache(self.get_deepcache_params(p.steps))
def before_hr(self, p:processing.StableDiffusionProcessing, *args):
print("DeepCache before_hr")
if self.session is not None:
self.session.enumerated_timestep["value"] = -1 # reset enumerated timestep
if not shared.opts.deepcache_hr_reuse:
self.detach_deepcache()
if shared.opts.deepcache_enable:
hr_steps = getattr(p, 'hr_second_pass_steps', 0) or p.steps
enable_step = int(shared.opts.deepcache_cache_enable_step_percentage_hr * hr_steps)
self.configure_deepcache(self.get_deepcache_params(getattr(p, 'hr_second_pass_steps', 0) or p.steps, enable_step_at = enable_step)) # use second pass steps if available
def postprocess_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
print("DeepCache postprocess")
self.detach_deepcache()
def configure_deepcache(self, params:DeepCacheParams):
if self.session is None:
self.session = DeepCacheSession()
self.session.deepcache_hook_model(
shared.sd_model.model.diffusion_model, #unet_model
params
)
def detach_deepcache(self):
print("Detaching DeepCache")
if self.session is None:
return
self.session.report()
self.session.detach()
self.session = None
def on_ui_settings():
import gradio as gr
options = {
"deepcache_explanation": shared.OptionHTML("""
<a href='https://github.com/horseee/DeepCache'>DeepCache</a> optimizes by caching the results of mid-blocks, which is known for high level features, and reusing them in the next forward pass.
"""),
"deepcache_enable": shared.OptionInfo(False, "Enable DeepCache").info("noticeable change in details of the generated picture"),
"deepcache_cache_resnet_level": shared.OptionInfo(0, "Cache Resnet level", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("Deeper = fewer layers cached"),
"deepcache_cache_enable_step_percentage": shared.OptionInfo(0.4, "Deepcaches is enabled after the step percentage", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache"),
"deepcache_full_run_step_rate": shared.OptionInfo(5, "Refreshes caches when step is divisible by number", gr.Slider, {"minimum": 0, "maximum": 1000, "step": 1}).info("5 = refresh caches every 5 steps"),
"deepcache_hr_reuse" : shared.OptionInfo(False, "Reuse for HR").info("Reuses cache information for HR generation"),
"deepcache_cache_enable_step_percentage_hr" : shared.OptionInfo(0.0, "Deepcaches is enabled after the step percentage for HR", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache for HR generation"),
}
for name, opt in options.items():
opt.section = ('deepcache', "DeepCache")
shared.opts.add_option(name, opt)
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_ui(add_axis_options)

View File

@ -0,0 +1,64 @@
from modules import scripts
from modules.shared import opts
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = int(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_int(p, x, xs):
validate(value_name, x)
opts.data[value_name] = int(x)
return apply_int
def bool_applier(value_name:str):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
def apply_bool(p, x, xs):
validate(value_name, x)
value_boolean = x.lower() == "true"
opts.data[value_name] = value_boolean
return apply_bool
def float_applier(value_name:str, min_range:float = -1, max_range:float = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = float(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_float(p, x, xs):
validate(value_name, x)
opts.data[value_name] = float(x)
return apply_float
def add_axis_options():
extra_axis_options = [
xyz_grid.AxisOption("[DeepCache] Enabled", str, bool_applier("deepcache_enable"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[DeepCache] Cache Resnet level", int, int_applier("deepcache_cache_resnet_level", 0, 10)),
xyz_grid.AxisOption("[DeepCache] Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage", 0, 1)),
xyz_grid.AxisOption("[DeepCache] Cache Refresh Rate", int, int_applier("deepcache_full_run_step_rate", 0, 1000)),
xyz_grid.AxisOption("[DeepCache] HR Reuse", str, bool_applier("deepcache_hr_reuse"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[DeepCache] HR Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage_hr", 0, 1)),
]
set_a = {opt.label for opt in xyz_grid.axis_options}
set_b = {opt.label for opt in extra_axis_options}
if set_a.intersection(set_b):
return
xyz_grid.axis_options.extend(extra_axis_options)

View File

@ -0,0 +1,62 @@
"""
Patched forward_timestep_embed function to support the following:
@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py
"""
from ldm.modules.attention import SpatialTransformer
try:
from ldm.modules.attention import SpatialVideoTransformer
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformer = None
from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample
try:
from ldm.modules.diffusionmodules.openaimodel import VideoResBlock
except (ImportError, ModuleNotFoundError):
VideoResBlock = None
# SD XL modules from generative-models repo
from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM
try:
from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformerSGM = None
from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM
try:
from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM
except (ImportError, ModuleNotFoundError):
VideoResBlockSGM = None
import torch.nn.functional as F
def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)):
x = layer(x, emb)
elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)):
x = layer(x, context, time_context, num_video_frames, image_only_indicator)
elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)):
x = layer(x, context)
elif isinstance(layer, (Upsample, UpsampleSGM)):
x = forward_upsample(layer, x, output_shape=output_shape)
else:
x = layer(x)
return x
def forward_upsample(self:Upsample, x, output_shape=None):
assert x.shape[1] == self.channels
if self.dims == 3:
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x