stable-diffusion-webui/modules/sd_vae_consistency.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

35 lines
1000 B
Python
Raw Normal View History

2023-11-07 10:52:29 +08:00
"""
Consistency Decoder
Improved decoding for stable diffusion vaes.
https://github.com/openai/consistencydecoder
"""
import os
import torch
import torch.nn as nn
from modules import devices, paths_internal, shared
from consistencydecoder import ConsistencyDecoder
sd_vae_consistency_models = None
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
def decoder_model():
global sd_vae_consistency_models
if getattr(shared.sd_model, 'is_sdxl', False):
raise NotImplementedError("SDXL is not supported for consistency decoder")
if sd_vae_consistency_models is not None:
sd_vae_consistency_models.ckpt.to(devices.device)
return sd_vae_consistency_models
loaded_model = ConsistencyDecoder(devices.device, model_path)
sd_vae_consistency_models = loaded_model
return loaded_model
def unload():
global sd_vae_consistency_models
if sd_vae_consistency_models is not None:
sd_vae_consistency_models.ckpt.to('cpu')