aria1th f166868df6 DeepCache Implementation Mark 2
maybe... some invalid interrupt method?

move to paper implementation

fix descriptions, KeyError

handle sgm for XL

fix ruff, change default for out_block

Implement Deepcache Optimization
2023-12-08 01:50:04 +09:00

63 lines
2.7 KiB
Python

"""
Patched forward_timestep_embed function to support the following:
@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py
"""
from ldm.modules.attention import SpatialTransformer
try:
from ldm.modules.attention import SpatialVideoTransformer
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformer = None
from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample
try:
from ldm.modules.diffusionmodules.openaimodel import VideoResBlock
except (ImportError, ModuleNotFoundError):
VideoResBlock = None
# SD XL modules from generative-models repo
from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM
try:
from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformerSGM = None
from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM
try:
from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM
except (ImportError, ModuleNotFoundError):
VideoResBlockSGM = None
import torch.nn.functional as F
def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)):
x = layer(x, emb)
elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)):
x = layer(x, context, time_context, num_video_frames, image_only_indicator)
elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)):
x = layer(x, context)
elif isinstance(layer, (Upsample, UpsampleSGM)):
x = forward_upsample(layer, x, output_shape=output_shape)
else:
x = layer(x)
return x
def forward_upsample(self:Upsample, x, output_shape=None):
assert x.shape[1] == self.channels
if self.dims == 3:
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x