mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
linter
This commit is contained in:
parent
5b2a60b8e2
commit
79de09c3df
@ -1,6 +1,7 @@
|
|||||||
### This file contains impls for underlying related models (CLIP, T5, etc)
|
### This file contains impls for underlying related models (CLIP, T5, etc)
|
||||||
|
|
||||||
import torch, math
|
import torch
|
||||||
|
import math
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||||
|
|
||||||
@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
|
|||||||
"""Convenience wrapper around a basic attention operation"""
|
"""Convenience wrapper around a basic attention operation"""
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
|
||||||
@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
|
|||||||
if intermediate_output < 0:
|
if intermediate_output < 0:
|
||||||
intermediate_output = len(self.layers) + intermediate_output
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
intermediate = None
|
intermediate = None
|
||||||
for i, l in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
x = l(x, mask)
|
x = layer(x, mask)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
return x, intermediate
|
return x, intermediate
|
||||||
@ -215,7 +216,7 @@ class SD3Tokenizer:
|
|||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
|
tokens = [a[0] for a in token_weight_pairs[0]]
|
||||||
out, pooled = self([tokens])
|
out, pooled = self([tokens])
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].cpu()
|
first_pooled = pooled[0:1].cpu()
|
||||||
@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
LAYERS = ["last", "pooled", "hidden"]
|
LAYERS = ["last", "pooled", "hidden"]
|
||||||
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
|
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
|
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
self.transformer = model_class(textmodel_json_config, dtype, device)
|
self.transformer = model_class(textmodel_json_config, dtype, device)
|
||||||
@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.special_tokens = special_tokens
|
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
self.return_projected_pooled = return_projected_pooled
|
self.return_projected_pooled = return_projected_pooled
|
||||||
@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
|
|||||||
intermediate = None
|
intermediate = None
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
past_bias = None
|
past_bias = None
|
||||||
for i, l in enumerate(self.block):
|
for i, layer in enumerate(self.block):
|
||||||
x, past_bias = l(x, past_bias)
|
x, past_bias = layer(x, past_bias)
|
||||||
if i == intermediate_output:
|
if i == intermediate_output:
|
||||||
intermediate = x.clone()
|
intermediate = x.clone()
|
||||||
x = self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
### Impls of the SD3 core diffusion model and VAE
|
### Impls of the SD3 core diffusion model and VAE
|
||||||
|
|
||||||
import torch, math, einops
|
import torch
|
||||||
|
import math
|
||||||
|
import einops
|
||||||
from modules.models.sd3.mmdit import MMDiT
|
from modules.models.sd3.mmdit import MMDiT
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
|
|||||||
k = self.k(hidden)
|
k = self.k(hidden)
|
||||||
v = self.v(hidden)
|
v = self.v(hidden)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
|
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
|
||||||
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
|
||||||
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||||
hidden = self.proj_out(hidden)
|
hidden = self.proj_out(hidden)
|
||||||
@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
|
|||||||
attn = torch.nn.ModuleList()
|
attn = torch.nn.ModuleList()
|
||||||
block_in = ch*in_ch_mult[i_level]
|
block_in = ch*in_ch_mult[i_level]
|
||||||
block_out = ch*ch_mult[i_level]
|
block_out = ch*ch_mult[i_level]
|
||||||
for i_block in range(num_res_blocks):
|
for _ in range(num_res_blocks):
|
||||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
down = torch.nn.Module()
|
down = torch.nn.Module()
|
||||||
@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
|
|||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
block = torch.nn.ModuleList()
|
block = torch.nn.ModuleList()
|
||||||
block_out = ch * ch_mult[i_level]
|
block_out = ch * ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks + 1):
|
for _ in range(self.num_res_blocks + 1):
|
||||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
up = torch.nn.Module()
|
up = torch.nn.Module()
|
||||||
|
Loading…
Reference in New Issue
Block a user