stable-diffusion-webui/modules/sd_vae_taesd.py

149 lines
5.0 KiB
Python
Raw Normal View History

2023-05-14 12:42:44 +08:00
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
https://github.com/madebyollin/taesd
"""
import os
import torch
import torch.nn as nn
from modules import devices, paths_internal, shared
2023-05-14 12:42:44 +08:00
sd_vae_taesd_models = {}
2023-05-14 12:42:44 +08:00
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
@staticmethod
def forward(x):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
2024-06-16 13:04:31 +08:00
def decoder(latent_channels=4):
2023-05-14 12:42:44 +08:00
return nn.Sequential(
2024-06-16 13:04:31 +08:00
Clamp(), conv(latent_channels, 64), nn.ReLU(),
2023-05-14 12:42:44 +08:00
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
2024-06-16 13:04:31 +08:00
def encoder(latent_channels=4):
2023-08-04 13:38:52 +08:00
return nn.Sequential(
conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
2024-06-16 13:04:31 +08:00
conv(64, latent_channels),
2023-08-04 13:38:52 +08:00
)
class TAESDDecoder(nn.Module):
2023-05-17 17:39:07 +08:00
latent_magnitude = 3
2023-05-14 12:42:44 +08:00
latent_shift = 0.5
2024-06-16 13:04:31 +08:00
def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
2023-05-14 12:42:44 +08:00
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
2024-06-16 13:04:31 +08:00
if latent_channels is None:
2024-08-31 15:46:34 +08:00
latent_channels = 16 if any(typ in str(decoder_path) for typ in ("taesd3", "taef1")) else 4
2024-06-16 13:04:31 +08:00
self.decoder = decoder(latent_channels)
2023-05-14 12:42:44 +08:00
self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
2023-08-04 13:38:52 +08:00
class TAESDEncoder(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
2024-06-16 13:04:31 +08:00
def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
2023-08-04 13:38:52 +08:00
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
2024-06-16 13:04:31 +08:00
if latent_channels is None:
2024-08-31 15:46:34 +08:00
latent_channels = 16 if any(typ in str(encoder_path) for typ in ("taesd3", "taef1")) else 4
2024-06-16 13:04:31 +08:00
self.encoder = encoder(latent_channels)
2023-08-04 13:38:52 +08:00
self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
2023-05-14 12:42:44 +08:00
def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
2023-08-04 13:38:52 +08:00
print(f'Downloading TAESD model to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)
2023-08-04 13:38:52 +08:00
def decoder_model():
2024-06-16 13:04:31 +08:00
if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
2024-08-31 15:46:34 +08:00
elif shared.sd_model.is_flux1:
model_name = "taef1_decoder.pth"
2024-06-16 13:04:31 +08:00
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
model_name = "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
2023-05-14 12:42:44 +08:00
if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
2023-05-14 12:42:44 +08:00
if os.path.exists(model_path):
2023-08-04 13:38:52 +08:00
loaded_model = TAESDDecoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
2023-05-14 12:42:44 +08:00
else:
raise FileNotFoundError('TAESD model not found')
2023-05-14 12:42:44 +08:00
return loaded_model.decoder
2023-08-04 13:38:52 +08:00
def encoder_model():
2024-06-16 13:04:31 +08:00
if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
2024-08-31 15:46:34 +08:00
elif shared.sd_model.is_flux1:
model_name = "taef1_encoder.pth"
2024-06-16 13:04:31 +08:00
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else:
model_name = "taesd_encoder.pth"
2023-08-04 13:38:52 +08:00
loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path):
loaded_model = TAESDEncoder(model_path)
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else:
raise FileNotFoundError('TAESD model not found')
return loaded_model.encoder