mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 02:45:05 +08:00
linter
This commit is contained in:
parent
5b2a60b8e2
commit
79de09c3df
@ -1,6 +1,7 @@
|
||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||
|
||||
import torch, math
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask)
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x, mask)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
@ -215,7 +216,7 @@ class SD3Tokenizer:
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
||||
tokens = [a[0] for a in token_weight_pairs[0]]
|
||||
out, pooled = self([tokens])
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
|
||||
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||
@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
param.requires_grad = False
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
|
||||
intermediate = None
|
||||
x = self.embed_tokens(input_ids)
|
||||
past_bias = None
|
||||
for i, l in enumerate(self.block):
|
||||
x, past_bias = l(x, past_bias)
|
||||
for i, layer in enumerate(self.block):
|
||||
x, past_bias = layer(x, past_bias)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
x = self.final_layer_norm(x)
|
||||
|
@ -1,6 +1,8 @@
|
||||
### Impls of the SD3 core diffusion model and VAE
|
||||
|
||||
import torch, math, einops
|
||||
import torch
|
||||
import math
|
||||
import einops
|
||||
from modules.models.sd3.mmdit import MMDiT
|
||||
from PIL import Image
|
||||
|
||||
@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
|
||||
k = self.k(hidden)
|
||||
v = self.v(hidden)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
||||
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
|
||||
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
hidden = self.proj_out(hidden)
|
||||
@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
|
||||
attn = torch.nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(num_res_blocks):
|
||||
for _ in range(num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||
block_in = block_out
|
||||
down = torch.nn.Module()
|
||||
@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = torch.nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||
block_in = block_out
|
||||
up = torch.nn.Module()
|
||||
|
Loading…
Reference in New Issue
Block a user