mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-30 10:13:03 +08:00
Add consistency decoder
This commit is contained in:
parent
9c1c0da026
commit
64fd916334
@ -3,7 +3,7 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
|
|||||||
return steps, t_enc
|
return steps, t_enc
|
||||||
|
|
||||||
|
|
||||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
|
||||||
|
|
||||||
|
|
||||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||||
@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
|||||||
elif approximation == 3:
|
elif approximation == 3:
|
||||||
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||||
x_sample = x_sample * 2 - 1
|
x_sample = x_sample * 2 - 1
|
||||||
|
elif approximation == 4:
|
||||||
|
with devices.autocast(), torch.no_grad():
|
||||||
|
x_sample = sd_vae_consistency.decoder_model()(
|
||||||
|
sample.to(devices.device, devices.dtype)/0.18215,
|
||||||
|
schedule=[1.0]
|
||||||
|
)
|
||||||
|
sd_vae_consistency.unload()
|
||||||
else:
|
else:
|
||||||
if model is None:
|
if model is None:
|
||||||
model = shared.sd_model
|
model = shared.sd_model
|
||||||
|
35
modules/sd_vae_consistency.py
Normal file
35
modules/sd_vae_consistency.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Consistency Decoder
|
||||||
|
Improved decoding for stable diffusion vaes.
|
||||||
|
|
||||||
|
https://github.com/openai/consistencydecoder
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modules import devices, paths_internal, shared
|
||||||
|
from consistencydecoder import ConsistencyDecoder
|
||||||
|
|
||||||
|
|
||||||
|
sd_vae_consistency_models = None
|
||||||
|
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
|
||||||
|
|
||||||
|
|
||||||
|
def decoder_model():
|
||||||
|
global sd_vae_consistency_models
|
||||||
|
if getattr(shared.sd_model, 'is_sdxl', False):
|
||||||
|
raise NotImplementedError("SDXL is not supported for consistency decoder")
|
||||||
|
if sd_vae_consistency_models is not None:
|
||||||
|
sd_vae_consistency_models.ckpt.to(devices.device)
|
||||||
|
return sd_vae_consistency_models
|
||||||
|
|
||||||
|
loaded_model = ConsistencyDecoder(devices.device, model_path)
|
||||||
|
sd_vae_consistency_models = loaded_model
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
|
def unload():
|
||||||
|
global sd_vae_consistency_models
|
||||||
|
if sd_vae_consistency_models is not None:
|
||||||
|
sd_vae_consistency_models.ckpt.to('cpu')
|
@ -172,7 +172,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('img2img', "img2img"), {
|
options_templates.update(options_section(('img2img', "img2img"), {
|
||||||
|
@ -32,3 +32,5 @@ torch
|
|||||||
torchdiffeq
|
torchdiffeq
|
||||||
torchsde
|
torchsde
|
||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
|
|
||||||
|
git+https://github.com/openai/consistencydecoder.git
|
||||||
|
@ -30,3 +30,4 @@ torchdiffeq==0.2.3
|
|||||||
torchsde==0.2.6
|
torchsde==0.2.6
|
||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
httpx==0.24.1
|
httpx==0.24.1
|
||||||
|
git+https://github.com/openai/consistencydecoder.git
|
||||||
|
Loading…
Reference in New Issue
Block a user