This commit is contained in:
AUTOMATIC1111 2024-06-16 08:13:23 +03:00
parent 5b2a60b8e2
commit 79de09c3df
2 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,7 @@
### This file contains impls for underlying related models (CLIP, T5, etc)
import torch, math
import torch
import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
for i, layer in enumerate(self.layers):
x = layer(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@ -215,7 +216,7 @@ class SD3Tokenizer:
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
tokens = [a[0] for a in token_weight_pairs[0]]
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)

View File

@ -1,6 +1,8 @@
### Impls of the SD3 core diffusion model and VAE
import torch, math, einops
import torch
import math
import einops
from modules.models.sd3.mmdit import MMDiT
from PIL import Image
@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
attn = torch.nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(num_res_blocks):
for _ in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
down = torch.nn.Module()
@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
up = torch.nn.Module()