mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-05-06 20:09:06 +08:00
maybe... some invalid interrupt method? move to paper implementation fix descriptions, KeyError handle sgm for XL fix ruff, change default for out_block Implement Deepcache Optimization
63 lines
2.7 KiB
Python
63 lines
2.7 KiB
Python
"""
|
|
Patched forward_timestep_embed function to support the following:
|
|
@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py
|
|
"""
|
|
from ldm.modules.attention import SpatialTransformer
|
|
try:
|
|
from ldm.modules.attention import SpatialVideoTransformer
|
|
except (ImportError, ModuleNotFoundError):
|
|
SpatialVideoTransformer = None
|
|
from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample
|
|
try:
|
|
from ldm.modules.diffusionmodules.openaimodel import VideoResBlock
|
|
except (ImportError, ModuleNotFoundError):
|
|
VideoResBlock = None
|
|
|
|
# SD XL modules from generative-models repo
|
|
from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM
|
|
try:
|
|
from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM
|
|
except (ImportError, ModuleNotFoundError):
|
|
SpatialVideoTransformerSGM = None
|
|
from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM
|
|
try:
|
|
from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM
|
|
except (ImportError, ModuleNotFoundError):
|
|
VideoResBlockSGM = None
|
|
|
|
import torch.nn.functional as F
|
|
|
|
def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
|
for layer in ts:
|
|
if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)):
|
|
x = layer(x, emb, num_video_frames, image_only_indicator)
|
|
elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)):
|
|
x = layer(x, emb)
|
|
elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)):
|
|
x = layer(x, context, time_context, num_video_frames, image_only_indicator)
|
|
elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)):
|
|
x = layer(x, context)
|
|
elif isinstance(layer, (Upsample, UpsampleSGM)):
|
|
x = forward_upsample(layer, x, output_shape=output_shape)
|
|
else:
|
|
x = layer(x)
|
|
return x
|
|
|
|
def forward_upsample(self:Upsample, x, output_shape=None):
|
|
assert x.shape[1] == self.channels
|
|
if self.dims == 3:
|
|
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
|
|
if output_shape is not None:
|
|
shape[1] = output_shape[3]
|
|
shape[2] = output_shape[4]
|
|
else:
|
|
shape = [x.shape[2] * 2, x.shape[3] * 2]
|
|
if output_shape is not None:
|
|
shape[0] = output_shape[2]
|
|
shape[1] = output_shape[3]
|
|
|
|
x = F.interpolate(x, size=shape, mode="nearest")
|
|
if self.use_conv:
|
|
x = self.conv(x)
|
|
return x
|