diff --git a/modules/models/sd3/other_impls.py b/modules/models/sd3/other_impls.py index 2c76e1cb6..cd10edc8d 100644 --- a/modules/models/sd3/other_impls.py +++ b/modules/models/sd3/other_impls.py @@ -1,6 +1,7 @@ ### This file contains impls for underlying related models (CLIP, T5, etc) -import torch, math +import torch +import math from torch import nn from transformers import CLIPTokenizer, T5TokenizerFast @@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None): """Convenience wrapper around a basic attention operation""" b, _, dim_head = q.shape dim_head //= heads - q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)] out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) return out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module): if intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) + for i, layer in enumerate(self.layers): + x = layer(x, mask) if i == intermediate_output: intermediate = x.clone() return x, intermediate @@ -215,7 +216,7 @@ class SD3Tokenizer: class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): - tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = [a[0] for a in token_weight_pairs[0]] out, pooled = self([tokens]) if pooled is not None: first_pooled = pooled[0:1].cpu() @@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = ["last", "pooled", "hidden"] def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True): + special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True): super().__init__() assert layer in self.LAYERS self.transformer = model_class(textmodel_json_config, dtype, device) @@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): param.requires_grad = False self.layer = layer self.layer_idx = None - self.special_tokens = special_tokens + self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407} self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled @@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module): intermediate = None x = self.embed_tokens(input_ids) past_bias = None - for i, l in enumerate(self.block): - x, past_bias = l(x, past_bias) + for i, layer in enumerate(self.block): + x, past_bias = layer(x, past_bias) if i == intermediate_output: intermediate = x.clone() x = self.final_layer_norm(x) diff --git a/modules/models/sd3/sd3_impls.py b/modules/models/sd3/sd3_impls.py index 91dad66d0..e2f6cad5b 100644 --- a/modules/models/sd3/sd3_impls.py +++ b/modules/models/sd3/sd3_impls.py @@ -1,6 +1,8 @@ ### Impls of the SD3 core diffusion model and VAE -import torch, math, einops +import torch +import math +import einops from modules.models.sd3.mmdit import MMDiT from PIL import Image @@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module): k = self.k(hidden) v = self.v(hidden) b, c, h, w = q.shape - q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) + q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)] hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) hidden = self.proj_out(hidden) @@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module): attn = torch.nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] - for i_block in range(num_res_blocks): + for _ in range(num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) block_in = block_out down = torch.nn.Module() @@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module): for i_level in reversed(range(self.num_resolutions)): block = torch.nn.ModuleList() block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): + for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) block_in = block_out up = torch.nn.Module()