mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-05-06 11:59:06 +08:00
Merge 873aea82ea1b58c2600dbe6ffa299ad46ab5b4c5 into 374bb6cc384d2a19422c0b07d69de0a41d1f3f4d
This commit is contained in:
commit
e9442a62c9
180
extensions-builtin/deepcache/deepcache.py
Normal file
180
extensions-builtin/deepcache/deepcache.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from ldm.modules.diffusionmodules.openaimodel import timestep_embedding
|
||||
from scripts.forward_timestep_embed_patch import forward_timestep_embed
|
||||
|
||||
from logging import getLogger
|
||||
@dataclass
|
||||
class DeepCacheParams:
|
||||
cache_in_level: int = 0
|
||||
cache_enable_step: int = 0
|
||||
full_run_step_rate: int = 5
|
||||
# cache_latents_cpu: bool = False
|
||||
# cache_latents_hires: bool = False
|
||||
|
||||
class DeepCacheSession:
|
||||
"""
|
||||
Session for DeepCache, which holds cache data and provides functions for hooking the model.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
self.CACHE_LAST = {"timestep": {0}}
|
||||
self.stored_forward = None
|
||||
self.unet_reference = None
|
||||
self.cache_success_count = 0
|
||||
self.cache_fail_count = 0
|
||||
self.fail_reasons = defaultdict(int)
|
||||
self.success_reasons = defaultdict(int)
|
||||
self.enumerated_timestep = {"value": -1}
|
||||
|
||||
def log_skip(self, reason:str = 'disabled_by_default'):
|
||||
self.fail_reasons[reason] += 1
|
||||
self.cache_fail_count += 1
|
||||
|
||||
def report(self):
|
||||
# report cache success rate
|
||||
total = self.cache_success_count + self.cache_fail_count
|
||||
if total == 0:
|
||||
return
|
||||
logger = getLogger()
|
||||
level = logger.getEffectiveLevel()
|
||||
logger.log(level, "DeepCache Information :")
|
||||
for fail_reasons, count in self.fail_reasons.items():
|
||||
logger.log(level, f" {fail_reasons}: {count}")
|
||||
for success_reasons, count in self.success_reasons.items():
|
||||
logger.log(level, f" {success_reasons}: {count}")
|
||||
|
||||
def deepcache_hook_model(self, unet, params:DeepCacheParams):
|
||||
"""
|
||||
Hooks the given unet model to use DeepCache.
|
||||
"""
|
||||
caching_level = params.cache_in_level
|
||||
# caching level 0 = no caching, idx for resnet layers
|
||||
cache_enable_step = params.cache_enable_step
|
||||
full_run_step_rate = params.full_run_step_rate # '5' means run full model every 5 steps
|
||||
if full_run_step_rate < 1:
|
||||
print(f"DeepCache disabled due to full_run_step_rate {full_run_step_rate} < 1 but enabled by user")
|
||||
return # disabled
|
||||
if getattr(unet, '_deepcache_hooked', False):
|
||||
return # already hooked
|
||||
CACHE_LAST = self.CACHE_LAST
|
||||
self.stored_forward = unet.forward
|
||||
self.enumerated_timestep["value"] = -1
|
||||
valid_caching_in_level = min(caching_level, len(unet.input_blocks) - 1)
|
||||
valid_caching_out_level = min(valid_caching_in_level, len(unet.output_blocks) - 1)
|
||||
# set to max if invalid
|
||||
caching_level = valid_caching_out_level
|
||||
valid_cache_timestep_range = 50 # total 1000, 50
|
||||
def put_cache(h:torch.Tensor, timestep:int, real_timestep:float):
|
||||
"""
|
||||
Registers cache
|
||||
"""
|
||||
CACHE_LAST["timestep"].add(timestep)
|
||||
assert h is not None, "Cannot cache None"
|
||||
# maybe move to cpu and load later for low vram?
|
||||
CACHE_LAST["last"] = h
|
||||
CACHE_LAST[f"timestep_{timestep}"] = h
|
||||
CACHE_LAST["real_timestep"] = real_timestep
|
||||
def get_cache(current_timestep:int, real_timestep:float) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Returns the cached tensor for the given timestep and cache key.
|
||||
"""
|
||||
if current_timestep < cache_enable_step:
|
||||
self.fail_reasons['disabled'] += 1
|
||||
self.cache_fail_count += 1
|
||||
return None
|
||||
elif full_run_step_rate < 1:
|
||||
self.fail_reasons['full_run_step_rate_disabled'] += 1
|
||||
self.cache_fail_count += 1
|
||||
return None
|
||||
elif current_timestep % full_run_step_rate == 0:
|
||||
if f"timestep_{current_timestep}" in CACHE_LAST:
|
||||
self.cache_success_count += 1
|
||||
self.success_reasons['cached_exact'] += 1
|
||||
CACHE_LAST["last"] = CACHE_LAST[f"timestep_{current_timestep}"] # update last
|
||||
return CACHE_LAST[f"timestep_{current_timestep}"]
|
||||
self.fail_reasons['full_run_step_rate_division'] += 1
|
||||
self.cache_fail_count += 1
|
||||
return None
|
||||
elif CACHE_LAST.get("real_timestep", 0) + valid_cache_timestep_range < real_timestep:
|
||||
self.fail_reasons['cache_outdated'] += 1
|
||||
self.cache_fail_count += 1
|
||||
return None
|
||||
# check if cache exists
|
||||
if "last" in CACHE_LAST:
|
||||
self.success_reasons['cached_last'] += 1
|
||||
self.cache_success_count += 1
|
||||
return CACHE_LAST["last"]
|
||||
self.fail_reasons['not_cached'] += 1
|
||||
self.cache_fail_count += 1
|
||||
return None
|
||||
def hijacked_unet_forward(x, timesteps=None, context=None, y=None, **kwargs):
|
||||
cache_cond = lambda : self.enumerated_timestep["value"] % full_run_step_rate == 0 or self.enumerated_timestep["value"] > cache_enable_step
|
||||
use_cache_cond = lambda : self.enumerated_timestep["value"] > cache_enable_step and self.enumerated_timestep["value"] % full_run_step_rate != 0
|
||||
nonlocal CACHE_LAST
|
||||
assert (y is not None) == (
|
||||
hasattr(unet, 'num_classes') and unet.num_classes is not None #v2 or xl
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
|
||||
emb = unet.time_embed(t_emb)
|
||||
if hasattr(unet, 'num_classes') and unet.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + unet.label_emb(y)
|
||||
real_timestep = timesteps[0].item()
|
||||
h = x.type(unet.dtype)
|
||||
cached_h = get_cache(self.enumerated_timestep["value"], real_timestep)
|
||||
for id, module in enumerate(unet.input_blocks):
|
||||
self.log_skip('run_before_cache_input_block')
|
||||
h = forward_timestep_embed(module, h, emb, context)
|
||||
hs.append(h)
|
||||
if cached_h is not None and use_cache_cond() and id == caching_level:
|
||||
break
|
||||
if not use_cache_cond():
|
||||
self.log_skip('run_before_cache_middle_block')
|
||||
h = forward_timestep_embed(unet.middle_block, h, emb, context)
|
||||
relative_cache_level = len(unet.output_blocks) - caching_level - 1
|
||||
for idx, module in enumerate(unet.output_blocks):
|
||||
if cached_h is not None and use_cache_cond() and idx == relative_cache_level:
|
||||
# use cache
|
||||
h = cached_h
|
||||
elif cache_cond() and idx == relative_cache_level:
|
||||
# put cache
|
||||
put_cache(h, self.enumerated_timestep["value"], real_timestep)
|
||||
elif cached_h is not None and use_cache_cond() and idx < relative_cache_level:
|
||||
# skip, h is already cached
|
||||
continue
|
||||
hsp = hs.pop()
|
||||
h = torch.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape)
|
||||
h = h.type(x.dtype)
|
||||
self.enumerated_timestep["value"] += 1
|
||||
if unet.predict_codebook_ids:
|
||||
return unet.id_predictor(h)
|
||||
else:
|
||||
return unet.out(h)
|
||||
unet.forward = hijacked_unet_forward
|
||||
unet._deepcache_hooked = True
|
||||
self.unet_reference = unet
|
||||
|
||||
def detach(self):
|
||||
if self.unet_reference is None:
|
||||
return
|
||||
if not getattr(self.unet_reference, '_deepcache_hooked', False):
|
||||
return
|
||||
# detach
|
||||
self.unet_reference.forward = self.stored_forward
|
||||
self.unet_reference._deepcache_hooked = False
|
||||
self.unet_reference = None
|
||||
self.stored_forward = None
|
||||
self.CACHE_LAST.clear()
|
||||
self.cache_fail_count = self.cache_success_count = 0#
|
78
extensions-builtin/deepcache/scripts/deepcache_script.py
Normal file
78
extensions-builtin/deepcache/scripts/deepcache_script.py
Normal file
@ -0,0 +1,78 @@
|
||||
from modules import scripts, script_callbacks, shared, processing
|
||||
from deepcache import DeepCacheSession, DeepCacheParams
|
||||
from scripts.deepcache_xyz import add_axis_options
|
||||
|
||||
class ScriptDeepCache(scripts.Script):
|
||||
|
||||
name = "DeepCache"
|
||||
session: DeepCacheSession = None
|
||||
|
||||
def title(self):
|
||||
return self.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def get_deepcache_params(self, steps: int, enable_step_at:int = None) -> DeepCacheParams:
|
||||
return DeepCacheParams(
|
||||
cache_in_level=shared.opts.deepcache_cache_resnet_level,
|
||||
cache_enable_step=int(shared.opts.deepcache_cache_enable_step_percentage * steps) if enable_step_at is None else enable_step_at,
|
||||
full_run_step_rate=shared.opts.deepcache_full_run_step_rate,
|
||||
)
|
||||
|
||||
def process_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
|
||||
print("DeepCache process")
|
||||
self.detach_deepcache()
|
||||
if shared.opts.deepcache_enable:
|
||||
self.configure_deepcache(self.get_deepcache_params(p.steps))
|
||||
|
||||
def before_hr(self, p:processing.StableDiffusionProcessing, *args):
|
||||
print("DeepCache before_hr")
|
||||
if self.session is not None:
|
||||
self.session.enumerated_timestep["value"] = -1 # reset enumerated timestep
|
||||
if not shared.opts.deepcache_hr_reuse:
|
||||
self.detach_deepcache()
|
||||
if shared.opts.deepcache_enable:
|
||||
hr_steps = getattr(p, 'hr_second_pass_steps', 0) or p.steps
|
||||
enable_step = int(shared.opts.deepcache_cache_enable_step_percentage_hr * hr_steps)
|
||||
self.configure_deepcache(self.get_deepcache_params(getattr(p, 'hr_second_pass_steps', 0) or p.steps, enable_step_at = enable_step)) # use second pass steps if available
|
||||
|
||||
def postprocess_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
|
||||
print("DeepCache postprocess")
|
||||
self.detach_deepcache()
|
||||
|
||||
def configure_deepcache(self, params:DeepCacheParams):
|
||||
if self.session is None:
|
||||
self.session = DeepCacheSession()
|
||||
self.session.deepcache_hook_model(
|
||||
shared.sd_model.model.diffusion_model, #unet_model
|
||||
params
|
||||
)
|
||||
|
||||
def detach_deepcache(self):
|
||||
print("Detaching DeepCache")
|
||||
if self.session is None:
|
||||
return
|
||||
self.session.report()
|
||||
self.session.detach()
|
||||
self.session = None
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
options = {
|
||||
"deepcache_explanation": shared.OptionHTML("""
|
||||
<a href='https://github.com/horseee/DeepCache'>DeepCache</a> optimizes by caching the results of mid-blocks, which is known for high level features, and reusing them in the next forward pass.
|
||||
"""),
|
||||
"deepcache_enable": shared.OptionInfo(False, "Enable DeepCache").info("noticeable change in details of the generated picture"),
|
||||
"deepcache_cache_resnet_level": shared.OptionInfo(0, "Cache Resnet level", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("Deeper = fewer layers cached"),
|
||||
"deepcache_cache_enable_step_percentage": shared.OptionInfo(0.4, "Deepcaches is enabled after the step percentage", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache"),
|
||||
"deepcache_full_run_step_rate": shared.OptionInfo(5, "Refreshes caches when step is divisible by number", gr.Slider, {"minimum": 0, "maximum": 1000, "step": 1}).info("5 = refresh caches every 5 steps"),
|
||||
"deepcache_hr_reuse" : shared.OptionInfo(False, "Reuse for HR").info("Reuses cache information for HR generation"),
|
||||
"deepcache_cache_enable_step_percentage_hr" : shared.OptionInfo(0.0, "Deepcaches is enabled after the step percentage for HR", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache for HR generation"),
|
||||
}
|
||||
for name, opt in options.items():
|
||||
opt.section = ('deepcache', "DeepCache")
|
||||
shared.opts.add_option(name, opt)
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_before_ui(add_axis_options)
|
64
extensions-builtin/deepcache/scripts/deepcache_xyz.py
Normal file
64
extensions-builtin/deepcache/scripts/deepcache_xyz.py
Normal file
@ -0,0 +1,64 @@
|
||||
from modules import scripts
|
||||
from modules.shared import opts
|
||||
|
||||
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
|
||||
|
||||
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
|
||||
"""
|
||||
Returns a function that applies the given value to the given value_name in opts.data.
|
||||
"""
|
||||
def validate(value_name:str, value:str):
|
||||
value = int(value)
|
||||
# validate value
|
||||
if not min_range == -1:
|
||||
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
|
||||
if not max_range == -1:
|
||||
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
|
||||
def apply_int(p, x, xs):
|
||||
validate(value_name, x)
|
||||
opts.data[value_name] = int(x)
|
||||
return apply_int
|
||||
|
||||
def bool_applier(value_name:str):
|
||||
"""
|
||||
Returns a function that applies the given value to the given value_name in opts.data.
|
||||
"""
|
||||
def validate(value_name:str, value:str):
|
||||
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
|
||||
def apply_bool(p, x, xs):
|
||||
validate(value_name, x)
|
||||
value_boolean = x.lower() == "true"
|
||||
opts.data[value_name] = value_boolean
|
||||
return apply_bool
|
||||
|
||||
def float_applier(value_name:str, min_range:float = -1, max_range:float = -1):
|
||||
"""
|
||||
Returns a function that applies the given value to the given value_name in opts.data.
|
||||
"""
|
||||
def validate(value_name:str, value:str):
|
||||
value = float(value)
|
||||
# validate value
|
||||
if not min_range == -1:
|
||||
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
|
||||
if not max_range == -1:
|
||||
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
|
||||
def apply_float(p, x, xs):
|
||||
validate(value_name, x)
|
||||
opts.data[value_name] = float(x)
|
||||
return apply_float
|
||||
|
||||
def add_axis_options():
|
||||
extra_axis_options = [
|
||||
xyz_grid.AxisOption("[DeepCache] Enabled", str, bool_applier("deepcache_enable"), choices=xyz_grid.boolean_choice(reverse=True)),
|
||||
xyz_grid.AxisOption("[DeepCache] Cache Resnet level", int, int_applier("deepcache_cache_resnet_level", 0, 10)),
|
||||
xyz_grid.AxisOption("[DeepCache] Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage", 0, 1)),
|
||||
xyz_grid.AxisOption("[DeepCache] Cache Refresh Rate", int, int_applier("deepcache_full_run_step_rate", 0, 1000)),
|
||||
xyz_grid.AxisOption("[DeepCache] HR Reuse", str, bool_applier("deepcache_hr_reuse"), choices=xyz_grid.boolean_choice(reverse=True)),
|
||||
xyz_grid.AxisOption("[DeepCache] HR Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage_hr", 0, 1)),
|
||||
]
|
||||
set_a = {opt.label for opt in xyz_grid.axis_options}
|
||||
set_b = {opt.label for opt in extra_axis_options}
|
||||
if set_a.intersection(set_b):
|
||||
return
|
||||
|
||||
xyz_grid.axis_options.extend(extra_axis_options)
|
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Patched forward_timestep_embed function to support the following:
|
||||
@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py
|
||||
"""
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
try:
|
||||
from ldm.modules.attention import SpatialVideoTransformer
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
SpatialVideoTransformer = None
|
||||
from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample
|
||||
try:
|
||||
from ldm.modules.diffusionmodules.openaimodel import VideoResBlock
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
VideoResBlock = None
|
||||
|
||||
# SD XL modules from generative-models repo
|
||||
from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM
|
||||
try:
|
||||
from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
SpatialVideoTransformerSGM = None
|
||||
from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM
|
||||
try:
|
||||
from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
VideoResBlockSGM = None
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)):
|
||||
x = layer(x, emb)
|
||||
elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)):
|
||||
x = layer(x, context, time_context, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)):
|
||||
x = layer(x, context)
|
||||
elif isinstance(layer, (Upsample, UpsampleSGM)):
|
||||
x = forward_upsample(layer, x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
def forward_upsample(self:Upsample, x, output_shape=None):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
|
||||
if output_shape is not None:
|
||||
shape[1] = output_shape[3]
|
||||
shape[2] = output_shape[4]
|
||||
else:
|
||||
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
||||
if output_shape is not None:
|
||||
shape[0] = output_shape[2]
|
||||
shape[1] = output_shape[3]
|
||||
|
||||
x = F.interpolate(x, size=shape, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
Loading…
x
Reference in New Issue
Block a user