2022-10-14 13:00:38 +08:00
import collections
2022-09-17 17:05:04 +08:00
import os . path
import sys
2022-11-01 15:01:49 +08:00
import gc
2023-05-02 14:08:00 +08:00
import threading
2022-09-17 17:05:04 +08:00
import torch
2022-10-28 11:49:39 +08:00
import re
2022-11-27 19:46:40 +08:00
import safetensors . torch
2022-09-17 17:05:04 +08:00
from omegaconf import OmegaConf
2022-12-09 08:14:35 +08:00
from os import mkdir
from urllib import request
import ldm . modules . midas as midas
2022-09-17 17:05:04 +08:00
from ldm . util import instantiate_from_config
2023-08-01 12:08:11 +08:00
from modules import paths , shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes , sd_models_config , sd_unet , sd_models_xl , cache
2023-01-27 16:28:12 +08:00
from modules . timer import Timer
2023-04-04 15:26:44 +08:00
import tomesd
2022-09-28 00:01:13 +08:00
model_dir = " Stable-diffusion "
2023-01-26 00:15:42 +08:00
model_path = os . path . abspath ( os . path . join ( paths . models_path , model_dir ) )
2022-09-17 17:05:04 +08:00
checkpoints_list = { }
2023-07-03 17:17:20 +08:00
checkpoint_aliases = { }
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
2022-10-14 13:00:38 +08:00
checkpoints_loaded = collections . OrderedDict ( )
2022-09-17 17:05:04 +08:00
2023-01-14 14:56:59 +08:00
class CheckpointInfo :
def __init__ ( self , filename ) :
self . filename = filename
abspath = os . path . abspath ( filename )
2023-08-01 12:08:11 +08:00
self . is_safetensors = os . path . splitext ( filename ) [ 1 ] . lower ( ) == " .safetensors "
2023-01-14 14:56:59 +08:00
if shared . cmd_opts . ckpt_dir is not None and abspath . startswith ( shared . cmd_opts . ckpt_dir ) :
name = abspath . replace ( shared . cmd_opts . ckpt_dir , ' ' )
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( filename )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
2023-08-01 12:08:11 +08:00
def read_metadata ( ) :
metadata = read_metadata_from_safetensors ( filename )
self . modelspec_thumbnail = metadata . pop ( ' modelspec.thumbnail ' , None )
return metadata
self . metadata = { }
if self . is_safetensors :
try :
self . metadata = cache . cached_data_for_file ( ' safetensors-metadata ' , " checkpoint/ " + name , filename , read_metadata )
except Exception as e :
errors . display ( e , f " reading metadata for { filename } " )
2023-01-19 23:58:08 +08:00
self . name = name
2023-01-29 15:20:19 +08:00
self . name_for_extra = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
2023-01-14 14:56:59 +08:00
self . model_name = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
self . hash = model_hash ( filename )
2023-01-14 20:55:40 +08:00
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256_from_cache ( self . filename , f " checkpoint/ { name } " )
2023-01-14 20:55:40 +08:00
self . shorthash = self . sha256 [ 0 : 10 ] if self . sha256 else None
2023-01-19 23:58:08 +08:00
self . title = name if self . shorthash is None else f ' { name } [ { self . shorthash } ] '
2023-07-30 18:48:27 +08:00
self . short_title = self . name_for_extra if self . shorthash is None else f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-01-19 23:58:08 +08:00
2023-08-04 03:46:57 +08:00
self . ids = [ self . hash , self . model_name , self . title , name , self . name_for_extra , f ' { name } [ { self . hash } ] ' ] + ( [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' ] if self . shorthash else [ ] )
2023-01-14 14:56:59 +08:00
def register ( self ) :
checkpoints_list [ self . title ] = self
for id in self . ids :
2023-07-03 17:17:20 +08:00
checkpoint_aliases [ id ] = self
2023-01-14 14:56:59 +08:00
def calculate_shorthash ( self ) :
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256 ( self . filename , f " checkpoint/ { self . name } " )
2023-02-04 16:38:56 +08:00
if self . sha256 is None :
return
2023-01-14 14:56:59 +08:00
self . shorthash = self . sha256 [ 0 : 10 ]
if self . shorthash not in self . ids :
2023-02-04 20:23:16 +08:00
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' ]
2023-01-14 14:56:59 +08:00
2023-08-01 13:27:54 +08:00
checkpoints_list . pop ( self . title , None )
2023-01-19 23:58:08 +08:00
self . title = f ' { self . name } [ { self . shorthash } ] '
2023-07-30 18:48:27 +08:00
self . short_title = f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-02-04 20:23:16 +08:00
self . register ( )
2023-01-19 23:58:08 +08:00
2023-01-14 14:56:59 +08:00
return self . shorthash
2022-09-17 17:05:04 +08:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
2023-05-10 14:02:23 +08:00
from transformers import logging , CLIPModel # noqa: F401
2022-09-17 17:05:04 +08:00
logging . set_verbosity_error ( )
except Exception :
pass
2022-10-03 02:09:10 +08:00
def setup_model ( ) :
2023-05-29 15:18:15 +08:00
os . makedirs ( model_path , exist_ok = True )
2022-10-03 02:09:10 +08:00
2022-12-09 08:14:35 +08:00
enable_midas_autodownload ( )
2022-09-30 08:59:36 +08:00
2023-07-30 18:48:27 +08:00
def checkpoint_tiles ( use_short = False ) :
return [ x . short_title if use_short else x . title for x in checkpoints_list . values ( ) ]
2022-09-29 05:59:44 +08:00
2022-09-17 17:05:04 +08:00
def list_models ( ) :
checkpoints_list . clear ( )
2023-07-03 17:17:20 +08:00
checkpoint_aliases . clear ( )
2022-09-17 17:05:04 +08:00
cmd_ckpt = shared . cmd_opts . ckpt
2023-02-19 19:49:07 +08:00
if shared . cmd_opts . no_download_sd_model or cmd_ckpt != shared . sd_model_file or os . path . exists ( cmd_ckpt ) :
2023-02-19 19:37:40 +08:00
model_url = None
else :
model_url = " https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors "
model_list = modelloader . load_models ( model_path = model_path , model_url = model_url , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , download_name = " v1-5-pruned-emaonly.safetensors " , ext_blacklist = [ " .vae.ckpt " , " .vae.safetensors " ] )
2022-09-17 17:05:04 +08:00
if os . path . exists ( cmd_ckpt ) :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( cmd_ckpt )
checkpoint_info . register ( )
shared . opts . data [ ' sd_model_checkpoint ' ] = checkpoint_info . title
2022-09-17 17:05:04 +08:00
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
2022-09-28 00:01:13 +08:00
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
2023-01-14 14:56:59 +08:00
2023-07-30 18:48:27 +08:00
for filename in model_list :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( filename )
checkpoint_info . register ( )
2022-10-09 04:26:48 +08:00
2023-07-30 18:48:27 +08:00
re_strip_checksum = re . compile ( r " \ s* \ [[^]]+] \ s*$ " )
2023-01-14 14:56:59 +08:00
def get_closet_checkpoint_match ( search_string ) :
2023-07-03 17:17:20 +08:00
checkpoint_info = checkpoint_aliases . get ( search_string , None )
2023-01-14 14:56:59 +08:00
if checkpoint_info is not None :
2023-01-14 15:25:21 +08:00
return checkpoint_info
2022-09-30 16:42:40 +08:00
2023-01-14 14:56:59 +08:00
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-17 17:05:04 +08:00
2023-07-30 18:48:27 +08:00
search_string_without_checksum = re . sub ( re_strip_checksum , ' ' , search_string )
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string_without_checksum in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-29 05:30:09 +08:00
return None
2022-09-17 17:05:04 +08:00
2022-09-30 16:42:40 +08:00
2022-09-17 17:05:04 +08:00
def model_hash ( filename ) :
2023-01-14 14:56:59 +08:00
""" old hash that only looks at a small part of the file and is prone to collisions """
2022-09-17 17:05:04 +08:00
try :
with open ( filename , " rb " ) as file :
import hashlib
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
return ' NOFILE '
def select_checkpoint ( ) :
2023-05-27 03:08:53 +08:00
""" Raises `FileNotFoundError` if no checkpoints are found. """
2022-09-17 17:05:04 +08:00
model_checkpoint = shared . opts . sd_model_checkpoint
2023-05-11 23:28:15 +08:00
2023-07-03 17:17:20 +08:00
checkpoint_info = checkpoint_aliases . get ( model_checkpoint , None )
2022-09-17 17:05:04 +08:00
if checkpoint_info is not None :
return checkpoint_info
if len ( checkpoints_list ) == 0 :
2023-05-27 03:08:53 +08:00
error_message = " No checkpoints found. When searching for checkpoints, looked at: "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - file { os . path . abspath ( shared . cmd_opts . ckpt ) } "
error_message + = f " \n - directory { model_path } "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt_dir is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - directory { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } "
error_message + = " Can ' t run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. "
raise FileNotFoundError ( error_message )
2022-09-17 17:05:04 +08:00
checkpoint_info = next ( iter ( checkpoints_list . values ( ) ) )
if model_checkpoint is not None :
print ( f " Checkpoint { model_checkpoint } not found; loading fallback { checkpoint_info . title } " , file = sys . stderr )
return checkpoint_info
2023-03-23 13:28:08 +08:00
checkpoint_dict_replacements = {
2022-10-19 13:42:22 +08:00
' cond_stage_model.transformer.embeddings. ' : ' cond_stage_model.transformer.text_model.embeddings. ' ,
' cond_stage_model.transformer.encoder. ' : ' cond_stage_model.transformer.text_model.encoder. ' ,
' cond_stage_model.transformer.final_layer_norm. ' : ' cond_stage_model.transformer.text_model.final_layer_norm. ' ,
}
def transform_checkpoint_dict_key ( k ) :
2023-03-23 13:28:08 +08:00
for text , replacement in checkpoint_dict_replacements . items ( ) :
2022-10-19 13:42:22 +08:00
if k . startswith ( text ) :
k = replacement + k [ len ( text ) : ]
return k
2022-10-09 15:23:31 +08:00
def get_state_dict_from_checkpoint ( pl_sd ) :
2022-11-28 13:39:59 +08:00
pl_sd = pl_sd . pop ( " state_dict " , pl_sd )
pl_sd . pop ( " state_dict " , None )
2022-10-19 13:42:22 +08:00
sd = { }
for k , v in pl_sd . items ( ) :
new_key = transform_checkpoint_dict_key ( k )
if new_key is not None :
sd [ new_key ] = v
2022-10-09 15:23:31 +08:00
2022-10-19 17:45:30 +08:00
pl_sd . clear ( )
pl_sd . update ( sd )
return pl_sd
2022-10-09 15:23:31 +08:00
2023-03-14 14:10:26 +08:00
def read_metadata_from_safetensors ( filename ) :
import json
with open ( filename , mode = " rb " ) as file :
metadata_len = file . read ( 8 )
metadata_len = int . from_bytes ( metadata_len , " little " )
json_start = file . read ( 2 )
assert metadata_len > 2 and json_start in ( b ' { " ' , b " { ' " ) , f " { filename } is not a safetensors file "
json_data = json_start + file . read ( metadata_len - 2 )
json_obj = json . loads ( json_data )
res = { }
for k , v in json_obj . get ( " __metadata__ " , { } ) . items ( ) :
res [ k ] = v
2023-03-14 16:22:29 +08:00
if isinstance ( v , str ) and v [ 0 : 1 ] == ' { ' :
2023-03-14 14:10:26 +08:00
try :
res [ k ] = json . loads ( v )
2023-05-10 12:52:45 +08:00
except Exception :
2023-03-14 14:10:26 +08:00
pass
return res
2022-11-27 20:51:29 +08:00
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
2023-06-27 14:19:04 +08:00
device = map_location or shared . weight_load_location or devices . get_optimal_device_name ( )
2023-06-17 00:10:15 +08:00
if not shared . opts . disable_mmap_load_safetensors :
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
else :
pl_sd = safetensors . torch . load ( open ( checkpoint_file , ' rb ' ) . read ( ) )
2023-06-27 14:19:04 +08:00
pl_sd = { k : v . to ( device ) for k , v in pl_sd . items ( ) }
2022-11-27 20:51:29 +08:00
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
if print_global_state and " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = get_state_dict_from_checkpoint ( pl_sd )
return sd
2023-01-27 16:28:12 +08:00
def get_checkpoint_state_dict ( checkpoint_info : CheckpointInfo , timer ) :
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
return checkpoints_loaded [ checkpoint_info ]
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
res = read_state_dict ( checkpoint_info . filename )
timer . record ( " load weights from disk " )
return res
def load_model_weights ( model , checkpoint_info : CheckpointInfo , state_dict , timer ) :
2023-01-14 14:56:59 +08:00
sd_model_hash = checkpoint_info . calculate_shorthash ( )
2023-01-27 16:28:12 +08:00
timer . record ( " calculate hash " )
2023-01-28 21:23:49 +08:00
shared . opts . data [ " sd_model_checkpoint " ] = checkpoint_info . title
2022-10-09 04:26:48 +08:00
2023-01-27 16:28:12 +08:00
if state_dict is None :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-09 11:54:21 +08:00
2023-07-14 02:17:50 +08:00
model . is_sdxl = hasattr ( model , ' conditioner ' )
2023-07-17 23:56:14 +08:00
model . is_sd2 = not model . is_sdxl and hasattr ( model . cond_stage_model , ' model ' )
model . is_sd1 = not model . is_sdxl and not model . is_sd2
2023-07-14 02:17:50 +08:00
if model . is_sdxl :
2023-07-12 02:16:43 +08:00
sd_models_xl . extend_sdxl ( model )
2023-01-27 16:28:12 +08:00
model . load_state_dict ( state_dict , strict = False )
timer . record ( " apply weights to model " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
if shared . opts . sd_checkpoint_cache > 0 :
# cache newly loaded model
2023-08-04 11:43:27 +08:00
checkpoints_loaded [ checkpoint_info ] = state_dict
del state_dict
2023-01-27 16:28:12 +08:00
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
timer . record ( " apply channels_last " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
if not shared . cmd_opts . no_half :
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
2022-11-02 19:41:29 +08:00
2023-01-27 16:28:12 +08:00
model . half ( )
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
2022-11-02 19:41:29 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " apply half() " )
2022-09-17 17:05:04 +08:00
2023-07-18 11:39:38 +08:00
devices . dtype_unet = torch . float16 if model . is_sdxl and not shared . cmd_opts . no_half else model . model . diffusion_model . dtype
2023-01-27 16:28:12 +08:00
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
model . first_stage_model . to ( devices . dtype_vae )
timer . record ( " apply dtype to VAE " )
2022-11-02 19:41:29 +08:00
2022-11-09 11:54:21 +08:00
# clean up cache if limit is reached
2023-01-27 16:28:12 +08:00
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache :
checkpoints_loaded . popitem ( last = False )
2022-10-31 17:27:27 +08:00
2022-09-17 17:05:04 +08:00
model . sd_model_hash = sd_model_hash
2023-01-14 14:56:59 +08:00
model . sd_model_checkpoint = checkpoint_info . filename
2022-10-09 04:26:48 +08:00
model . sd_checkpoint_info = checkpoint_info
2023-01-14 20:55:40 +08:00
shared . opts . data [ " sd_checkpoint_hash " ] = checkpoint_info . sha256
2022-09-17 17:05:04 +08:00
2023-07-12 02:16:43 +08:00
if hasattr ( model , ' logvar ' ) :
model . logvar = model . logvar . to ( devices . device ) # fix for training
2023-01-02 05:38:09 +08:00
2022-11-13 12:11:14 +08:00
sd_vae . delete_base_vae ( )
2022-11-03 12:10:53 +08:00
sd_vae . clear_loaded_vae ( )
2023-01-15 00:56:09 +08:00
vae_file , vae_source = sd_vae . resolve_vae ( checkpoint_info . filename )
sd_vae . load_vae ( model , vae_file , vae_source )
2023-01-27 16:28:12 +08:00
timer . record ( " load VAE " )
2022-11-02 13:51:46 +08:00
2022-09-17 17:05:04 +08:00
2022-12-09 08:14:35 +08:00
def enable_midas_autodownload ( ) :
"""
Gives the ldm . modules . midas . api . load_model function automatic downloading .
When the 512 - depth - ema model , and other future models like it , is loaded ,
it calls midas . api . load_model to load the associated midas depth model .
This function applies a wrapper to download the model to the correct
location automatically .
"""
2023-01-26 00:15:42 +08:00
midas_path = os . path . join ( paths . models_path , ' midas ' )
2022-12-09 08:14:35 +08:00
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k , v in midas . api . ISL_PATHS . items ( ) :
file_name = os . path . basename ( v )
midas . api . ISL_PATHS [ k ] = os . path . join ( midas_path , file_name )
midas_urls = {
" dpt_large " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " ,
" dpt_hybrid " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " ,
" midas_v21 " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " ,
" midas_v21_small " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " ,
}
midas . api . load_model_inner = midas . api . load_model
def load_model_wrapper ( model_type ) :
path = midas . api . ISL_PATHS [ model_type ]
if not os . path . exists ( path ) :
if not os . path . exists ( midas_path ) :
mkdir ( midas_path )
2023-05-11 23:28:15 +08:00
2022-12-09 08:14:35 +08:00
print ( f " Downloading midas model weights for { model_type } to { path } " )
request . urlretrieve ( midas_urls [ model_type ] , path )
print ( f " { model_type } downloaded " )
return midas . api . load_model_inner ( model_type )
midas . api . load_model = load_model_wrapper
2023-01-04 17:35:07 +08:00
2023-01-27 16:28:12 +08:00
def repair_config ( sd_config ) :
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
2023-01-10 21:51:04 +08:00
2023-07-13 22:32:35 +08:00
if hasattr ( sd_config . model . params , ' unet_config ' ) :
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
elif shared . cmd_opts . upcast_sampling :
sd_config . model . params . unet_config . params . use_fp16 = True
2023-01-10 21:51:04 +08:00
2023-03-27 04:55:29 +08:00
if getattr ( sd_config . model . params . first_stage_config . params . ddconfig , " attn_type " , None ) == " vanilla-xformers " and not shared . xformers_available :
sd_config . model . params . first_stage_config . params . ddconfig . attn_type = " vanilla "
2023-03-25 10:48:16 +08:00
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr ( sd_config . model . params , " noise_aug_config " ) and hasattr ( sd_config . model . params . noise_aug_config . params , " clip_stats_path " ) :
karlo_path = os . path . join ( paths . models_path , ' karlo ' )
sd_config . model . params . noise_aug_config . params . clip_stats_path = sd_config . model . params . noise_aug_config . params . clip_stats_path . replace ( " checkpoints/karlo_models " , karlo_path )
2023-01-27 16:28:12 +08:00
2023-02-05 16:20:47 +08:00
sd1_clip_weight = ' cond_stage_model.transformer.text_model.embeddings.token_embedding.weight '
sd2_clip_weight = ' cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight '
2023-07-13 04:52:43 +08:00
sdxl_clip_weight = ' conditioner.embedders.1.model.ln_final.weight '
2023-07-14 14:16:01 +08:00
sdxl_refiner_clip_weight = ' conditioner.embedders.0.model.ln_final.weight '
2023-02-05 16:20:47 +08:00
2023-05-02 14:08:00 +08:00
class SdModelData :
def __init__ ( self ) :
self . sd_model = None
2023-08-01 05:24:48 +08:00
self . loaded_sd_models = [ ]
2023-05-18 20:47:43 +08:00
self . was_loaded_at_least_once = False
2023-05-02 14:08:00 +08:00
self . lock = threading . Lock ( )
def get_sd_model ( self ) :
2023-05-18 20:47:43 +08:00
if self . was_loaded_at_least_once :
return self . sd_model
2023-05-02 14:08:00 +08:00
if self . sd_model is None :
with self . lock :
2023-05-18 20:47:43 +08:00
if self . sd_model is not None or self . was_loaded_at_least_once :
2023-05-14 18:27:50 +08:00
return self . sd_model
2023-05-02 14:08:00 +08:00
try :
load_model ( )
2023-08-01 05:24:48 +08:00
2023-05-02 14:08:00 +08:00
except Exception as e :
2023-05-27 03:15:59 +08:00
errors . display ( e , " loading stable diffusion model " , full_traceback = True )
2023-05-02 14:08:00 +08:00
print ( " " , file = sys . stderr )
print ( " Stable diffusion model failed to load " , file = sys . stderr )
self . sd_model = None
return self . sd_model
def set_sd_model ( self , v ) :
self . sd_model = v
2023-08-01 05:24:48 +08:00
try :
self . loaded_sd_models . remove ( v )
except ValueError :
pass
if v is not None :
self . loaded_sd_models . insert ( 0 , v )
2023-05-02 14:08:00 +08:00
model_data = SdModelData ( )
2023-07-13 04:52:43 +08:00
def get_empty_cond ( sd_model ) :
2023-08-01 05:24:48 +08:00
from modules import extra_networks , processing
p = processing . StableDiffusionProcessingTxt2Img ( )
extra_networks . activate ( p , { } )
2023-07-13 04:52:43 +08:00
if hasattr ( sd_model , ' conditioner ' ) :
d = sd_model . get_learned_conditioning ( [ " " ] )
return d [ ' crossattn ' ]
else :
return sd_model . cond_stage_model ( [ " " ] )
2023-08-01 05:24:48 +08:00
def send_model_to_cpu ( m ) :
from modules import lowvram
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
else :
m . to ( devices . cpu )
devices . torch_gc ( )
def send_model_to_device ( m ) :
from modules import lowvram
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . setup_for_low_vram ( m , shared . cmd_opts . medvram )
else :
m . to ( shared . device )
def send_model_to_trash ( m ) :
m . to ( device = " meta " )
devices . torch_gc ( )
2023-05-02 14:08:00 +08:00
def load_model ( checkpoint_info = None , already_loaded_state_dict = None ) :
2023-08-01 05:24:48 +08:00
from modules import sd_hijack
2022-10-21 07:01:27 +08:00
checkpoint_info = checkpoint_info or select_checkpoint ( )
2022-10-09 04:26:48 +08:00
2023-08-01 05:24:48 +08:00
timer = Timer ( )
2023-05-02 14:08:00 +08:00
if model_data . sd_model :
2023-08-01 05:24:48 +08:00
send_model_to_trash ( model_data . sd_model )
2023-05-02 14:08:00 +08:00
model_data . sd_model = None
2022-11-01 15:01:49 +08:00
devices . torch_gc ( )
2023-08-01 05:24:48 +08:00
timer . record ( " unload existing model " )
2022-12-11 23:19:46 +08:00
2023-01-27 16:28:12 +08:00
if already_loaded_state_dict is not None :
state_dict = already_loaded_state_dict
else :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-01 15:01:49 +08:00
2023-01-27 16:28:12 +08:00
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
2023-07-14 14:19:08 +08:00
clip_is_included_into_sd = any ( x for x in [ sd1_clip_weight , sd2_clip_weight , sdxl_clip_weight , sdxl_refiner_clip_weight ] if x in state_dict )
2022-11-27 02:28:44 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " find config " )
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
sd_config = OmegaConf . load ( checkpoint_config )
repair_config ( sd_config )
timer . record ( " load config " )
print ( f " Creating model from config: { checkpoint_config } " )
2023-01-11 23:54:04 +08:00
2023-01-27 16:28:12 +08:00
sd_model = None
2023-01-10 22:46:59 +08:00
try :
2023-07-18 23:10:04 +08:00
with sd_disable_initialization . DisableInitialization ( disable_clip = clip_is_included_into_sd or shared . cmd_opts . do_not_download_clip ) :
2023-07-25 03:08:08 +08:00
with sd_disable_initialization . InitializeOnMeta ( ) :
sd_model = instantiate_from_config ( sd_config . model )
except Exception as e :
errors . display ( e , " creating model quickly " , full_traceback = True )
2023-01-11 15:24:56 +08:00
if sd_model is None :
2023-01-10 22:46:59 +08:00
print ( ' Failed to create model quickly; will retry using slow method. ' , file = sys . stderr )
2023-07-25 03:08:08 +08:00
with sd_disable_initialization . InitializeOnMeta ( ) :
sd_model = instantiate_from_config ( sd_config . model )
2023-01-04 17:35:07 +08:00
2023-01-27 16:28:12 +08:00
sd_model . used_config = checkpoint_config
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " create model " )
2022-09-17 17:05:04 +08:00
2023-07-25 03:08:08 +08:00
with sd_disable_initialization . LoadStateDictOnMeta ( state_dict , devices . cpu ) :
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-08-01 05:24:48 +08:00
timer . record ( " load weights from state dict " )
2023-01-10 21:51:04 +08:00
2023-08-01 05:24:48 +08:00
send_model_to_device ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2022-09-17 17:05:04 +08:00
sd_model . eval ( )
2023-08-01 05:24:48 +08:00
model_data . set_sd_model ( sd_model )
2023-05-18 20:47:43 +08:00
model_data . was_loaded_at_least_once = True
2022-10-22 17:23:45 +08:00
2023-01-03 23:39:14 +08:00
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
2023-01-27 16:28:12 +08:00
timer . record ( " load textual inversion embeddings " )
2022-10-23 01:15:12 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " scripts callbacks " )
2023-01-10 21:51:04 +08:00
2023-05-22 05:13:53 +08:00
with devices . autocast ( ) , torch . no_grad ( ) :
2023-07-13 04:52:43 +08:00
sd_model . cond_stage_model_empty_prompt = get_empty_cond ( sd_model )
2023-05-22 05:13:53 +08:00
timer . record ( " calculate empty prompt " )
2023-01-27 16:28:12 +08:00
print ( f " Model loaded in { timer . summary ( ) } . " )
2023-01-01 00:27:02 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2023-08-01 05:24:48 +08:00
def reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer ) :
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data . loaded_sd_models .
If it is loaded , returns that ( moving it to GPU if necessary , and moving the currently loadded model to CPU if necessary ) .
If not , returns the model that can be used to load weights from checkpoint_info ' s file.
If no such model exists , returns None .
Additionaly deletes loaded models that are over the limit set in settings ( sd_checkpoints_limit ) .
"""
already_loaded = None
for i in reversed ( range ( len ( model_data . loaded_sd_models ) ) ) :
loaded_model = model_data . loaded_sd_models [ i ]
if loaded_model . sd_checkpoint_info . filename == checkpoint_info . filename :
already_loaded = loaded_model
continue
if len ( model_data . loaded_sd_models ) > shared . opts . sd_checkpoints_limit > 0 :
print ( f " Unloading model { len ( model_data . loaded_sd_models ) } over the limit of { shared . opts . sd_checkpoints_limit } : { loaded_model . sd_checkpoint_info . title } " )
model_data . loaded_sd_models . pop ( )
send_model_to_trash ( loaded_model )
timer . record ( " send model to trash " )
if shared . opts . sd_checkpoints_keep_in_cpu :
send_model_to_cpu ( sd_model )
timer . record ( " send model to cpu " )
if already_loaded is not None :
send_model_to_device ( already_loaded )
timer . record ( " send model to device " )
model_data . set_sd_model ( already_loaded )
print ( f " Using already loaded model { already_loaded . sd_checkpoint_info . title } : done in { timer . summary ( ) } " )
return model_data . sd_model
elif shared . opts . sd_checkpoints_limit > 1 and len ( model_data . loaded_sd_models ) < shared . opts . sd_checkpoints_limit :
print ( f " Loading model { checkpoint_info . title } ( { len ( model_data . loaded_sd_models ) + 1 } out of { shared . opts . sd_checkpoints_limit } ) " )
model_data . sd_model = None
load_model ( checkpoint_info )
return model_data . sd_model
elif len ( model_data . loaded_sd_models ) > 0 :
sd_model = model_data . loaded_sd_models . pop ( )
model_data . sd_model = sd_model
print ( f " Reusing loaded model { sd_model . sd_checkpoint_info . title } to load { checkpoint_info . title } " )
return sd_model
else :
return None
2022-11-02 13:51:46 +08:00
def reload_model_weights ( sd_model = None , info = None ) :
2023-08-01 05:24:48 +08:00
from modules import devices , sd_hijack
2022-09-17 18:49:36 +08:00
checkpoint_info = info or select_checkpoint ( )
2023-01-04 17:35:07 +08:00
2023-08-01 05:24:48 +08:00
timer = Timer ( )
2022-11-01 15:01:49 +08:00
if not sd_model :
2023-05-02 14:08:00 +08:00
sd_model = model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-10 21:51:04 +08:00
if sd_model is None : # previous model load failed
2023-01-10 07:34:26 +08:00
current_checkpoint_info = None
else :
current_checkpoint_info = sd_model . sd_checkpoint_info
if sd_model . sd_model_checkpoint == checkpoint_info . filename :
2023-08-01 05:24:48 +08:00
return sd_model
2023-05-27 20:47:33 +08:00
2023-08-01 05:24:48 +08:00
sd_model = reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer )
if sd_model is not None and sd_model . sd_checkpoint_info . filename == checkpoint_info . filename :
return sd_model
2022-09-17 17:05:04 +08:00
2023-08-01 05:24:48 +08:00
if sd_model is not None :
sd_unet . apply_unet ( " None " )
send_model_to_cpu ( sd_model )
2023-01-27 16:54:19 +08:00
sd_hijack . model_hijack . undo_hijack ( sd_model )
2022-09-29 20:40:28 +08:00
2023-01-27 16:28:12 +08:00
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer . record ( " find config " )
if sd_model is None or checkpoint_config != sd_model . used_config :
2023-08-01 05:24:48 +08:00
if sd_model is not None :
send_model_to_trash ( sd_model )
2023-03-09 12:56:19 +08:00
load_model ( checkpoint_info , already_loaded_state_dict = state_dict )
2023-05-02 14:08:00 +08:00
return model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-04 17:35:07 +08:00
try :
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-05-10 12:52:45 +08:00
except Exception :
2023-01-04 17:35:07 +08:00
print ( " Failed to load checkpoint, restoring previous " )
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , current_checkpoint_info , None , timer )
2023-01-04 17:35:07 +08:00
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2023-01-04 17:35:07 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " script callbacks " )
2023-01-04 17:35:07 +08:00
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
sd_model . to ( devices . device )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
print ( f " Weights loaded in { timer . summary ( ) } . " )
2023-01-04 17:35:07 +08:00
2023-08-01 05:24:48 +08:00
model_data . set_sd_model ( sd_model )
2022-09-17 17:05:04 +08:00
return sd_model
2023-03-09 12:56:19 +08:00
2023-05-02 14:08:00 +08:00
2023-03-09 12:56:19 +08:00
def unload_model_weights ( sd_model = None , info = None ) :
2023-05-10 13:43:42 +08:00
from modules import devices , sd_hijack
2023-03-09 12:56:19 +08:00
timer = Timer ( )
2023-05-02 14:08:00 +08:00
if model_data . sd_model :
model_data . sd_model . to ( devices . cpu )
sd_hijack . model_hijack . undo_hijack ( model_data . sd_model )
model_data . sd_model = None
2023-03-09 12:56:19 +08:00
sd_model = None
gc . collect ( )
devices . torch_gc ( )
print ( f " Unloaded weights { timer . summary ( ) } . " )
2023-04-04 15:26:44 +08:00
return sd_model
2023-05-18 01:22:38 +08:00
def apply_token_merging ( sd_model , token_merging_ratio ) :
2023-04-04 15:26:44 +08:00
"""
Applies speed and memory optimizations from tomesd .
"""
2023-05-18 01:22:38 +08:00
current_token_merging_ratio = getattr ( sd_model , ' applied_token_merged_ratio ' , 0 )
if current_token_merging_ratio == token_merging_ratio :
return
if current_token_merging_ratio > 0 :
tomesd . remove_patch ( sd_model )
if token_merging_ratio > 0 :
tomesd . apply_patch (
sd_model ,
ratio = token_merging_ratio ,
use_rand = False , # can cause issues with some samplers
merge_attn = True ,
merge_crossattn = False ,
merge_mlp = False
)
sd_model . applied_token_merged_ratio = token_merging_ratio