2022-10-14 13:00:38 +08:00
import collections
2022-09-17 17:05:04 +08:00
import os . path
import sys
2022-11-01 15:01:49 +08:00
import gc
2023-05-02 14:08:00 +08:00
import threading
2022-09-17 17:05:04 +08:00
import torch
2022-10-28 11:49:39 +08:00
import re
2022-11-27 19:46:40 +08:00
import safetensors . torch
2022-09-17 17:05:04 +08:00
from omegaconf import OmegaConf
2022-12-09 08:14:35 +08:00
from os import mkdir
from urllib import request
import ldm . modules . midas as midas
2022-09-17 17:05:04 +08:00
from ldm . util import instantiate_from_config
2023-05-27 20:47:33 +08:00
from modules import paths , shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes , sd_models_config , sd_unet
2023-01-27 16:28:12 +08:00
from modules . sd_hijack_inpainting import do_inpainting_hijack
from modules . timer import Timer
2023-04-04 15:26:44 +08:00
import tomesd
2022-09-28 00:01:13 +08:00
model_dir = " Stable-diffusion "
2023-01-26 00:15:42 +08:00
model_path = os . path . abspath ( os . path . join ( paths . models_path , model_dir ) )
2022-09-17 17:05:04 +08:00
checkpoints_list = { }
2023-01-14 14:56:59 +08:00
checkpoint_alisases = { }
2022-10-14 13:00:38 +08:00
checkpoints_loaded = collections . OrderedDict ( )
2022-09-17 17:05:04 +08:00
2023-01-14 14:56:59 +08:00
class CheckpointInfo :
def __init__ ( self , filename ) :
self . filename = filename
abspath = os . path . abspath ( filename )
if shared . cmd_opts . ckpt_dir is not None and abspath . startswith ( shared . cmd_opts . ckpt_dir ) :
name = abspath . replace ( shared . cmd_opts . ckpt_dir , ' ' )
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( filename )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
2023-01-19 23:58:08 +08:00
self . name = name
2023-01-29 15:20:19 +08:00
self . name_for_extra = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
2023-01-14 14:56:59 +08:00
self . model_name = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
self . hash = model_hash ( filename )
2023-01-14 20:55:40 +08:00
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256_from_cache ( self . filename , f " checkpoint/ { name } " )
2023-01-14 20:55:40 +08:00
self . shorthash = self . sha256 [ 0 : 10 ] if self . sha256 else None
2023-01-19 23:58:08 +08:00
self . title = name if self . shorthash is None else f ' { name } [ { self . shorthash } ] '
self . ids = [ self . hash , self . model_name , self . title , name , f ' { name } [ { self . hash } ] ' ] + ( [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' ] if self . shorthash else [ ] )
2023-01-14 14:56:59 +08:00
2023-04-03 06:41:55 +08:00
self . metadata = { }
_ , ext = os . path . splitext ( self . filename )
if ext . lower ( ) == " .safetensors " :
try :
self . metadata = read_metadata_from_safetensors ( filename )
except Exception as e :
errors . display ( e , f " reading checkpoint metadata: { filename } " )
2023-01-14 14:56:59 +08:00
def register ( self ) :
checkpoints_list [ self . title ] = self
for id in self . ids :
checkpoint_alisases [ id ] = self
def calculate_shorthash ( self ) :
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256 ( self . filename , f " checkpoint/ { self . name } " )
2023-02-04 16:38:56 +08:00
if self . sha256 is None :
return
2023-01-14 14:56:59 +08:00
self . shorthash = self . sha256 [ 0 : 10 ]
if self . shorthash not in self . ids :
2023-02-04 20:23:16 +08:00
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' ]
2023-01-14 14:56:59 +08:00
2023-02-04 20:23:16 +08:00
checkpoints_list . pop ( self . title )
2023-01-19 23:58:08 +08:00
self . title = f ' { self . name } [ { self . shorthash } ] '
2023-02-04 20:23:16 +08:00
self . register ( )
2023-01-19 23:58:08 +08:00
2023-01-14 14:56:59 +08:00
return self . shorthash
2022-09-17 17:05:04 +08:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
2023-05-10 14:02:23 +08:00
from transformers import logging , CLIPModel # noqa: F401
2022-09-17 17:05:04 +08:00
logging . set_verbosity_error ( )
except Exception :
pass
2022-10-03 02:09:10 +08:00
def setup_model ( ) :
2022-09-28 00:01:13 +08:00
if not os . path . exists ( model_path ) :
os . makedirs ( model_path )
2022-10-03 02:09:10 +08:00
2022-12-09 08:14:35 +08:00
enable_midas_autodownload ( )
2022-09-30 08:59:36 +08:00
2023-01-14 14:56:59 +08:00
def checkpoint_tiles ( ) :
def convert ( name ) :
return int ( name ) if name . isdigit ( ) else name . lower ( )
def alphanumeric_key ( key ) :
return [ convert ( c ) for c in re . split ( ' ([0-9]+) ' , key ) ]
return sorted ( [ x . title for x in checkpoints_list . values ( ) ] , key = alphanumeric_key )
2022-09-29 05:59:44 +08:00
2022-09-17 17:05:04 +08:00
def list_models ( ) :
checkpoints_list . clear ( )
2023-01-14 14:56:59 +08:00
checkpoint_alisases . clear ( )
2022-09-17 17:05:04 +08:00
cmd_ckpt = shared . cmd_opts . ckpt
2023-02-19 19:49:07 +08:00
if shared . cmd_opts . no_download_sd_model or cmd_ckpt != shared . sd_model_file or os . path . exists ( cmd_ckpt ) :
2023-02-19 19:37:40 +08:00
model_url = None
else :
model_url = " https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors "
model_list = modelloader . load_models ( model_path = model_path , model_url = model_url , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , download_name = " v1-5-pruned-emaonly.safetensors " , ext_blacklist = [ " .vae.ckpt " , " .vae.safetensors " ] )
2022-09-17 17:05:04 +08:00
if os . path . exists ( cmd_ckpt ) :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( cmd_ckpt )
checkpoint_info . register ( )
shared . opts . data [ ' sd_model_checkpoint ' ] = checkpoint_info . title
2022-09-17 17:05:04 +08:00
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
2022-09-28 00:01:13 +08:00
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
2023-01-14 14:56:59 +08:00
2023-03-29 01:03:57 +08:00
for filename in sorted ( model_list , key = str . lower ) :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( filename )
checkpoint_info . register ( )
2022-10-09 04:26:48 +08:00
2023-01-14 14:56:59 +08:00
def get_closet_checkpoint_match ( search_string ) :
checkpoint_info = checkpoint_alisases . get ( search_string , None )
if checkpoint_info is not None :
2023-01-14 15:25:21 +08:00
return checkpoint_info
2022-09-30 16:42:40 +08:00
2023-01-14 14:56:59 +08:00
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-17 17:05:04 +08:00
2022-09-29 05:30:09 +08:00
return None
2022-09-17 17:05:04 +08:00
2022-09-30 16:42:40 +08:00
2022-09-17 17:05:04 +08:00
def model_hash ( filename ) :
2023-01-14 14:56:59 +08:00
""" old hash that only looks at a small part of the file and is prone to collisions """
2022-09-17 17:05:04 +08:00
try :
with open ( filename , " rb " ) as file :
import hashlib
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
return ' NOFILE '
def select_checkpoint ( ) :
2023-05-27 03:08:53 +08:00
""" Raises `FileNotFoundError` if no checkpoints are found. """
2022-09-17 17:05:04 +08:00
model_checkpoint = shared . opts . sd_model_checkpoint
2023-05-11 23:28:15 +08:00
2023-01-14 14:56:59 +08:00
checkpoint_info = checkpoint_alisases . get ( model_checkpoint , None )
2022-09-17 17:05:04 +08:00
if checkpoint_info is not None :
return checkpoint_info
if len ( checkpoints_list ) == 0 :
2023-05-27 03:08:53 +08:00
error_message = " No checkpoints found. When searching for checkpoints, looked at: "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - file { os . path . abspath ( shared . cmd_opts . ckpt ) } "
error_message + = f " \n - directory { model_path } "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt_dir is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - directory { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } "
error_message + = " Can ' t run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. "
raise FileNotFoundError ( error_message )
2022-09-17 17:05:04 +08:00
checkpoint_info = next ( iter ( checkpoints_list . values ( ) ) )
if model_checkpoint is not None :
print ( f " Checkpoint { model_checkpoint } not found; loading fallback { checkpoint_info . title } " , file = sys . stderr )
return checkpoint_info
2023-03-23 13:28:08 +08:00
checkpoint_dict_replacements = {
2022-10-19 13:42:22 +08:00
' cond_stage_model.transformer.embeddings. ' : ' cond_stage_model.transformer.text_model.embeddings. ' ,
' cond_stage_model.transformer.encoder. ' : ' cond_stage_model.transformer.text_model.encoder. ' ,
' cond_stage_model.transformer.final_layer_norm. ' : ' cond_stage_model.transformer.text_model.final_layer_norm. ' ,
}
def transform_checkpoint_dict_key ( k ) :
2023-03-23 13:28:08 +08:00
for text , replacement in checkpoint_dict_replacements . items ( ) :
2022-10-19 13:42:22 +08:00
if k . startswith ( text ) :
k = replacement + k [ len ( text ) : ]
return k
2022-10-09 15:23:31 +08:00
def get_state_dict_from_checkpoint ( pl_sd ) :
2022-11-28 13:39:59 +08:00
pl_sd = pl_sd . pop ( " state_dict " , pl_sd )
pl_sd . pop ( " state_dict " , None )
2022-10-19 13:42:22 +08:00
sd = { }
for k , v in pl_sd . items ( ) :
new_key = transform_checkpoint_dict_key ( k )
if new_key is not None :
sd [ new_key ] = v
2022-10-09 15:23:31 +08:00
2022-10-19 17:45:30 +08:00
pl_sd . clear ( )
pl_sd . update ( sd )
return pl_sd
2022-10-09 15:23:31 +08:00
2023-03-14 14:10:26 +08:00
def read_metadata_from_safetensors ( filename ) :
import json
with open ( filename , mode = " rb " ) as file :
metadata_len = file . read ( 8 )
metadata_len = int . from_bytes ( metadata_len , " little " )
json_start = file . read ( 2 )
assert metadata_len > 2 and json_start in ( b ' { " ' , b " { ' " ) , f " { filename } is not a safetensors file "
json_data = json_start + file . read ( metadata_len - 2 )
json_obj = json . loads ( json_data )
res = { }
for k , v in json_obj . get ( " __metadata__ " , { } ) . items ( ) :
res [ k ] = v
2023-03-14 16:22:29 +08:00
if isinstance ( v , str ) and v [ 0 : 1 ] == ' { ' :
2023-03-14 14:10:26 +08:00
try :
res [ k ] = json . loads ( v )
2023-05-10 12:52:45 +08:00
except Exception :
2023-03-14 14:10:26 +08:00
pass
return res
2022-11-27 20:51:29 +08:00
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
2023-01-27 16:28:12 +08:00
device = map_location or shared . weight_load_location or devices . get_optimal_device_name ( )
2022-12-21 20:45:58 +08:00
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
2022-11-27 20:51:29 +08:00
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
if print_global_state and " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = get_state_dict_from_checkpoint ( pl_sd )
return sd
2023-01-27 16:28:12 +08:00
def get_checkpoint_state_dict ( checkpoint_info : CheckpointInfo , timer ) :
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
return checkpoints_loaded [ checkpoint_info ]
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
res = read_state_dict ( checkpoint_info . filename )
timer . record ( " load weights from disk " )
return res
def load_model_weights ( model , checkpoint_info : CheckpointInfo , state_dict , timer ) :
2023-01-14 14:56:59 +08:00
sd_model_hash = checkpoint_info . calculate_shorthash ( )
2023-01-27 16:28:12 +08:00
timer . record ( " calculate hash " )
2023-01-28 21:23:49 +08:00
shared . opts . data [ " sd_model_checkpoint " ] = checkpoint_info . title
2022-10-09 04:26:48 +08:00
2023-01-27 16:28:12 +08:00
if state_dict is None :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-09 11:54:21 +08:00
2023-01-27 16:28:12 +08:00
model . load_state_dict ( state_dict , strict = False )
del state_dict
timer . record ( " apply weights to model " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
if shared . opts . sd_checkpoint_cache > 0 :
# cache newly loaded model
checkpoints_loaded [ checkpoint_info ] = model . state_dict ( ) . copy ( )
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
timer . record ( " apply channels_last " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
if not shared . cmd_opts . no_half :
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
2022-11-02 19:41:29 +08:00
2023-01-27 16:28:12 +08:00
model . half ( )
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
2022-11-02 19:41:29 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " apply half() " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
devices . dtype = torch . float32 if shared . cmd_opts . no_half else torch . float16
devices . dtype_vae = torch . float32 if shared . cmd_opts . no_half or shared . cmd_opts . no_half_vae else torch . float16
devices . dtype_unet = model . model . diffusion_model . dtype
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
model . first_stage_model . to ( devices . dtype_vae )
timer . record ( " apply dtype to VAE " )
2022-11-02 19:41:29 +08:00
2022-11-09 11:54:21 +08:00
# clean up cache if limit is reached
2023-01-27 16:28:12 +08:00
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache :
checkpoints_loaded . popitem ( last = False )
2022-10-31 17:27:27 +08:00
2022-09-17 17:05:04 +08:00
model . sd_model_hash = sd_model_hash
2023-01-14 14:56:59 +08:00
model . sd_model_checkpoint = checkpoint_info . filename
2022-10-09 04:26:48 +08:00
model . sd_checkpoint_info = checkpoint_info
2023-01-14 20:55:40 +08:00
shared . opts . data [ " sd_checkpoint_hash " ] = checkpoint_info . sha256
2022-09-17 17:05:04 +08:00
2023-01-02 05:38:09 +08:00
model . logvar = model . logvar . to ( devices . device ) # fix for training
2022-11-13 12:11:14 +08:00
sd_vae . delete_base_vae ( )
2022-11-03 12:10:53 +08:00
sd_vae . clear_loaded_vae ( )
2023-01-15 00:56:09 +08:00
vae_file , vae_source = sd_vae . resolve_vae ( checkpoint_info . filename )
sd_vae . load_vae ( model , vae_file , vae_source )
2023-01-27 16:28:12 +08:00
timer . record ( " load VAE " )
2022-11-02 13:51:46 +08:00
2022-09-17 17:05:04 +08:00
2022-12-09 08:14:35 +08:00
def enable_midas_autodownload ( ) :
"""
Gives the ldm . modules . midas . api . load_model function automatic downloading .
When the 512 - depth - ema model , and other future models like it , is loaded ,
it calls midas . api . load_model to load the associated midas depth model .
This function applies a wrapper to download the model to the correct
location automatically .
"""
2023-01-26 00:15:42 +08:00
midas_path = os . path . join ( paths . models_path , ' midas ' )
2022-12-09 08:14:35 +08:00
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k , v in midas . api . ISL_PATHS . items ( ) :
file_name = os . path . basename ( v )
midas . api . ISL_PATHS [ k ] = os . path . join ( midas_path , file_name )
midas_urls = {
" dpt_large " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " ,
" dpt_hybrid " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " ,
" midas_v21 " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " ,
" midas_v21_small " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " ,
}
midas . api . load_model_inner = midas . api . load_model
def load_model_wrapper ( model_type ) :
path = midas . api . ISL_PATHS [ model_type ]
if not os . path . exists ( path ) :
if not os . path . exists ( midas_path ) :
mkdir ( midas_path )
2023-05-11 23:28:15 +08:00
2022-12-09 08:14:35 +08:00
print ( f " Downloading midas model weights for { model_type } to { path } " )
request . urlretrieve ( midas_urls [ model_type ] , path )
print ( f " { model_type } downloaded " )
return midas . api . load_model_inner ( model_type )
midas . api . load_model = load_model_wrapper
2023-01-04 17:35:07 +08:00
2023-01-27 16:28:12 +08:00
def repair_config ( sd_config ) :
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
elif shared . cmd_opts . upcast_sampling :
sd_config . model . params . unet_config . params . use_fp16 = True
2023-01-10 21:51:04 +08:00
2023-03-27 04:55:29 +08:00
if getattr ( sd_config . model . params . first_stage_config . params . ddconfig , " attn_type " , None ) == " vanilla-xformers " and not shared . xformers_available :
sd_config . model . params . first_stage_config . params . ddconfig . attn_type = " vanilla "
2023-03-25 10:48:16 +08:00
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr ( sd_config . model . params , " noise_aug_config " ) and hasattr ( sd_config . model . params . noise_aug_config . params , " clip_stats_path " ) :
karlo_path = os . path . join ( paths . models_path , ' karlo ' )
sd_config . model . params . noise_aug_config . params . clip_stats_path = sd_config . model . params . noise_aug_config . params . clip_stats_path . replace ( " checkpoints/karlo_models " , karlo_path )
2023-01-27 16:28:12 +08:00
2023-02-05 16:20:47 +08:00
sd1_clip_weight = ' cond_stage_model.transformer.text_model.embeddings.token_embedding.weight '
sd2_clip_weight = ' cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight '
2023-05-02 14:08:00 +08:00
class SdModelData :
def __init__ ( self ) :
self . sd_model = None
2023-05-18 20:47:43 +08:00
self . was_loaded_at_least_once = False
2023-05-02 14:08:00 +08:00
self . lock = threading . Lock ( )
def get_sd_model ( self ) :
2023-05-18 20:47:43 +08:00
if self . was_loaded_at_least_once :
return self . sd_model
2023-05-02 14:08:00 +08:00
if self . sd_model is None :
with self . lock :
2023-05-18 20:47:43 +08:00
if self . sd_model is not None or self . was_loaded_at_least_once :
2023-05-14 18:27:50 +08:00
return self . sd_model
2023-05-02 14:08:00 +08:00
try :
load_model ( )
except Exception as e :
2023-05-27 03:15:59 +08:00
errors . display ( e , " loading stable diffusion model " , full_traceback = True )
2023-05-02 14:08:00 +08:00
print ( " " , file = sys . stderr )
print ( " Stable diffusion model failed to load " , file = sys . stderr )
self . sd_model = None
return self . sd_model
def set_sd_model ( self , v ) :
self . sd_model = v
model_data = SdModelData ( )
def load_model ( checkpoint_info = None , already_loaded_state_dict = None ) :
2022-09-17 17:05:04 +08:00
from modules import lowvram , sd_hijack
2022-10-21 07:01:27 +08:00
checkpoint_info = checkpoint_info or select_checkpoint ( )
2022-10-09 04:26:48 +08:00
2023-05-02 14:08:00 +08:00
if model_data . sd_model :
sd_hijack . model_hijack . undo_hijack ( model_data . sd_model )
model_data . sd_model = None
2022-11-01 15:01:49 +08:00
gc . collect ( )
devices . torch_gc ( )
2023-01-27 16:28:12 +08:00
do_inpainting_hijack ( )
2023-01-25 21:53:23 +08:00
2023-01-27 16:28:12 +08:00
timer = Timer ( )
2022-12-11 23:19:46 +08:00
2023-01-27 16:28:12 +08:00
if already_loaded_state_dict is not None :
state_dict = already_loaded_state_dict
else :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-01 15:01:49 +08:00
2023-01-27 16:28:12 +08:00
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
2023-02-05 16:20:47 +08:00
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
2022-11-27 02:28:44 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " find config " )
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
sd_config = OmegaConf . load ( checkpoint_config )
repair_config ( sd_config )
timer . record ( " load config " )
print ( f " Creating model from config: { checkpoint_config } " )
2023-01-11 23:54:04 +08:00
2023-01-27 16:28:12 +08:00
sd_model = None
2023-01-10 22:46:59 +08:00
try :
2023-02-05 16:20:47 +08:00
with sd_disable_initialization . DisableInitialization ( disable_clip = clip_is_included_into_sd ) :
2023-01-10 22:46:59 +08:00
sd_model = instantiate_from_config ( sd_config . model )
2023-05-10 12:52:45 +08:00
except Exception :
2023-01-11 15:24:56 +08:00
pass
if sd_model is None :
2023-01-10 22:46:59 +08:00
print ( ' Failed to create model quickly; will retry using slow method. ' , file = sys . stderr )
2023-01-10 19:08:29 +08:00
sd_model = instantiate_from_config ( sd_config . model )
2023-01-04 17:35:07 +08:00
2023-01-27 16:28:12 +08:00
sd_model . used_config = checkpoint_config
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " create model " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-01-10 21:51:04 +08:00
2022-09-17 17:05:04 +08:00
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . setup_for_low_vram ( sd_model , shared . cmd_opts . medvram )
else :
sd_model . to ( shared . device )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2022-09-17 17:05:04 +08:00
sd_model . eval ( )
2023-05-02 14:08:00 +08:00
model_data . sd_model = sd_model
2023-05-18 20:47:43 +08:00
model_data . was_loaded_at_least_once = True
2022-10-22 17:23:45 +08:00
2023-01-03 23:39:14 +08:00
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
2023-01-27 16:28:12 +08:00
timer . record ( " load textual inversion embeddings " )
2022-10-23 01:15:12 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " scripts callbacks " )
2023-01-10 21:51:04 +08:00
2023-05-22 05:13:53 +08:00
with devices . autocast ( ) , torch . no_grad ( ) :
sd_model . cond_stage_model_empty_prompt = sd_model . cond_stage_model ( [ " " ] )
timer . record ( " calculate empty prompt " )
2023-01-27 16:28:12 +08:00
print ( f " Model loaded in { timer . summary ( ) } . " )
2023-01-01 00:27:02 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2022-11-02 13:51:46 +08:00
def reload_model_weights ( sd_model = None , info = None ) :
2022-09-29 20:40:28 +08:00
from modules import lowvram , devices , sd_hijack
2022-09-17 18:49:36 +08:00
checkpoint_info = info or select_checkpoint ( )
2023-01-04 17:35:07 +08:00
2022-11-01 15:01:49 +08:00
if not sd_model :
2023-05-02 14:08:00 +08:00
sd_model = model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-10 21:51:04 +08:00
if sd_model is None : # previous model load failed
2023-01-10 07:34:26 +08:00
current_checkpoint_info = None
else :
current_checkpoint_info = sd_model . sd_checkpoint_info
if sd_model . sd_model_checkpoint == checkpoint_info . filename :
return
2022-09-17 17:05:04 +08:00
2023-05-27 20:47:33 +08:00
sd_unet . apply_unet ( " None " )
2023-01-27 16:54:19 +08:00
if shared . cmd_opts . lowvram or shared . cmd_opts . medvram :
lowvram . send_everything_to_cpu ( )
else :
sd_model . to ( devices . cpu )
2022-09-17 17:05:04 +08:00
2023-01-27 16:54:19 +08:00
sd_hijack . model_hijack . undo_hijack ( sd_model )
2022-09-29 20:40:28 +08:00
2023-01-10 21:51:04 +08:00
timer = Timer ( )
2023-01-27 16:28:12 +08:00
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer . record ( " find config " )
if sd_model is None or checkpoint_config != sd_model . used_config :
del sd_model
2023-03-09 12:56:19 +08:00
load_model ( checkpoint_info , already_loaded_state_dict = state_dict )
2023-05-02 14:08:00 +08:00
return model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-04 17:35:07 +08:00
try :
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-05-10 12:52:45 +08:00
except Exception :
2023-01-04 17:35:07 +08:00
print ( " Failed to load checkpoint, restoring previous " )
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , current_checkpoint_info , None , timer )
2023-01-04 17:35:07 +08:00
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2023-01-04 17:35:07 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " script callbacks " )
2023-01-04 17:35:07 +08:00
if not shared . cmd_opts . lowvram and not shared . cmd_opts . medvram :
sd_model . to ( devices . device )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
print ( f " Weights loaded in { timer . summary ( ) } . " )
2023-01-04 17:35:07 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2023-03-09 12:56:19 +08:00
2023-05-02 14:08:00 +08:00
2023-03-09 12:56:19 +08:00
def unload_model_weights ( sd_model = None , info = None ) :
2023-05-10 13:43:42 +08:00
from modules import devices , sd_hijack
2023-03-09 12:56:19 +08:00
timer = Timer ( )
2023-05-02 14:08:00 +08:00
if model_data . sd_model :
model_data . sd_model . to ( devices . cpu )
sd_hijack . model_hijack . undo_hijack ( model_data . sd_model )
model_data . sd_model = None
2023-03-09 12:56:19 +08:00
sd_model = None
gc . collect ( )
devices . torch_gc ( )
torch . cuda . empty_cache ( )
print ( f " Unloaded weights { timer . summary ( ) } . " )
2023-04-04 15:26:44 +08:00
return sd_model
2023-05-18 01:22:38 +08:00
def apply_token_merging ( sd_model , token_merging_ratio ) :
2023-04-04 15:26:44 +08:00
"""
Applies speed and memory optimizations from tomesd .
"""
2023-05-18 01:22:38 +08:00
current_token_merging_ratio = getattr ( sd_model , ' applied_token_merged_ratio ' , 0 )
if current_token_merging_ratio == token_merging_ratio :
return
if current_token_merging_ratio > 0 :
tomesd . remove_patch ( sd_model )
if token_merging_ratio > 0 :
tomesd . apply_patch (
sd_model ,
ratio = token_merging_ratio ,
use_rand = False , # can cause issues with some samplers
merge_attn = True ,
merge_crossattn = False ,
merge_mlp = False
)
sd_model . applied_token_merged_ratio = token_merging_ratio