2022-10-14 13:00:38 +08:00
import collections
2024-04-06 13:53:21 +08:00
import os
2022-09-17 17:05:04 +08:00
import sys
2023-05-02 14:08:00 +08:00
import threading
2022-09-17 17:05:04 +08:00
import torch
2022-10-28 11:49:39 +08:00
import re
2022-11-27 19:46:40 +08:00
import safetensors . torch
2023-09-16 00:59:44 +08:00
from omegaconf import OmegaConf , ListConfig
2022-12-09 08:14:35 +08:00
from urllib import request
import ldm . modules . midas as midas
2022-09-17 17:05:04 +08:00
from ldm . util import instantiate_from_config
2023-09-30 14:11:31 +08:00
from modules import paths , shared , modelloader , devices , script_callbacks , sd_vae , sd_disable_initialization , errors , hashes , sd_models_config , sd_unet , sd_models_xl , cache , extra_networks , processing , lowvram , sd_hijack , patches
2023-01-27 16:28:12 +08:00
from modules . timer import Timer
2024-02-27 12:43:27 +08:00
from modules . shared import opts
2023-04-04 15:26:44 +08:00
import tomesd
2023-09-16 00:59:44 +08:00
import numpy as np
2022-09-28 00:01:13 +08:00
model_dir = " Stable-diffusion "
2023-01-26 00:15:42 +08:00
model_path = os . path . abspath ( os . path . join ( paths . models_path , model_dir ) )
2022-09-17 17:05:04 +08:00
checkpoints_list = { }
2023-07-03 17:17:20 +08:00
checkpoint_aliases = { }
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
2022-10-14 13:00:38 +08:00
checkpoints_loaded = collections . OrderedDict ( )
2022-09-17 17:05:04 +08:00
2023-01-14 14:56:59 +08:00
2023-08-30 13:54:31 +08:00
def replace_key ( d , key , new_key , value ) :
keys = list ( d . keys ( ) )
d [ new_key ] = value
if key not in keys :
return d
index = keys . index ( key )
keys [ index ] = new_key
new_d = { k : d [ k ] for k in keys }
d . clear ( )
d . update ( new_d )
return d
2023-01-14 14:56:59 +08:00
class CheckpointInfo :
def __init__ ( self , filename ) :
self . filename = filename
abspath = os . path . abspath ( filename )
2023-09-08 08:46:34 +08:00
abs_ckpt_dir = os . path . abspath ( shared . cmd_opts . ckpt_dir ) if shared . cmd_opts . ckpt_dir is not None else None
2023-01-14 14:56:59 +08:00
2023-08-01 12:08:11 +08:00
self . is_safetensors = os . path . splitext ( filename ) [ 1 ] . lower ( ) == " .safetensors "
2023-09-08 08:46:34 +08:00
if abs_ckpt_dir and abspath . startswith ( abs_ckpt_dir ) :
name = abspath . replace ( abs_ckpt_dir , ' ' )
2023-01-14 14:56:59 +08:00
elif abspath . startswith ( model_path ) :
name = abspath . replace ( model_path , ' ' )
else :
name = os . path . basename ( filename )
if name . startswith ( " \\ " ) or name . startswith ( " / " ) :
name = name [ 1 : ]
2023-08-01 12:08:11 +08:00
def read_metadata ( ) :
metadata = read_metadata_from_safetensors ( filename )
self . modelspec_thumbnail = metadata . pop ( ' modelspec.thumbnail ' , None )
return metadata
self . metadata = { }
if self . is_safetensors :
try :
self . metadata = cache . cached_data_for_file ( ' safetensors-metadata ' , " checkpoint/ " + name , filename , read_metadata )
except Exception as e :
errors . display ( e , f " reading metadata for { filename } " )
2023-01-19 23:58:08 +08:00
self . name = name
2023-01-29 15:20:19 +08:00
self . name_for_extra = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
2023-01-14 14:56:59 +08:00
self . model_name = os . path . splitext ( name . replace ( " / " , " _ " ) . replace ( " \\ " , " _ " ) ) [ 0 ]
self . hash = model_hash ( filename )
2023-01-14 20:55:40 +08:00
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256_from_cache ( self . filename , f " checkpoint/ { name } " )
2023-01-14 20:55:40 +08:00
self . shorthash = self . sha256 [ 0 : 10 ] if self . sha256 else None
2023-01-19 23:58:08 +08:00
self . title = name if self . shorthash is None else f ' { name } [ { self . shorthash } ] '
2023-07-30 18:48:27 +08:00
self . short_title = self . name_for_extra if self . shorthash is None else f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-01-19 23:58:08 +08:00
2023-08-09 19:47:44 +08:00
self . ids = [ self . hash , self . model_name , self . title , name , self . name_for_extra , f ' { name } [ { self . hash } ] ' ]
if self . shorthash :
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' , f ' { self . name_for_extra } [ { self . shorthash } ] ' ]
2023-01-14 14:56:59 +08:00
def register ( self ) :
checkpoints_list [ self . title ] = self
for id in self . ids :
2023-07-03 17:17:20 +08:00
checkpoint_aliases [ id ] = self
2023-01-14 14:56:59 +08:00
def calculate_shorthash ( self ) :
2023-05-10 03:17:58 +08:00
self . sha256 = hashes . sha256 ( self . filename , f " checkpoint/ { self . name } " )
2023-02-04 16:38:56 +08:00
if self . sha256 is None :
return
2023-08-09 19:47:44 +08:00
shorthash = self . sha256 [ 0 : 10 ]
if self . shorthash == self . sha256 [ 0 : 10 ] :
return self . shorthash
self . shorthash = shorthash
2023-01-14 14:56:59 +08:00
if self . shorthash not in self . ids :
2023-08-09 19:47:44 +08:00
self . ids + = [ self . shorthash , self . sha256 , f ' { self . name } [ { self . shorthash } ] ' , f ' { self . name_for_extra } [ { self . shorthash } ] ' ]
2023-01-14 14:56:59 +08:00
2023-08-30 13:54:31 +08:00
old_title = self . title
2023-01-19 23:58:08 +08:00
self . title = f ' { self . name } [ { self . shorthash } ] '
2023-07-30 18:48:27 +08:00
self . short_title = f ' { self . name_for_extra } [ { self . shorthash } ] '
2023-08-30 13:54:31 +08:00
replace_key ( checkpoints_list , old_title , self . title , self )
2023-02-04 20:23:16 +08:00
self . register ( )
2023-01-19 23:58:08 +08:00
2023-01-14 14:56:59 +08:00
return self . shorthash
2022-09-17 17:05:04 +08:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
2023-05-10 14:02:23 +08:00
from transformers import logging , CLIPModel # noqa: F401
2022-09-17 17:05:04 +08:00
logging . set_verbosity_error ( )
except Exception :
pass
2022-10-03 02:09:10 +08:00
def setup_model ( ) :
2023-09-30 14:11:31 +08:00
""" called once at startup to do various one-time tasks related to SD models """
2023-05-29 15:18:15 +08:00
os . makedirs ( model_path , exist_ok = True )
2022-10-03 02:09:10 +08:00
2022-12-09 08:14:35 +08:00
enable_midas_autodownload ( )
2023-09-16 00:59:44 +08:00
patch_given_betas ( )
2022-09-30 08:59:36 +08:00
2023-07-30 18:48:27 +08:00
def checkpoint_tiles ( use_short = False ) :
return [ x . short_title if use_short else x . title for x in checkpoints_list . values ( ) ]
2022-09-29 05:59:44 +08:00
2022-09-17 17:05:04 +08:00
def list_models ( ) :
checkpoints_list . clear ( )
2023-07-03 17:17:20 +08:00
checkpoint_aliases . clear ( )
2022-09-17 17:05:04 +08:00
cmd_ckpt = shared . cmd_opts . ckpt
2023-02-19 19:49:07 +08:00
if shared . cmd_opts . no_download_sd_model or cmd_ckpt != shared . sd_model_file or os . path . exists ( cmd_ckpt ) :
2023-02-19 19:37:40 +08:00
model_url = None
2024-04-23 02:08:57 +08:00
expected_sha256 = None
2023-02-19 19:37:40 +08:00
else :
2024-04-06 13:53:21 +08:00
model_url = f " { shared . hf_endpoint } /runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors "
2024-04-23 02:08:57 +08:00
expected_sha256 = ' 6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa '
2023-02-19 19:37:40 +08:00
2024-04-23 02:08:57 +08:00
model_list = modelloader . load_models ( model_path = model_path , model_url = model_url , command_path = shared . cmd_opts . ckpt_dir , ext_filter = [ " .ckpt " , " .safetensors " ] , download_name = " v1-5-pruned-emaonly.safetensors " , ext_blacklist = [ " .vae.ckpt " , " .vae.safetensors " ] , hash_prefix = expected_sha256 )
2023-02-19 19:37:40 +08:00
2022-09-17 17:05:04 +08:00
if os . path . exists ( cmd_ckpt ) :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( cmd_ckpt )
checkpoint_info . register ( )
shared . opts . data [ ' sd_model_checkpoint ' ] = checkpoint_info . title
2022-09-17 17:05:04 +08:00
elif cmd_ckpt is not None and cmd_ckpt != shared . default_sd_model_file :
2022-09-28 00:01:13 +08:00
print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to { model_path } : { cmd_ckpt } " , file = sys . stderr )
2023-01-14 14:56:59 +08:00
2023-07-30 18:48:27 +08:00
for filename in model_list :
2023-01-14 14:56:59 +08:00
checkpoint_info = CheckpointInfo ( filename )
checkpoint_info . register ( )
2022-10-09 04:26:48 +08:00
2023-07-30 18:48:27 +08:00
re_strip_checksum = re . compile ( r " \ s* \ [[^]]+] \ s*$ " )
2023-01-14 14:56:59 +08:00
def get_closet_checkpoint_match ( search_string ) :
2023-08-12 17:39:59 +08:00
if not search_string :
return None
2023-07-03 17:17:20 +08:00
checkpoint_info = checkpoint_aliases . get ( search_string , None )
2023-01-14 14:56:59 +08:00
if checkpoint_info is not None :
2023-01-14 15:25:21 +08:00
return checkpoint_info
2022-09-30 16:42:40 +08:00
2023-01-14 14:56:59 +08:00
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-17 17:05:04 +08:00
2023-07-30 18:48:27 +08:00
search_string_without_checksum = re . sub ( re_strip_checksum , ' ' , search_string )
found = sorted ( [ info for info in checkpoints_list . values ( ) if search_string_without_checksum in info . title ] , key = lambda x : len ( x . title ) )
if found :
return found [ 0 ]
2022-09-29 05:30:09 +08:00
return None
2022-09-17 17:05:04 +08:00
2022-09-30 16:42:40 +08:00
2022-09-17 17:05:04 +08:00
def model_hash ( filename ) :
2023-01-14 14:56:59 +08:00
""" old hash that only looks at a small part of the file and is prone to collisions """
2022-09-17 17:05:04 +08:00
try :
with open ( filename , " rb " ) as file :
import hashlib
m = hashlib . sha256 ( )
file . seek ( 0x100000 )
m . update ( file . read ( 0x10000 ) )
return m . hexdigest ( ) [ 0 : 8 ]
except FileNotFoundError :
return ' NOFILE '
def select_checkpoint ( ) :
2023-05-27 03:08:53 +08:00
""" Raises `FileNotFoundError` if no checkpoints are found. """
2022-09-17 17:05:04 +08:00
model_checkpoint = shared . opts . sd_model_checkpoint
2023-05-11 23:28:15 +08:00
2023-07-03 17:17:20 +08:00
checkpoint_info = checkpoint_aliases . get ( model_checkpoint , None )
2022-09-17 17:05:04 +08:00
if checkpoint_info is not None :
return checkpoint_info
if len ( checkpoints_list ) == 0 :
2023-05-27 03:08:53 +08:00
error_message = " No checkpoints found. When searching for checkpoints, looked at: "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - file { os . path . abspath ( shared . cmd_opts . ckpt ) } "
error_message + = f " \n - directory { model_path } "
2022-10-03 02:09:10 +08:00
if shared . cmd_opts . ckpt_dir is not None :
2023-05-27 03:08:53 +08:00
error_message + = f " \n - directory { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } "
error_message + = " Can ' t run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. "
raise FileNotFoundError ( error_message )
2022-09-17 17:05:04 +08:00
checkpoint_info = next ( iter ( checkpoints_list . values ( ) ) )
if model_checkpoint is not None :
print ( f " Checkpoint { model_checkpoint } not found; loading fallback { checkpoint_info . title } " , file = sys . stderr )
return checkpoint_info
2023-12-02 11:58:05 +08:00
checkpoint_dict_replacements_sd1 = {
2022-10-19 13:42:22 +08:00
' cond_stage_model.transformer.embeddings. ' : ' cond_stage_model.transformer.text_model.embeddings. ' ,
' cond_stage_model.transformer.encoder. ' : ' cond_stage_model.transformer.text_model.encoder. ' ,
' cond_stage_model.transformer.final_layer_norm. ' : ' cond_stage_model.transformer.text_model.final_layer_norm. ' ,
}
2023-12-02 11:58:05 +08:00
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
' conditioner.embedders.0. ' : ' cond_stage_model. ' ,
}
2022-10-19 13:42:22 +08:00
2023-12-02 11:58:05 +08:00
def transform_checkpoint_dict_key ( k , replacements ) :
for text , replacement in replacements . items ( ) :
2022-10-19 13:42:22 +08:00
if k . startswith ( text ) :
k = replacement + k [ len ( text ) : ]
return k
2022-10-09 15:23:31 +08:00
def get_state_dict_from_checkpoint ( pl_sd ) :
2022-11-28 13:39:59 +08:00
pl_sd = pl_sd . pop ( " state_dict " , pl_sd )
pl_sd . pop ( " state_dict " , None )
2022-10-19 13:42:22 +08:00
2023-12-02 11:58:05 +08:00
is_sd2_turbo = ' conditioner.embedders.0.model.ln_final.weight ' in pl_sd and pl_sd [ ' conditioner.embedders.0.model.ln_final.weight ' ] . size ( ) [ 0 ] == 1024
2022-10-19 13:42:22 +08:00
sd = { }
for k , v in pl_sd . items ( ) :
2023-12-02 11:58:05 +08:00
if is_sd2_turbo :
new_key = transform_checkpoint_dict_key ( k , checkpoint_dict_replacements_sd2_turbo )
else :
new_key = transform_checkpoint_dict_key ( k , checkpoint_dict_replacements_sd1 )
2022-10-19 13:42:22 +08:00
if new_key is not None :
sd [ new_key ] = v
2022-10-09 15:23:31 +08:00
2022-10-19 17:45:30 +08:00
pl_sd . clear ( )
pl_sd . update ( sd )
return pl_sd
2022-10-09 15:23:31 +08:00
2023-03-14 14:10:26 +08:00
def read_metadata_from_safetensors ( filename ) :
import json
with open ( filename , mode = " rb " ) as file :
metadata_len = file . read ( 8 )
metadata_len = int . from_bytes ( metadata_len , " little " )
json_start = file . read ( 2 )
assert metadata_len > 2 and json_start in ( b ' { " ' , b " { ' " ) , f " { filename } is not a safetensors file "
res = { }
2024-04-26 18:52:21 +08:00
try :
json_data = json_start + file . read ( metadata_len - 2 )
json_obj = json . loads ( json_data )
for k , v in json_obj . get ( " __metadata__ " , { } ) . items ( ) :
res [ k ] = v
if isinstance ( v , str ) and v [ 0 : 1 ] == ' { ' :
try :
res [ k ] = json . loads ( v )
except Exception :
pass
2024-04-26 20:21:12 +08:00
except Exception :
2024-04-26 18:52:21 +08:00
errors . report ( f " Error reading metadata from file: { filename } " , exc_info = True )
2024-04-26 20:17:37 +08:00
2023-03-14 14:10:26 +08:00
return res
2022-11-27 20:51:29 +08:00
def read_state_dict ( checkpoint_file , print_global_state = False , map_location = None ) :
_ , extension = os . path . splitext ( checkpoint_file )
if extension . lower ( ) == " .safetensors " :
2023-06-27 14:19:04 +08:00
device = map_location or shared . weight_load_location or devices . get_optimal_device_name ( )
2023-06-17 00:10:15 +08:00
if not shared . opts . disable_mmap_load_safetensors :
pl_sd = safetensors . torch . load_file ( checkpoint_file , device = device )
else :
pl_sd = safetensors . torch . load ( open ( checkpoint_file , ' rb ' ) . read ( ) )
2023-06-27 14:19:04 +08:00
pl_sd = { k : v . to ( device ) for k , v in pl_sd . items ( ) }
2022-11-27 20:51:29 +08:00
else :
pl_sd = torch . load ( checkpoint_file , map_location = map_location or shared . weight_load_location )
if print_global_state and " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = get_state_dict_from_checkpoint ( pl_sd )
return sd
2023-01-27 16:28:12 +08:00
def get_checkpoint_state_dict ( checkpoint_info : CheckpointInfo , timer ) :
sd_model_hash = checkpoint_info . calculate_shorthash ( )
timer . record ( " calculate hash " )
if checkpoint_info in checkpoints_loaded :
# use checkpoint cache
print ( f " Loading weights [ { sd_model_hash } ] from cache " )
2023-09-18 16:45:42 +08:00
# move to end as latest
2023-09-15 18:23:23 +08:00
checkpoints_loaded . move_to_end ( checkpoint_info )
2023-01-27 16:28:12 +08:00
return checkpoints_loaded [ checkpoint_info ]
print ( f " Loading weights [ { sd_model_hash } ] from { checkpoint_info . filename } " )
res = read_state_dict ( checkpoint_info . filename )
timer . record ( " load weights from disk " )
return res
2023-08-06 22:01:07 +08:00
class SkipWritingToConfig :
""" This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight. """
skip = False
previous = None
def __enter__ ( self ) :
self . previous = SkipWritingToConfig . skip
SkipWritingToConfig . skip = True
return self
def __exit__ ( self , exc_type , exc_value , exc_traceback ) :
SkipWritingToConfig . skip = self . previous
2023-11-19 15:50:06 +08:00
def check_fp8 ( model ) :
if model is None :
return None
if devices . get_optimal_device_name ( ) == " mps " :
enable_fp8 = False
elif shared . opts . fp8_storage == " Enable " :
enable_fp8 = True
elif getattr ( model , " is_sdxl " , False ) and shared . opts . fp8_storage == " Enable for SDXL " :
enable_fp8 = True
else :
enable_fp8 = False
return enable_fp8
2023-01-27 16:28:12 +08:00
def load_model_weights ( model , checkpoint_info : CheckpointInfo , state_dict , timer ) :
2023-01-14 14:56:59 +08:00
sd_model_hash = checkpoint_info . calculate_shorthash ( )
2023-01-27 16:28:12 +08:00
timer . record ( " calculate hash " )
2023-11-25 12:35:09 +08:00
if devices . fp8 :
2023-11-19 15:50:06 +08:00
# prevent model to load state dict in fp8
model . half ( )
2023-08-06 22:01:07 +08:00
if not SkipWritingToConfig . skip :
shared . opts . data [ " sd_model_checkpoint " ] = checkpoint_info . title
2022-10-09 04:26:48 +08:00
2023-01-27 16:28:12 +08:00
if state_dict is None :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-09 11:54:21 +08:00
2023-07-14 02:17:50 +08:00
model . is_sdxl = hasattr ( model , ' conditioner ' )
2023-07-17 23:56:14 +08:00
model . is_sd2 = not model . is_sdxl and hasattr ( model . cond_stage_model , ' model ' )
model . is_sd1 = not model . is_sdxl and not model . is_sd2
2023-11-05 21:43:49 +08:00
model . is_ssd = model . is_sdxl and ' model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight ' not in state_dict . keys ( )
2023-07-14 02:17:50 +08:00
if model . is_sdxl :
2023-07-12 02:16:43 +08:00
sd_models_xl . extend_sdxl ( model )
2023-11-05 21:43:49 +08:00
if model . is_ssd :
2023-11-06 00:46:20 +08:00
sd_hijack . model_hijack . convert_sdxl_to_ssd ( model )
2023-11-06 00:32:21 +08:00
2023-01-27 16:28:12 +08:00
if shared . opts . sd_checkpoint_cache > 0 :
# cache newly loaded model
2023-10-14 13:01:04 +08:00
checkpoints_loaded [ checkpoint_info ] = state_dict . copy ( )
2023-10-07 15:36:01 +08:00
model . load_state_dict ( state_dict , strict = False )
timer . record ( " apply weights to model " )
2023-08-04 11:43:27 +08:00
del state_dict
2023-01-27 16:28:12 +08:00
2024-06-09 10:11:11 +08:00
# Set is_sdxl_inpaint flag.
2024-06-09 10:15:37 +08:00
# Checks Unet structure to detect inpaint model. The inpaint model's
# checkpoint state_dict does not contain the key
# 'diffusion_model.input_blocks.0.0.weight'.
2024-06-09 10:11:11 +08:00
diffusion_model_input = model . model . state_dict ( ) . get (
' diffusion_model.input_blocks.0.0.weight '
)
model . is_sdxl_inpaint = (
model . is_sdxl and
diffusion_model_input is not None and
diffusion_model_input . shape [ 1 ] == 9
)
2023-01-27 16:28:12 +08:00
if shared . cmd_opts . opt_channelslast :
model . to ( memory_format = torch . channels_last )
timer . record ( " apply channels_last " )
2022-09-17 17:05:04 +08:00
2023-08-17 12:54:07 +08:00
if shared . cmd_opts . no_half :
model . float ( )
2023-12-03 03:09:18 +08:00
model . alphas_cumprod_original = model . alphas_cumprod
2023-08-23 12:10:43 +08:00
devices . dtype_unet = torch . float32
2024-05-17 07:50:06 +08:00
assert shared . cmd_opts . precision != " half " , " Cannot use --precision half with --no-half "
2023-08-17 12:54:07 +08:00
timer . record ( " apply float() " )
else :
2023-01-27 16:28:12 +08:00
vae = model . first_stage_model
depth_model = getattr ( model , ' depth_model ' , None )
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared . cmd_opts . no_half_vae :
model . first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared . cmd_opts . upcast_sampling and depth_model :
model . depth_model = None
2022-11-02 19:41:29 +08:00
2023-11-30 06:38:53 +08:00
alphas_cumprod = model . alphas_cumprod
model . alphas_cumprod = None
2023-01-27 16:28:12 +08:00
model . half ( )
2023-11-30 06:38:53 +08:00
model . alphas_cumprod = alphas_cumprod
model . alphas_cumprod_original = alphas_cumprod
2023-01-27 16:28:12 +08:00
model . first_stage_model = vae
if depth_model :
model . depth_model = depth_model
2022-11-02 19:41:29 +08:00
2023-08-23 12:10:43 +08:00
devices . dtype_unet = torch . float16
2023-01-27 16:28:12 +08:00
timer . record ( " apply half() " )
2023-10-24 01:49:05 +08:00
2024-03-02 11:53:53 +08:00
apply_alpha_schedule_override ( model )
2023-11-21 19:59:34 +08:00
for module in model . modules ( ) :
if hasattr ( module , ' fp16_weight ' ) :
del module . fp16_weight
if hasattr ( module , ' fp16_bias ' ) :
del module . fp16_bias
2023-11-19 15:50:06 +08:00
if check_fp8 ( model ) :
2023-10-24 01:49:05 +08:00
devices . fp8 = True
2023-11-19 15:50:06 +08:00
first_stage = model . first_stage_model
model . first_stage_model = None
for module in model . modules ( ) :
2023-11-21 19:59:34 +08:00
if isinstance ( module , ( torch . nn . Conv2d , torch . nn . Linear ) ) :
if shared . opts . cache_fp16_weight :
2023-12-02 22:06:47 +08:00
module . fp16_weight = module . weight . data . clone ( ) . cpu ( ) . half ( )
2023-11-21 19:59:34 +08:00
if module . bias is not None :
2023-12-02 22:06:47 +08:00
module . fp16_bias = module . bias . data . clone ( ) . cpu ( ) . half ( )
2023-11-19 15:50:06 +08:00
module . to ( torch . float8_e4m3fn )
model . first_stage_model = first_stage
2023-10-28 15:24:26 +08:00
timer . record ( " apply fp8 " )
else :
devices . fp8 = False
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
devices . unet_needs_upcast = shared . cmd_opts . upcast_sampling and devices . dtype == torch . float16 and devices . dtype_unet == torch . float16
2022-09-17 17:05:04 +08:00
2023-01-27 16:28:12 +08:00
model . first_stage_model . to ( devices . dtype_vae )
timer . record ( " apply dtype to VAE " )
2022-11-02 19:41:29 +08:00
2022-11-09 11:54:21 +08:00
# clean up cache if limit is reached
2023-01-27 16:28:12 +08:00
while len ( checkpoints_loaded ) > shared . opts . sd_checkpoint_cache :
checkpoints_loaded . popitem ( last = False )
2022-10-31 17:27:27 +08:00
2022-09-17 17:05:04 +08:00
model . sd_model_hash = sd_model_hash
2023-01-14 14:56:59 +08:00
model . sd_model_checkpoint = checkpoint_info . filename
2022-10-09 04:26:48 +08:00
model . sd_checkpoint_info = checkpoint_info
2023-01-14 20:55:40 +08:00
shared . opts . data [ " sd_checkpoint_hash " ] = checkpoint_info . sha256
2022-09-17 17:05:04 +08:00
2023-07-12 02:16:43 +08:00
if hasattr ( model , ' logvar ' ) :
model . logvar = model . logvar . to ( devices . device ) # fix for training
2023-01-02 05:38:09 +08:00
2022-11-13 12:11:14 +08:00
sd_vae . delete_base_vae ( )
2022-11-03 12:10:53 +08:00
sd_vae . clear_loaded_vae ( )
2023-08-07 13:07:09 +08:00
vae_file , vae_source = sd_vae . resolve_vae ( checkpoint_info . filename ) . tuple ( )
2023-01-15 00:56:09 +08:00
sd_vae . load_vae ( model , vae_file , vae_source )
2023-01-27 16:28:12 +08:00
timer . record ( " load VAE " )
2022-11-02 13:51:46 +08:00
2022-09-17 17:05:04 +08:00
2022-12-09 08:14:35 +08:00
def enable_midas_autodownload ( ) :
"""
Gives the ldm . modules . midas . api . load_model function automatic downloading .
When the 512 - depth - ema model , and other future models like it , is loaded ,
it calls midas . api . load_model to load the associated midas depth model .
This function applies a wrapper to download the model to the correct
location automatically .
"""
2023-01-26 00:15:42 +08:00
midas_path = os . path . join ( paths . models_path , ' midas ' )
2022-12-09 08:14:35 +08:00
# stable-diffusion-stability-ai hard-codes the midas model path to
# a location that differs from where other scripts using this model look.
# HACK: Overriding the path here.
for k , v in midas . api . ISL_PATHS . items ( ) :
file_name = os . path . basename ( v )
midas . api . ISL_PATHS [ k ] = os . path . join ( midas_path , file_name )
midas_urls = {
" dpt_large " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " ,
" dpt_hybrid " : " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " ,
" midas_v21 " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " ,
" midas_v21_small " : " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " ,
}
midas . api . load_model_inner = midas . api . load_model
def load_model_wrapper ( model_type ) :
path = midas . api . ISL_PATHS [ model_type ]
if not os . path . exists ( path ) :
if not os . path . exists ( midas_path ) :
2024-04-06 13:53:21 +08:00
os . mkdir ( midas_path )
2023-05-11 23:28:15 +08:00
2022-12-09 08:14:35 +08:00
print ( f " Downloading midas model weights for { model_type } to { path } " )
request . urlretrieve ( midas_urls [ model_type ] , path )
print ( f " { model_type } downloaded " )
return midas . api . load_model_inner ( model_type )
midas . api . load_model = load_model_wrapper
2023-01-04 17:35:07 +08:00
2023-09-16 00:59:44 +08:00
def patch_given_betas ( ) :
2023-09-30 14:11:31 +08:00
import ldm . models . diffusion . ddpm
2023-09-16 00:59:44 +08:00
def patched_register_schedule ( * args , * * kwargs ) :
2023-09-30 14:11:31 +08:00
""" a modified version of register_schedule function that converts plain list from Omegaconf into numpy """
if isinstance ( args [ 1 ] , ListConfig ) :
args = ( args [ 0 ] , np . array ( args [ 1 ] ) , * args [ 2 : ] )
2023-09-16 00:59:44 +08:00
original_register_schedule ( * args , * * kwargs )
2023-09-30 14:11:31 +08:00
original_register_schedule = patches . patch ( __name__ , ldm . models . diffusion . ddpm . DDPM , ' register_schedule ' , patched_register_schedule )
2023-09-16 00:59:44 +08:00
2023-01-27 16:28:12 +08:00
def repair_config ( sd_config ) :
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
if not hasattr ( sd_config . model . params , " use_ema " ) :
sd_config . model . params . use_ema = False
2023-01-10 21:51:04 +08:00
2023-07-13 22:32:35 +08:00
if hasattr ( sd_config . model . params , ' unet_config ' ) :
if shared . cmd_opts . no_half :
sd_config . model . params . unet_config . params . use_fp16 = False
2024-05-18 01:34:04 +08:00
elif shared . cmd_opts . upcast_sampling or shared . cmd_opts . precision == " half " :
2023-07-13 22:32:35 +08:00
sd_config . model . params . unet_config . params . use_fp16 = True
2023-01-10 21:51:04 +08:00
2023-03-27 04:55:29 +08:00
if getattr ( sd_config . model . params . first_stage_config . params . ddconfig , " attn_type " , None ) == " vanilla-xformers " and not shared . xformers_available :
sd_config . model . params . first_stage_config . params . ddconfig . attn_type = " vanilla "
2023-03-25 10:48:16 +08:00
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr ( sd_config . model . params , " noise_aug_config " ) and hasattr ( sd_config . model . params . noise_aug_config . params , " clip_stats_path " ) :
karlo_path = os . path . join ( paths . models_path , ' karlo ' )
sd_config . model . params . noise_aug_config . params . clip_stats_path = sd_config . model . params . noise_aug_config . params . clip_stats_path . replace ( " checkpoints/karlo_models " , karlo_path )
2024-05-17 04:39:02 +08:00
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
2024-05-17 08:06:04 +08:00
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr ( sd_config . model . params , " network_config " ) :
sd_config . model . params . network_config . params . use_checkpoint = False
if hasattr ( sd_config . model . params , " unet_config " ) :
sd_config . model . params . unet_config . params . use_checkpoint = False
2024-05-17 04:39:02 +08:00
2024-03-02 11:54:11 +08:00
def rescale_zero_terminal_snr_abar ( alphas_cumprod ) :
alphas_bar_sqrt = alphas_cumprod . sqrt ( )
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt [ 0 ] . clone ( )
alphas_bar_sqrt_T = alphas_bar_sqrt [ - 1 ] . clone ( )
# Shift so the last timestep is zero.
alphas_bar_sqrt - = ( alphas_bar_sqrt_T )
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt * = alphas_bar_sqrt_0 / ( alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt * * 2 # Revert sqrt
alphas_bar [ - 1 ] = 4.8973451890853435e-08
return alphas_bar
2024-02-27 12:43:27 +08:00
def apply_alpha_schedule_override ( sd_model , p = None ) :
2024-03-02 11:54:11 +08:00
"""
Applies an override to the alpha schedule of the model according to settings .
- downcasts the alpha schedule to half precision
- rescales the alpha schedule to have zero terminal SNR
"""
if not hasattr ( sd_model , ' alphas_cumprod ' ) or not hasattr ( sd_model , ' alphas_cumprod_original ' ) :
return
sd_model . alphas_cumprod = sd_model . alphas_cumprod_original . to ( shared . device )
if opts . use_downcasted_alpha_bar :
if p is not None :
p . extra_generation_params [ ' Downcast alphas_cumprod ' ] = opts . use_downcasted_alpha_bar
sd_model . alphas_cumprod = sd_model . alphas_cumprod . half ( ) . to ( shared . device )
if opts . sd_noise_schedule == " Zero Terminal SNR " :
if p is not None :
p . extra_generation_params [ ' Noise Schedule ' ] = opts . sd_noise_schedule
sd_model . alphas_cumprod = rescale_zero_terminal_snr_abar ( sd_model . alphas_cumprod ) . to ( shared . device )
2023-01-27 16:28:12 +08:00
2023-02-05 16:20:47 +08:00
sd1_clip_weight = ' cond_stage_model.transformer.text_model.embeddings.token_embedding.weight '
sd2_clip_weight = ' cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight '
2023-07-13 04:52:43 +08:00
sdxl_clip_weight = ' conditioner.embedders.1.model.ln_final.weight '
2023-07-14 14:16:01 +08:00
sdxl_refiner_clip_weight = ' conditioner.embedders.0.model.ln_final.weight '
2023-02-05 16:20:47 +08:00
2023-05-02 14:08:00 +08:00
class SdModelData :
def __init__ ( self ) :
self . sd_model = None
2023-08-01 05:24:48 +08:00
self . loaded_sd_models = [ ]
2023-05-18 20:47:43 +08:00
self . was_loaded_at_least_once = False
2023-05-02 14:08:00 +08:00
self . lock = threading . Lock ( )
def get_sd_model ( self ) :
2023-05-18 20:47:43 +08:00
if self . was_loaded_at_least_once :
return self . sd_model
2023-05-02 14:08:00 +08:00
if self . sd_model is None :
with self . lock :
2023-05-18 20:47:43 +08:00
if self . sd_model is not None or self . was_loaded_at_least_once :
2023-05-14 18:27:50 +08:00
return self . sd_model
2023-05-02 14:08:00 +08:00
try :
load_model ( )
2023-08-01 05:24:48 +08:00
2023-05-02 14:08:00 +08:00
except Exception as e :
2023-05-27 03:15:59 +08:00
errors . display ( e , " loading stable diffusion model " , full_traceback = True )
2023-05-02 14:08:00 +08:00
print ( " " , file = sys . stderr )
print ( " Stable diffusion model failed to load " , file = sys . stderr )
self . sd_model = None
return self . sd_model
2023-08-20 14:00:14 +08:00
def set_sd_model ( self , v , already_loaded = False ) :
2023-05-02 14:08:00 +08:00
self . sd_model = v
2023-08-20 14:00:14 +08:00
if already_loaded :
2023-08-20 18:44:37 +08:00
sd_vae . base_vae = getattr ( v , " base_vae " , None )
sd_vae . loaded_vae_file = getattr ( v , " loaded_vae_file " , None )
sd_vae . checkpoint_info = v . sd_checkpoint_info
2023-05-02 14:08:00 +08:00
2023-08-01 05:24:48 +08:00
try :
self . loaded_sd_models . remove ( v )
except ValueError :
pass
if v is not None :
self . loaded_sd_models . insert ( 0 , v )
2023-05-02 14:08:00 +08:00
model_data = SdModelData ( )
2023-07-13 04:52:43 +08:00
def get_empty_cond ( sd_model ) :
2023-08-01 05:24:48 +08:00
p = processing . StableDiffusionProcessingTxt2Img ( )
extra_networks . activate ( p , { } )
2023-07-13 04:52:43 +08:00
if hasattr ( sd_model , ' conditioner ' ) :
d = sd_model . get_learned_conditioning ( [ " " ] )
return d [ ' crossattn ' ]
else :
return sd_model . cond_stage_model ( [ " " ] )
2023-08-01 05:24:48 +08:00
def send_model_to_cpu ( m ) :
2024-04-23 01:35:25 +08:00
if m is not None :
if m . lowvram :
lowvram . send_everything_to_cpu ( )
else :
m . to ( devices . cpu )
2023-08-01 05:24:48 +08:00
devices . torch_gc ( )
2023-08-22 23:49:08 +08:00
def model_target_device ( m ) :
if lowvram . is_needed ( m ) :
2023-08-16 17:11:01 +08:00
return devices . cpu
else :
return devices . device
2023-08-01 05:24:48 +08:00
def send_model_to_device ( m ) :
2023-08-22 23:49:08 +08:00
lowvram . apply ( m )
if not m . lowvram :
2023-08-01 05:24:48 +08:00
m . to ( shared . device )
def send_model_to_trash ( m ) :
m . to ( device = " meta " )
devices . torch_gc ( )
2023-05-02 14:08:00 +08:00
def load_model ( checkpoint_info = None , already_loaded_state_dict = None ) :
2023-08-01 05:24:48 +08:00
from modules import sd_hijack
2022-10-21 07:01:27 +08:00
checkpoint_info = checkpoint_info or select_checkpoint ( )
2022-10-09 04:26:48 +08:00
2023-08-01 05:24:48 +08:00
timer = Timer ( )
2023-05-02 14:08:00 +08:00
if model_data . sd_model :
2023-08-01 05:24:48 +08:00
send_model_to_trash ( model_data . sd_model )
2023-05-02 14:08:00 +08:00
model_data . sd_model = None
2022-11-01 15:01:49 +08:00
devices . torch_gc ( )
2023-08-01 05:24:48 +08:00
timer . record ( " unload existing model " )
2022-12-11 23:19:46 +08:00
2023-01-27 16:28:12 +08:00
if already_loaded_state_dict is not None :
state_dict = already_loaded_state_dict
else :
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
2022-11-01 15:01:49 +08:00
2023-01-27 16:28:12 +08:00
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
2023-07-14 14:19:08 +08:00
clip_is_included_into_sd = any ( x for x in [ sd1_clip_weight , sd2_clip_weight , sdxl_clip_weight , sdxl_refiner_clip_weight ] if x in state_dict )
2022-11-27 02:28:44 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " find config " )
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
sd_config = OmegaConf . load ( checkpoint_config )
repair_config ( sd_config )
timer . record ( " load config " )
print ( f " Creating model from config: { checkpoint_config } " )
2023-01-11 23:54:04 +08:00
2023-01-27 16:28:12 +08:00
sd_model = None
2023-01-10 22:46:59 +08:00
try :
2023-07-18 23:10:04 +08:00
with sd_disable_initialization . DisableInitialization ( disable_clip = clip_is_included_into_sd or shared . cmd_opts . do_not_download_clip ) :
2023-07-25 03:08:08 +08:00
with sd_disable_initialization . InitializeOnMeta ( ) :
sd_model = instantiate_from_config ( sd_config . model )
except Exception as e :
errors . display ( e , " creating model quickly " , full_traceback = True )
2023-01-11 15:24:56 +08:00
if sd_model is None :
2023-01-10 22:46:59 +08:00
print ( ' Failed to create model quickly; will retry using slow method. ' , file = sys . stderr )
2023-07-25 03:08:08 +08:00
with sd_disable_initialization . InitializeOnMeta ( ) :
sd_model = instantiate_from_config ( sd_config . model )
2023-01-04 17:35:07 +08:00
2023-01-27 16:28:12 +08:00
sd_model . used_config = checkpoint_config
2023-01-10 21:51:04 +08:00
2023-01-27 16:28:12 +08:00
timer . record ( " create model " )
2022-09-17 17:05:04 +08:00
2023-08-16 17:11:01 +08:00
if shared . cmd_opts . no_half :
weight_dtype_conversion = None
else :
weight_dtype_conversion = {
' first_stage_model ' : None ,
2023-11-30 06:38:53 +08:00
' alphas_cumprod ' : None ,
2023-08-16 17:11:01 +08:00
' ' : torch . float16 ,
}
2023-08-22 23:49:08 +08:00
with sd_disable_initialization . LoadStateDictOnMeta ( state_dict , device = model_target_device ( sd_model ) , weight_dtype_conversion = weight_dtype_conversion ) :
2023-07-25 03:08:08 +08:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-08-01 05:24:48 +08:00
timer . record ( " load weights from state dict " )
2023-01-10 21:51:04 +08:00
2023-08-01 05:24:48 +08:00
send_model_to_device ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2022-09-17 17:05:04 +08:00
sd_model . eval ( )
2023-08-01 05:24:48 +08:00
model_data . set_sd_model ( sd_model )
2023-05-18 20:47:43 +08:00
model_data . was_loaded_at_least_once = True
2022-10-22 17:23:45 +08:00
2023-01-03 23:39:14 +08:00
sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True ) # Reload embeddings after model load as they may or may not fit the model
2023-01-27 16:28:12 +08:00
timer . record ( " load textual inversion embeddings " )
2022-10-23 01:15:12 +08:00
script_callbacks . model_loaded_callback ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " scripts callbacks " )
2023-01-10 21:51:04 +08:00
2023-05-22 05:13:53 +08:00
with devices . autocast ( ) , torch . no_grad ( ) :
2023-07-13 04:52:43 +08:00
sd_model . cond_stage_model_empty_prompt = get_empty_cond ( sd_model )
2023-05-22 05:13:53 +08:00
timer . record ( " calculate empty prompt " )
2023-01-27 16:28:12 +08:00
print ( f " Model loaded in { timer . summary ( ) } . " )
2023-01-01 00:27:02 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2023-08-01 05:24:48 +08:00
def reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer ) :
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data . loaded_sd_models .
If it is loaded , returns that ( moving it to GPU if necessary , and moving the currently loadded model to CPU if necessary ) .
If not , returns the model that can be used to load weights from checkpoint_info ' s file.
If no such model exists , returns None .
2024-03-04 14:37:23 +08:00
Additionally deletes loaded models that are over the limit set in settings ( sd_checkpoints_limit ) .
2023-08-01 05:24:48 +08:00
"""
2024-03-22 06:28:38 +08:00
if sd_model is not None and sd_model . sd_checkpoint_info . filename == checkpoint_info . filename :
return sd_model
if shared . opts . sd_checkpoints_keep_in_cpu :
send_model_to_cpu ( sd_model )
timer . record ( " send model to cpu " )
2023-08-01 05:24:48 +08:00
already_loaded = None
for i in reversed ( range ( len ( model_data . loaded_sd_models ) ) ) :
loaded_model = model_data . loaded_sd_models [ i ]
if loaded_model . sd_checkpoint_info . filename == checkpoint_info . filename :
already_loaded = loaded_model
continue
if len ( model_data . loaded_sd_models ) > shared . opts . sd_checkpoints_limit > 0 :
print ( f " Unloading model { len ( model_data . loaded_sd_models ) } over the limit of { shared . opts . sd_checkpoints_limit } : { loaded_model . sd_checkpoint_info . title } " )
2024-03-26 13:53:16 +08:00
del model_data . loaded_sd_models [ i ]
2023-08-01 05:24:48 +08:00
send_model_to_trash ( loaded_model )
timer . record ( " send model to trash " )
if already_loaded is not None :
send_model_to_device ( already_loaded )
timer . record ( " send model to device " )
2023-08-20 14:00:14 +08:00
model_data . set_sd_model ( already_loaded , already_loaded = True )
2023-08-10 22:04:59 +08:00
if not SkipWritingToConfig . skip :
shared . opts . data [ " sd_model_checkpoint " ] = already_loaded . sd_checkpoint_info . title
shared . opts . data [ " sd_checkpoint_hash " ] = already_loaded . sd_checkpoint_info . sha256
2023-08-01 05:24:48 +08:00
print ( f " Using already loaded model { already_loaded . sd_checkpoint_info . title } : done in { timer . summary ( ) } " )
2023-08-21 10:28:53 +08:00
sd_vae . reload_vae_weights ( already_loaded )
2023-08-01 05:24:48 +08:00
return model_data . sd_model
elif shared . opts . sd_checkpoints_limit > 1 and len ( model_data . loaded_sd_models ) < shared . opts . sd_checkpoints_limit :
print ( f " Loading model { checkpoint_info . title } ( { len ( model_data . loaded_sd_models ) + 1 } out of { shared . opts . sd_checkpoints_limit } ) " )
model_data . sd_model = None
load_model ( checkpoint_info )
return model_data . sd_model
elif len ( model_data . loaded_sd_models ) > 0 :
sd_model = model_data . loaded_sd_models . pop ( )
model_data . sd_model = sd_model
2023-08-20 18:44:37 +08:00
sd_vae . base_vae = getattr ( sd_model , " base_vae " , None )
sd_vae . loaded_vae_file = getattr ( sd_model , " loaded_vae_file " , None )
sd_vae . checkpoint_info = sd_model . sd_checkpoint_info
2023-08-20 14:00:14 +08:00
2023-08-01 05:24:48 +08:00
print ( f " Reusing loaded model { sd_model . sd_checkpoint_info . title } to load { checkpoint_info . title } " )
return sd_model
else :
return None
2023-11-19 15:50:06 +08:00
def reload_model_weights ( sd_model = None , info = None , forced_reload = False ) :
2022-09-17 18:49:36 +08:00
checkpoint_info = info or select_checkpoint ( )
2023-01-04 17:35:07 +08:00
2023-08-01 05:24:48 +08:00
timer = Timer ( )
2022-11-01 15:01:49 +08:00
if not sd_model :
2023-05-02 14:08:00 +08:00
sd_model = model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-10 21:51:04 +08:00
if sd_model is None : # previous model load failed
2023-01-10 07:34:26 +08:00
current_checkpoint_info = None
else :
current_checkpoint_info = sd_model . sd_checkpoint_info
2023-11-19 15:50:06 +08:00
if check_fp8 ( sd_model ) != devices . fp8 :
# load from state dict again to prevent extra numerical errors
forced_reload = True
2023-12-06 15:16:10 +08:00
elif sd_model . sd_model_checkpoint == checkpoint_info . filename and not forced_reload :
2023-08-01 05:24:48 +08:00
return sd_model
2023-05-27 20:47:33 +08:00
2023-08-01 05:24:48 +08:00
sd_model = reuse_model_from_already_loaded ( sd_model , checkpoint_info , timer )
2023-11-19 15:50:06 +08:00
if not forced_reload and sd_model is not None and sd_model . sd_checkpoint_info . filename == checkpoint_info . filename :
2023-08-01 05:24:48 +08:00
return sd_model
2022-09-17 17:05:04 +08:00
2023-08-01 05:24:48 +08:00
if sd_model is not None :
sd_unet . apply_unet ( " None " )
send_model_to_cpu ( sd_model )
2023-01-27 16:54:19 +08:00
sd_hijack . model_hijack . undo_hijack ( sd_model )
2022-09-29 20:40:28 +08:00
2023-01-27 16:28:12 +08:00
state_dict = get_checkpoint_state_dict ( checkpoint_info , timer )
checkpoint_config = sd_models_config . find_checkpoint_config ( state_dict , checkpoint_info )
timer . record ( " find config " )
if sd_model is None or checkpoint_config != sd_model . used_config :
2023-08-01 05:24:48 +08:00
if sd_model is not None :
send_model_to_trash ( sd_model )
2023-03-09 12:56:19 +08:00
load_model ( checkpoint_info , already_loaded_state_dict = state_dict )
2023-05-02 14:08:00 +08:00
return model_data . sd_model
2023-01-27 16:28:12 +08:00
2023-01-04 17:35:07 +08:00
try :
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , checkpoint_info , state_dict , timer )
2023-05-10 12:52:45 +08:00
except Exception :
2023-01-04 17:35:07 +08:00
print ( " Failed to load checkpoint, restoring previous " )
2023-01-27 16:28:12 +08:00
load_model_weights ( sd_model , current_checkpoint_info , None , timer )
2023-01-04 17:35:07 +08:00
raise
finally :
sd_hijack . model_hijack . hijack ( sd_model )
2023-01-27 16:28:12 +08:00
timer . record ( " hijack " )
2023-08-22 23:49:08 +08:00
if not sd_model . lowvram :
2023-01-04 17:35:07 +08:00
sd_model . to ( devices . device )
2023-01-27 16:28:12 +08:00
timer . record ( " move model to device " )
2022-09-17 17:05:04 +08:00
2024-01-06 20:03:33 +08:00
script_callbacks . model_loaded_callback ( sd_model )
timer . record ( " script callbacks " )
2023-01-27 16:28:12 +08:00
print ( f " Weights loaded in { timer . summary ( ) } . " )
2023-01-04 17:35:07 +08:00
2023-08-01 05:24:48 +08:00
model_data . set_sd_model ( sd_model )
2023-08-07 13:16:20 +08:00
sd_unet . apply_unet ( )
2023-08-01 05:24:48 +08:00
2022-09-17 17:05:04 +08:00
return sd_model
2023-03-09 12:56:19 +08:00
2023-05-02 14:08:00 +08:00
2023-03-09 12:56:19 +08:00
def unload_model_weights ( sd_model = None , info = None ) :
2023-10-15 14:41:02 +08:00
send_model_to_cpu ( sd_model or shared . sd_model )
2023-03-09 12:56:19 +08:00
2023-04-04 15:26:44 +08:00
return sd_model
2023-05-18 01:22:38 +08:00
def apply_token_merging ( sd_model , token_merging_ratio ) :
2023-04-04 15:26:44 +08:00
"""
Applies speed and memory optimizations from tomesd .
"""
2023-05-18 01:22:38 +08:00
current_token_merging_ratio = getattr ( sd_model , ' applied_token_merged_ratio ' , 0 )
if current_token_merging_ratio == token_merging_ratio :
return
if current_token_merging_ratio > 0 :
tomesd . remove_patch ( sd_model )
if token_merging_ratio > 0 :
tomesd . apply_patch (
sd_model ,
ratio = token_merging_ratio ,
use_rand = False , # can cause issues with some samplers
merge_attn = True ,
merge_crossattn = False ,
merge_mlp = False
)
sd_model . applied_token_merged_ratio = token_merging_ratio