This commit is contained in:
AUTOMATIC1111 2024-06-24 09:06:10 +03:00
parent 34b4443cc3
commit a65dd315ad

View File

@ -11,6 +11,18 @@ from transformers import CLIPTokenizer, T5TokenizerFast
#################################################################################################
class AutocastLinear(nn.Linear):
"""Same as usual linear layer, but casts its weights to whatever the parameter type is.
This is different from torch.autocast in a way that float16 layer processing float32 input
will return float16 with autocast on, and float32 with this. T5 seems to be fucked
if you do it in full float16 (returning almost all zeros in the final output).
"""
def forward(self, x):
return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)
def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
@ -27,9 +39,9 @@ class Mlp(nn.Module):
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.fc1 = AutocastLinear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)
self.act = act_layer
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
self.fc2 = AutocastLinear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
@ -297,7 +309,6 @@ class T5XXLModel(SDClipModel):
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
def __init__(self):
@ -319,9 +330,9 @@ class T5LayerNorm(torch.nn.Module):
class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device):
super().__init__()
self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
@ -348,10 +359,10 @@ class T5Attention(torch.nn.Module):
def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):
super().__init__()
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)
self.num_heads = num_heads
self.relative_attention_bias = None
if relative_attention_bias:
@ -421,11 +432,16 @@ class T5Attention(torch.nn.Module):
q = self.q(x)
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
mask = past_bias
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask)
else:
mask = None
out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)
return self.o(out), past_bias