This commit is contained in:
AUTOMATIC1111 2024-06-16 08:13:23 +03:00
parent 5b2a60b8e2
commit 79de09c3df
2 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,7 @@
### This file contains impls for underlying related models (CLIP, T5, etc) ### This file contains impls for underlying related models (CLIP, T5, etc)
import torch, math import torch
import math
from torch import nn from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast from transformers import CLIPTokenizer, T5TokenizerFast
@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation""" """Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head) return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
if intermediate_output < 0: if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output intermediate_output = len(self.layers) + intermediate_output
intermediate = None intermediate = None
for i, l in enumerate(self.layers): for i, layer in enumerate(self.layers):
x = l(x, mask) x = layer(x, mask)
if i == intermediate_output: if i == intermediate_output:
intermediate = x.clone() intermediate = x.clone()
return x, intermediate return x, intermediate
@ -215,7 +216,7 @@ class SD3Tokenizer:
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0])) tokens = [a[0] for a in token_weight_pairs[0]]
out, pooled = self([tokens]) out, pooled = self([tokens])
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].cpu() first_pooled = pooled[0:1].cpu()
@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"] LAYERS = ["last", "pooled", "hidden"]
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True): special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device) self.transformer = model_class(textmodel_json_config, dtype, device)
@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False param.requires_grad = False
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.special_tokens = special_tokens self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled self.return_projected_pooled = return_projected_pooled
@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
intermediate = None intermediate = None
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
past_bias = None past_bias = None
for i, l in enumerate(self.block): for i, layer in enumerate(self.block):
x, past_bias = l(x, past_bias) x, past_bias = layer(x, past_bias)
if i == intermediate_output: if i == intermediate_output:
intermediate = x.clone() intermediate = x.clone()
x = self.final_layer_norm(x) x = self.final_layer_norm(x)

View File

@ -1,6 +1,8 @@
### Impls of the SD3 core diffusion model and VAE ### Impls of the SD3 core diffusion model and VAE
import torch, math, einops import torch
import math
import einops
from modules.models.sd3.mmdit import MMDiT from modules.models.sd3.mmdit import MMDiT
from PIL import Image from PIL import Image
@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
k = self.k(hidden) k = self.k(hidden)
v = self.v(hidden) v = self.v(hidden)
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden) hidden = self.proj_out(hidden)
@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
attn = torch.nn.ModuleList() attn = torch.nn.ModuleList()
block_in = ch*in_ch_mult[i_level] block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level] block_out = ch*ch_mult[i_level]
for i_block in range(num_res_blocks): for _ in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out block_in = block_out
down = torch.nn.Module() down = torch.nn.Module()
@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList() block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level] block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1): for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out block_in = block_out
up = torch.nn.Module() up = torch.nn.Module()