2022-10-14 13:00:38 +08:00
import collections
2022-09-17 17:05:04 +08:00
import os . path
import sys
2022-11-01 15:01:49 +08:00
import gc
2022-09-17 17:05:04 +08:00
from collections import namedtuple
import torch
2022-10-28 11:49:39 +08:00
import re
2022-11-27 19:46:40 +08:00
import safetensors . torch
2022-09-17 17:05:04 +08:00
from omegaconf import OmegaConf
2022-12-09 08:14:35 +08:00
from os import mkdir
from urllib import request
import ldm . modules . midas as midas
2022-09-17 17:05:04 +08:00
from ldm . util import instantiate_from_config
2022-10-30 22:54:31 +08:00
from modules import shared , modelloader , devices , script_callbacks , sd_vae
2022-09-28 00:01:13 +08:00
from modules . paths import models_path
2022-10-20 04:47:45 +08:00
from modules . sd_hijack_inpainting import do_inpainting_hijack , should_hijack_inpainting
2022-09-28 00:01:13 +08:00
model_dir = " Stable-diffusion "
2022-09-30 16:42:40 +08:00
model_path = os . path . abspath ( os . path . join ( models_path , model_dir ) )
2022-09-17 17:05:04 +08:00
2023-01-04 17:47:42 +08:00
CheckpointInfo = namedtuple ( " CheckpointInfo " , [ ' filename ' , ' title ' , ' hash ' , ' model_name ' ] )
2022-09-17 17:05:04 +08:00
checkpoints_list = { }
2022-10-14 13:00:38 +08:00
checkpoints_loaded = collections . OrderedDict ( )
2022-09-17 17:05:04 +08:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
2022-10-16 23:53:56 +08:00
from transformers import logging , CLIPModel
2022-09-17 17:05:04 +08:00
logging . set_verbosity_error ( )
except Exception :
pass
2022-10-03 02:09:10 +08:00
def setup_model ( ) :
2022-09-28 00:01:13 +08:00
if not os . path . exists ( model_path ) :
os . makedirs ( model_path )
2022-10-03 02:09:10 +08:00
2022-09-30 08:59:36 +08:00
list_models ( )
2022-12-09 08:14:35 +08:00
enable_midas_autodownload ( )
2022-09-30 08:59:36 +08:00
2022-10-28 11:49:39 +08:00
def checkpoint_tiles ( ) :
convert = lambda name : int ( name ) if name . isdigit ( ) else name . lower ( )
alphanumeric_key = lambda key : [ convert ( c ) for c in re . split ( ' ([0-9]+) ' , key ) ]
return sorted ( [ x . title for x in checkpoints_list . values ( ) ] , key = alphanumeric_key )
2022-09-29 05:59:44 +08:00
2023-01-04 17:47:42 +08:00
def find_checkpoint_config ( info ) :
config = os . path . splitext ( info . filename ) [ 0 ] + " .yaml "
if os . path . exists ( config ) :
return config
return shared . cmd_opts . config
2022-09-17 17:05:04 +08:00
def list_models ( ) :
checkpoints_list . clear ( )
2022-11-21 21:04:25 +08:00
model_list = modelloader . load_models ( model_path = model_path , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] )
2022-09-17 17:05:04 +08:00
2022-09-30 16:42:40 +08:00
def modeltitle ( path , shorthash ) :
2022-09-17 17:05:04 +08:00
abspath = os . path . abspath ( path )
2022-10-03 02:22:20 +08:00
if shared . cmd_opts . ckpt_dir is not None and abspath . startswith ( shared . cmd_opts . ckpt_dir ) :
name = abspath . replace ( shared . cmd_opts . ckpt_dir , ' ' )
2022-09-30 16:42:40 +08:00
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
2022-09-17 17:05:04 +08:00
else :
name = os . path . basename ( path )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
2022-09-29 05:59:44 +08:00
shortname = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
2022-09-30 16:42:40 +08:00
return f ' { name } [ { shorthash } ] ' , shortname
2022-09-17 17:05:04 +08:00
cmd_ckpt = shared . cmd_opts . ckpt
if os . path . exists ( cmd_ckpt ) :
h = model_hash ( cmd_ckpt )
2022-09-30 16:42:40 +08:00
title , short_model_name = modeltitle ( cmd_ckpt , h )
2023-01-04 17:47:42 +08:00
checkpoints_list [ title ] = CheckpointInfo ( cmd_ckpt , title , h , short_model_name )
2022-10-02 22:24:50 +08:00
shared . opts . data [ ' sd_model_checkpoint ' ] = title
2022-09-17 17:05:04 +08:00
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
2022-09-28 00:01:13 +08:00
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
for filename in model_list :
h = model_hash ( filename )
2022-09-30 16:42:40 +08:00
title , short_model_name = modeltitle ( filename , h )
2022-10-09 04:26:48 +08:00
2023-01-04 17:47:42 +08:00
checkpoints_list [ title ] = CheckpointInfo ( filename , title , h , short_model_name )
2022-09-30 16:42:40 +08:00
2022-09-17 17:05:04 +08:00
2022-09-29 05:30:09 +08:00
def get_closet_checkpoint_match ( searchString ) :
2022-09-30 02:08:03 +08:00
applicable = sorted ( [ info for info in checkpoints_list . values ( ) if searchString in info . title ] , key = lambda x : len ( x . title ) )
2022-09-30 16:42:40 +08:00
if len ( applicable ) > 0 :
2022-09-29 05:30:09 +08:00
return applicable [ 0 ]
return None
2022-09-17 17:05:04 +08:00
2022-09-30 16:42:40 +08:00
2022-09-17 17:05:04 +08:00
def model_hash ( filename ) :
try :
with open ( filename , " rb " ) as file :
import hashlib
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
return ' NOFILE '
def select_checkpoint ( ) :
model_checkpoint = shared . opts . sd_model_checkpoint
fix: fallback model_checkpoint if it's empty
This fixes the following error when SD attempts to start with a deleted checkpoint:
```
Traceback (most recent call last):
File "D:\Web\stable-diffusion-webui\launch.py", line 295, in <module>
start()
File "D:\Web\stable-diffusion-webui\launch.py", line 290, in start
webui.webui()
File "D:\Web\stable-diffusion-webui\webui.py", line 132, in webui
initialize()
File "D:\Web\stable-diffusion-webui\webui.py", line 62, in initialize
modules.sd_models.load_model()
File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 283, in load_model
checkpoint_info = checkpoint_info or select_checkpoint()
File "D:\Web\stable-diffusion-webui\modules\sd_models.py", line 117, in select_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
TypeError: unhashable type: 'list'
```
2022-12-12 00:08:51 +08:00
2022-09-17 17:05:04 +08:00
checkpoint_info = checkpoints_list . get ( model_checkpoint , None )
if checkpoint_info is not None :
return checkpoint_info
if len ( checkpoints_list ) == 0 :
2022-12-25 03:35:29 +08:00
print ( " No checkpoints found. When searching for checkpoints, looked at: " , file = sys . stderr )
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt is not None :
print ( f " - file { os . path . abspath ( shared . cmd_opts . ckpt ) } " , file = sys . stderr )
print ( f " - directory { model_path } " , file = sys . stderr )
if shared . cmd_opts . ckpt_dir is not None :
print ( f " - directory { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } " , file = sys . stderr )
2022-12-25 03:35:29 +08:00
print ( " Can ' t run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit. " , file = sys . stderr )
2022-09-19 04:52:01 +08:00
exit ( 1 )
2022-09-17 17:05:04 +08:00
checkpoint_info = next ( iter ( checkpoints_list . values ( ) ) )
if model_checkpoint is not None :
print ( f " Checkpoint { model_checkpoint } not found; loading fallback { checkpoint_info . title } " , file = sys . stderr )
return checkpoint_info
2022-10-19 13:42:22 +08:00
chckpoint_dict_replacements = {
' cond_stage_model.transformer.embeddings. ' : ' cond_stage_model.transformer.text_model.embeddings. ' ,
' cond_stage_model.transformer.encoder. ' : ' cond_stage_model.transformer.text_model.encoder. ' ,
' cond_stage_model.transformer.final_layer_norm. ' : ' cond_stage_model.transformer.text_model.final_layer_norm. ' ,
}
def transform_checkpoint_dict_key ( k ) :
for text , replacement in chckpoint_dict_replacements . items ( ) :
if k . startswith ( text ) :
k = replacement + k [ len ( text ) : ]
return k
2022-10-09 15:23:31 +08:00
def get_state_dict_from_checkpoint ( pl_sd ) :
2022-11-28 13:39:59 +08:00
pl_sd = pl_sd . pop ( " state_dict " , pl_sd )
pl_sd . pop ( " state_dict " , None )
2022-10-19 13:42:22 +08:00
sd = { }
for k , v in pl_sd . items ( ) :
new_key = transform_checkpoint_dict_key ( k )
if new_key is not None :
sd [ new_key ] = v
2022-10-09 15:23:31 +08:00
2022-10-19 17:45:30 +08:00
pl_sd . clear ( )
pl_sd . update ( sd )
return pl_sd
2022-10-09 15:23:31 +08:00
2022-11-27 20:51:29 +08:00
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
2022-12-21 20:45:58 +08:00
device = map_location or shared . weight_load_location
if device is None :
2023-01-04 20:09:53 +08:00
device = devices . get_cuda_device_string ( ) if torch . cuda . is_available ( ) else " cpu "
2022-12-21 20:45:58 +08:00
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
2022-11-27 20:51:29 +08:00
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
if print_global_state and " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = get_state_dict_from_checkpoint ( pl_sd )
return sd
2022-10-31 16:19:34 +08:00
def load_model_weights ( model , checkpoint_info , vae_file = " auto " ) :
2022-10-09 04:26:48 +08:00
checkpoint_file = checkpoint_info . filename
sd_model_hash = checkpoint_info . hash
2022-11-09 11:54:21 +08:00
cache_enabled = shared . opts . sd_checkpoint_cache > 0
if cache_enabled and checkpoint_info in checkpoints_loaded :
# use checkpoint cache
2022-11-13 11:55:47 +08:00
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
2022-11-09 11:54:21 +08:00
model . load_state_dict ( checkpoints_loaded [ checkpoint_info ] )
else :
# load from file
2022-10-14 13:00:38 +08:00
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_file } " )
2022-09-17 17:05:04 +08:00
2022-11-27 20:51:29 +08:00
sd = read_state_dict ( checkpoint_file )
2022-10-28 04:59:16 +08:00
model . load_state_dict ( sd , strict = False )
del sd
2022-11-09 11:54:21 +08:00
if cache_enabled :
# cache newly loaded model
checkpoints_loaded [ checkpoint_info ] = model . state_dict ( ) . copy ( )
2022-09-17 17:05:04 +08:00
2022-10-14 13:00:38 +08:00
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
2022-09-17 17:05:04 +08:00
2022-10-14 13:00:38 +08:00
if not shared . cmd_opts . no_half :
2022-11-02 19:41:29 +08:00
vae = model . first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
2022-10-14 13:00:38 +08:00
model . half ( )
2022-11-02 19:41:29 +08:00
model . first_stage_model = vae
2022-09-17 17:05:04 +08:00
2022-10-14 13:00:38 +08:00
devices . dtype = torch . float32 if shared . cmd_opts . no_half else torch . float16
devices . dtype_vae = torch . float32 if shared . cmd_opts . no_half or shared . cmd_opts . no_half_vae else torch . float16
2022-09-17 17:05:04 +08:00
2022-10-14 13:00:38 +08:00
model . first_stage_model . to ( devices . dtype_vae )
2022-11-02 19:41:29 +08:00
2022-11-09 11:54:21 +08:00
# clean up cache if limit is reached
if cache_enabled :
2022-11-09 14:17:09 +08:00
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache + 1 : # we need to count the current model
2022-10-31 17:27:27 +08:00
checkpoints_loaded . popitem ( last = False ) # LRU
2022-09-17 17:05:04 +08:00
model . sd_model_hash = sd_model_hash
2022-10-09 03:12:24 +08:00
model . sd_model_checkpoint = checkpoint_file
2022-10-09 04:26:48 +08:00
model . sd_checkpoint_info = checkpoint_info
2022-09-17 17:05:04 +08:00
2023-01-02 05:38:09 +08:00
model . logvar = model . logvar . to ( devices . device ) # fix for training
2022-11-13 12:11:14 +08:00
sd_vae . delete_base_vae ( )
2022-11-03 12:10:53 +08:00
sd_vae . clear_loaded_vae ( )
2022-11-13 11:55:47 +08:00
vae_file = sd_vae . resolve_vae ( checkpoint_file , vae_file = vae_file )
2022-11-02 13:51:46 +08:00
sd_vae . load_vae ( model , vae_file )
2022-09-17 17:05:04 +08:00
2022-12-09 08:14:35 +08:00
def enable_midas_autodownload ( ) :
"""
Gives the ldm . modules . midas . api . load_model function automatic downloading .
When the 512 - depth - ema model , and other future models like it , is loaded ,
it calls midas . api . load_model to load the associated midas depth model .
This function applies a wrapper to download the model to the correct
location automatically .
"""
midas_path = os . path . join ( models_path , ' midas ' )
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k , v in midas . api . ISL_PATHS . items ( ) :
file_name = os . path . basename ( v )
midas . api . ISL_PATHS [ k ] = os . path . join ( midas_path , file_name )
midas_urls = {
" dpt_large " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " ,
" dpt_hybrid " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " ,
" midas_v21 " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " ,
" midas_v21_small " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " ,
}
midas . api . load_model_inner = midas . api . load_model
def load_model_wrapper ( model_type ) :
path = midas . api . ISL_PATHS [ model_type ]
if not os . path . exists ( path ) :
if not os . path . exists ( midas_path ) :
mkdir ( midas_path )
print ( f " Downloading midas model weights for { model_type } to { path } " )
request . urlretrieve ( midas_urls [ model_type ] , path )
print ( f " { model_type } downloaded " )
return midas . api . load_model_inner ( model_type )
midas . api . load_model = load_model_wrapper
2023-01-04 17:35:07 +08:00
2022-10-21 07:01:27 +08:00
def load_model ( checkpoint_info = None ) :
2022-09-17 17:05:04 +08:00
from modules import lowvram , sd_hijack
2022-10-21 07:01:27 +08:00
checkpoint_info = checkpoint_info or select_checkpoint ( )
2023-01-04 17:47:42 +08:00
checkpoint_config = find_checkpoint_config ( checkpoint_info )
2022-09-17 17:05:04 +08:00
2023-01-04 17:47:42 +08:00
if checkpoint_config != shared . cmd_opts . config :
print ( f " Loading config from: { checkpoint_config } " )
2022-10-09 04:26:48 +08:00
2022-11-01 15:01:49 +08:00
if shared . sd_model :
sd_hijack . model_hijack . undo_hijack ( shared . sd_model )
shared . sd_model = None
gc . collect ( )
devices . torch_gc ( )
2023-01-04 17:47:42 +08:00
sd_config = OmegaConf . load ( checkpoint_config )
2022-10-20 04:47:45 +08:00
if should_hijack_inpainting ( checkpoint_info ) :
# Hardcoded config for now...
sd_config . model . target = " ldm.models.diffusion.ddpm.LatentInpaintDiffusion "
sd_config . model . params . conditioning_key = " hybrid "
sd_config . model . params . unet_config . params . in_channels = 9
2022-12-11 00:29:26 +08:00
sd_config . model . params . finetune_keys = None
2022-10-20 04:47:45 +08:00
2022-12-11 23:19:46 +08:00
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
2022-10-21 04:28:43 +08:00
do_inpainting_hijack ( )
2022-11-01 15:01:49 +08:00
2022-11-27 02:28:44 +08:00
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
2022-09-17 17:05:04 +08:00
sd_model = instantiate_from_config ( sd_config . model )
2023-01-04 17:35:07 +08:00
2022-10-09 04:26:48 +08:00
load_model_weights ( sd_model , checkpoint_info )
2022-09-17 17:05:04 +08:00
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . setup_for_low_vram ( sd_model , shared . cmd_opts . medvram )
else :
sd_model . to ( shared . device )
sd_hijack . model_hijack . hijack ( sd_model )
sd_model . eval ( )
2022-10-22 17:23:45 +08:00
shared . sd_model = sd_model
2023-01-03 23:39:14 +08:00
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
2022-10-23 01:15:12 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2022-12-25 03:35:29 +08:00
print ( " Model loaded. " )
2023-01-01 00:27:02 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2022-11-02 13:51:46 +08:00
def reload_model_weights ( sd_model = None , info = None ) :
2022-09-29 20:40:28 +08:00
from modules import lowvram , devices , sd_hijack
2022-09-17 18:49:36 +08:00
checkpoint_info = info or select_checkpoint ( )
2023-01-04 17:35:07 +08:00
2022-11-01 15:01:49 +08:00
if not sd_model :
sd_model = shared . sd_model
2022-09-17 17:05:04 +08:00
2023-01-04 17:35:07 +08:00
current_checkpoint_info = sd_model . sd_checkpoint_info
2023-01-04 17:47:42 +08:00
checkpoint_config = find_checkpoint_config ( current_checkpoint_info )
2023-01-04 17:35:07 +08:00
2022-10-09 03:12:24 +08:00
if sd_model . sd_model_checkpoint == checkpoint_info . filename :
2022-09-17 17:05:04 +08:00
return
2023-01-04 17:47:42 +08:00
if checkpoint_config != find_checkpoint_config ( checkpoint_info ) or should_hijack_inpainting ( checkpoint_info ) != should_hijack_inpainting ( sd_model . sd_checkpoint_info ) :
2022-11-01 15:01:49 +08:00
del sd_model
2022-10-14 13:00:38 +08:00
checkpoints_loaded . clear ( )
2022-10-22 17:23:45 +08:00
load_model ( checkpoint_info )
2022-10-09 18:23:30 +08:00
return shared . sd_model
2022-10-09 04:26:48 +08:00
2022-09-17 17:05:04 +08:00
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
else :
sd_model . to ( devices . cpu )
2022-09-29 20:40:28 +08:00
sd_hijack . model_hijack . undo_hijack ( sd_model )
2023-01-04 17:35:07 +08:00
try :
load_model_weights ( sd_model , checkpoint_info )
except Exception as e :
print ( " Failed to load checkpoint, restoring previous " )
load_model_weights ( sd_model , current_checkpoint_info )
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
script_callbacks . model_loaded_callback ( sd_model )
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
sd_model . to ( devices . device )
2022-09-17 17:05:04 +08:00
2022-12-25 03:35:29 +08:00
print ( " Weights loaded. " )
2023-01-04 17:35:07 +08:00
2022-09-17 17:05:04 +08:00
return sd_model