2022-09-03 17:08:45 +08:00
import torch
2024-06-09 21:47:08 +08:00
from modules import prompt_parser , sd_samplers_common
2022-09-03 17:08:45 +08:00
2023-01-30 14:51:06 +08:00
from modules . shared import opts , state
2022-09-03 17:08:45 +08:00
import modules . shared as shared
2022-11-02 08:38:17 +08:00
from modules . script_callbacks import CFGDenoiserParams , cfg_denoiser_callback
2023-02-11 10:18:38 +08:00
from modules . script_callbacks import CFGDenoisedParams , cfg_denoised_callback
2023-05-14 09:49:41 +08:00
from modules . script_callbacks import AfterCFGCallbackParams , cfg_after_cfg_callback
2022-09-03 17:08:45 +08:00
2022-10-23 01:48:13 +08:00
2023-07-12 02:16:43 +08:00
def catenate_conds ( conds ) :
if not isinstance ( conds [ 0 ] , dict ) :
return torch . cat ( conds )
return { key : torch . cat ( [ x [ key ] for x in conds ] ) for key in conds [ 0 ] . keys ( ) }
def subscript_cond ( cond , a , b ) :
if not isinstance ( cond , dict ) :
return cond [ a : b ]
return { key : vec [ a : b ] for key , vec in cond . items ( ) }
def pad_cond ( tensor , repeats , empty ) :
if not isinstance ( tensor , dict ) :
return torch . cat ( [ tensor , empty . repeat ( ( tensor . shape [ 0 ] , repeats , 1 ) ) ] , axis = 1 )
tensor [ ' crossattn ' ] = pad_cond ( tensor [ ' crossattn ' ] , repeats , empty )
return tensor
2022-09-03 17:08:45 +08:00
class CFGDenoiser ( torch . nn . Module ) :
2023-01-30 15:11:30 +08:00
"""
Classifier free guidance denoiser . A wrapper for stable diffusion model ( specifically for unet )
that can take a noisy picture and produce a noise - free picture using two guidances ( prompts )
instead of one . Originally , the second prompt is just an empty string , but we use non - empty
negative prompt .
"""
2023-08-09 03:09:40 +08:00
def __init__ ( self , sampler ) :
2022-09-03 17:08:45 +08:00
super ( ) . __init__ ( )
2023-08-09 03:09:40 +08:00
self . model_wrap = None
2022-09-03 17:08:45 +08:00
self . mask = None
self . nmask = None
self . init_latent = None
2023-08-09 03:09:40 +08:00
self . steps = None
2023-08-12 17:39:59 +08:00
""" number of steps as specified by user in UI """
self . total_steps = None
""" expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler """
2022-09-15 18:10:16 +08:00
self . step = 0
2023-02-04 16:06:17 +08:00
self . image_cfg_scale = None
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2024-01-28 03:30:12 +08:00
self . padded_cond_uncond_v0 = False
2023-08-09 00:20:11 +08:00
self . sampler = sampler
2023-08-09 03:09:40 +08:00
self . model_wrap = None
self . p = None
2023-11-29 07:10:22 +08:00
2024-07-06 15:31:08 +08:00
self . cond_scale_miltiplier = 1.0
self . need_last_noise_uncond = False
2024-06-17 08:47:21 +08:00
self . last_noise_uncond = None
2023-11-29 07:10:22 +08:00
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
2023-08-14 13:59:15 +08:00
self . mask_before_denoising = False
2023-08-09 03:09:40 +08:00
@property
def inner_model ( self ) :
raise NotImplementedError ( )
2022-12-24 23:38:16 +08:00
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale ) :
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
denoised = torch . clone ( denoised_uncond )
for i , conds in enumerate ( conds_list ) :
for cond_index , weight in conds :
denoised [ i ] + = ( x_out [ cond_index ] - denoised_uncond [ i ] ) * ( weight * cond_scale )
return denoised
2023-02-04 16:06:17 +08:00
def combine_denoised_for_edit_model ( self , x_out , cond_scale ) :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised = out_uncond + cond_scale * ( out_cond - out_img_cond ) + self . image_cfg_scale * ( out_img_cond - out_uncond )
return denoised
2023-08-09 00:20:11 +08:00
def get_pred_x0 ( self , x_in , x_out , sigma ) :
return x_out
2023-08-09 03:09:40 +08:00
def update_inner_model ( self ) :
self . model_wrap = None
c , uc = self . p . get_conds ( )
self . sampler . sampler_extra_args [ ' cond ' ] = c
self . sampler . sampler_extra_args [ ' uncond ' ] = uc
2024-01-28 03:30:12 +08:00
def pad_cond_uncond ( self , cond , uncond ) :
empty = shared . sd_model . cond_stage_model_empty_prompt
2024-01-29 15:20:27 +08:00
num_repeats = ( cond . shape [ 1 ] - uncond . shape [ 1 ] ) / / empty . shape [ 1 ]
2024-01-28 03:30:12 +08:00
if num_repeats < 0 :
cond = pad_cond ( cond , - num_repeats , empty )
self . padded_cond_uncond = True
elif num_repeats > 0 :
uncond = pad_cond ( uncond , num_repeats , empty )
self . padded_cond_uncond = True
return cond , uncond
def pad_cond_uncond_v0 ( self , cond , uncond ) :
"""
Pads the ' uncond ' tensor to match the shape of the ' cond ' tensor .
If ' uncond ' is a dictionary , it is assumed that the ' crossattn ' key holds the tensor to be padded .
If ' uncond ' is a tensor , it is padded directly .
If the number of columns in ' uncond ' is less than the number of columns in ' cond ' , the last column of ' uncond '
is repeated to match the number of columns in ' cond ' .
If the number of columns in ' uncond ' is greater than the number of columns in ' cond ' , ' uncond ' is truncated
to match the number of columns in ' cond ' .
Args :
cond ( torch . Tensor or DictWithShape ) : The condition tensor to match the shape of ' uncond ' .
uncond ( torch . Tensor or DictWithShape ) : The tensor to be padded , or a dictionary containing the tensor to be padded .
Returns :
tuple : A tuple containing the ' cond ' tensor and the padded ' uncond ' tensor .
Note :
This is the padding that was always used in DDIM before version 1.6 .0
"""
is_dict_cond = isinstance ( uncond , dict )
uncond_vec = uncond [ ' crossattn ' ] if is_dict_cond else uncond
if uncond_vec . shape [ 1 ] < cond . shape [ 1 ] :
last_vector = uncond_vec [ : , - 1 : ]
last_vector_repeated = last_vector . repeat ( [ 1 , cond . shape [ 1 ] - uncond_vec . shape [ 1 ] , 1 ] )
uncond_vec = torch . hstack ( [ uncond_vec , last_vector_repeated ] )
self . padded_cond_uncond_v0 = True
elif uncond_vec . shape [ 1 ] > cond . shape [ 1 ] :
uncond_vec = uncond_vec [ : , : cond . shape [ 1 ] ]
self . padded_cond_uncond_v0 = True
if is_dict_cond :
uncond [ ' crossattn ' ] = uncond_vec
else :
uncond = uncond_vec
return cond , uncond
2023-03-29 06:18:28 +08:00
def forward ( self , x , sigma , uncond , cond , cond_scale , s_min_uncond , image_cond ) :
2022-10-18 22:23:38 +08:00
if state . interrupted or state . skipped :
2023-01-30 14:51:06 +08:00
raise sd_samplers_common . InterruptedException
2022-10-18 22:23:38 +08:00
2024-02-21 05:22:40 +08:00
if sd_samplers_common . apply_refiner ( self , sigma ) :
2023-08-09 03:09:40 +08:00
cond = self . sampler . sampler_extra_args [ ' cond ' ]
uncond = self . sampler . sampler_extra_args [ ' uncond ' ]
2023-02-04 16:06:17 +08:00
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared . sd_model . cond_stage_key == " edit " and self . image_cfg_scale is not None and self . image_cfg_scale != 1.0
2022-10-06 04:16:27 +08:00
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
2022-09-15 18:10:16 +08:00
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
2023-05-10 16:05:02 +08:00
assert not is_edit_model or all ( len ( conds ) == 1 for conds in conds_list ) , " AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0) "
2023-02-04 16:06:17 +08:00
2023-12-07 07:54:42 +08:00
# If we use masks, blending between the denoised and original latent images occurs here.
2023-12-07 12:16:27 +08:00
def apply_blend ( current_latent ) :
blended_latent = current_latent * self . nmask + self . init_latent * self . mask
if self . p . scripts is not None :
from modules import scripts
mba = scripts . MaskBlendArgs ( current_latent , self . nmask , self . init_latent , self . mask , blended_latent , denoiser = self , sigma = sigma )
self . p . scripts . on_mask_blend ( self . p , mba )
blended_latent = mba . blended_latent
return blended_latent
2023-12-07 07:54:42 +08:00
2023-11-29 07:10:22 +08:00
# Blend in the original latents (before)
2023-08-14 13:59:15 +08:00
if self . mask_before_denoising and self . mask is not None :
2023-12-07 07:54:42 +08:00
x = apply_blend ( x )
2023-08-09 00:20:11 +08:00
2022-10-06 04:16:27 +08:00
batch_size = len ( conds_list )
repeats = [ len ( conds_list [ i ] ) for i in range ( batch_size ) ]
2023-03-25 10:48:16 +08:00
if shared . sd_model . model . conditioning_key == " crossattn-adm " :
image_uncond = torch . zeros_like ( image_cond )
2023-07-12 02:16:43 +08:00
make_condition_dict = lambda c_crossattn , c_adm : { " c_crossattn " : [ c_crossattn ] , " c_adm " : c_adm }
2023-03-25 10:48:16 +08:00
else :
image_uncond = image_cond
2023-07-12 02:16:43 +08:00
if isinstance ( uncond , dict ) :
make_condition_dict = lambda c_crossattn , c_concat : { * * c_crossattn , " c_concat " : [ c_concat ] }
else :
make_condition_dict = lambda c_crossattn , c_concat : { " c_crossattn " : [ c_crossattn ] , " c_concat " : [ c_concat ] }
2023-03-25 10:48:16 +08:00
2023-02-04 16:06:17 +08:00
if not is_edit_model :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] )
2023-02-04 16:06:17 +08:00
else :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] + [ torch . zeros_like ( self . init_latent ) ] )
2022-10-06 04:16:27 +08:00
2024-01-08 02:35:35 +08:00
denoiser_params = CFGDenoiserParams ( x_in , image_cond_in , sigma_in , state . sampling_step , state . sampling_steps , tensor , uncond , self )
2022-11-02 08:38:17 +08:00
cfg_denoiser_callback ( denoiser_params )
x_in = denoiser_params . x
image_cond_in = denoiser_params . image_cond
sigma_in = denoiser_params . sigma
2023-03-11 19:52:29 +08:00
tensor = denoiser_params . text_cond
uncond = denoiser_params . text_uncond
2023-04-29 20:57:09 +08:00
skip_uncond = False
2022-10-31 07:48:33 +08:00
2024-05-18 13:22:31 +08:00
if shared . opts . skip_early_cond != 0. and self . step / self . total_steps < = shared . opts . skip_early_cond :
2024-04-23 11:38:44 +08:00
skip_uncond = True
2024-06-08 14:56:09 +08:00
self . p . extra_generation_params [ " Skip Early CFG " ] = shared . opts . skip_early_cond
elif ( self . step % 2 or shared . opts . s_min_uncond_all ) and s_min_uncond > 0 and sigma [ 0 ] < s_min_uncond and not is_edit_model :
2023-04-29 20:57:09 +08:00
skip_uncond = True
2024-06-08 14:56:09 +08:00
self . p . extra_generation_params [ " NGMS " ] = s_min_uncond
if shared . opts . s_min_uncond_all :
self . p . extra_generation_params [ " NGMS all steps " ] = shared . opts . s_min_uncond_all
if skip_uncond :
2023-04-29 20:57:09 +08:00
x_in = x_in [ : - batch_size ]
sigma_in = sigma_in [ : - batch_size ]
2023-03-29 06:18:28 +08:00
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2024-01-28 03:30:12 +08:00
self . padded_cond_uncond_v0 = False
2024-02-17 18:21:08 +08:00
if shared . opts . pad_cond_uncond_v0 and tensor . shape [ 1 ] != uncond . shape [ 1 ] :
2024-01-28 03:30:12 +08:00
tensor , uncond = self . pad_cond_uncond_v0 ( tensor , uncond )
2024-02-17 18:21:08 +08:00
elif shared . opts . pad_cond_uncond and tensor . shape [ 1 ] != uncond . shape [ 1 ] :
tensor , uncond = self . pad_cond_uncond ( tensor , uncond )
2023-05-22 05:13:53 +08:00
2023-04-29 20:57:09 +08:00
if tensor . shape [ 1 ] == uncond . shape [ 1 ] or skip_uncond :
if is_edit_model :
2023-07-12 02:16:43 +08:00
cond_in = catenate_conds ( [ tensor , uncond , uncond ] )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
cond_in = tensor
else :
2023-07-12 02:16:43 +08:00
cond_in = catenate_conds ( [ tensor , uncond ] )
2022-10-08 20:25:59 +08:00
2023-08-21 20:07:10 +08:00
if shared . opts . batch_cond_uncond :
2023-07-12 02:16:43 +08:00
x_out = self . inner_model ( x_in , sigma_in , cond = make_condition_dict ( cond_in , image_cond_in ) )
2022-10-08 20:25:59 +08:00
else :
x_out = torch . zeros_like ( x_in )
for batch_offset in range ( 0 , x_out . shape [ 0 ] , batch_size ) :
a = batch_offset
b = a + batch_size
2023-07-13 04:52:43 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( subscript_cond ( cond_in , a , b ) , image_cond_in [ a : b ] ) )
2022-09-03 17:08:45 +08:00
else :
2022-10-06 04:16:27 +08:00
x_out = torch . zeros_like ( x_in )
2023-08-21 20:07:10 +08:00
batch_size = batch_size * 2 if shared . opts . batch_cond_uncond else batch_size
2022-10-08 20:25:59 +08:00
for batch_offset in range ( 0 , tensor . shape [ 0 ] , batch_size ) :
2022-10-06 04:16:27 +08:00
a = batch_offset
2022-10-08 20:25:59 +08:00
b = min ( a + batch_size , tensor . shape [ 0 ] )
2023-02-04 16:06:17 +08:00
if not is_edit_model :
2023-07-12 02:16:43 +08:00
c_crossattn = subscript_cond ( tensor , a , b )
2023-02-04 16:06:17 +08:00
else :
c_crossattn = torch . cat ( [ tensor [ a : b ] ] , uncond )
2023-03-25 10:48:16 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( c_crossattn , image_cond_in [ a : b ] ) )
2022-10-08 20:25:59 +08:00
2023-04-29 20:57:09 +08:00
if not skip_uncond :
2023-07-12 02:16:43 +08:00
x_out [ - uncond . shape [ 0 ] : ] = self . inner_model ( x_in [ - uncond . shape [ 0 ] : ] , sigma_in [ - uncond . shape [ 0 ] : ] , cond = make_condition_dict ( uncond , image_cond_in [ - uncond . shape [ 0 ] : ] ) )
2022-10-06 04:16:27 +08:00
2023-04-29 21:05:20 +08:00
denoised_image_indexes = [ x [ 0 ] [ 0 ] for x in conds_list ]
2023-04-29 20:57:09 +08:00
if skip_uncond :
fake_uncond = torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] )
2023-04-29 21:05:20 +08:00
x_out = torch . cat ( [ x_out , fake_uncond ] ) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
2023-04-29 20:57:09 +08:00
2023-05-14 09:49:41 +08:00
denoised_params = CFGDenoisedParams ( x_out , state . sampling_step , state . sampling_steps , self . inner_model )
2023-02-11 10:18:38 +08:00
cfg_denoised_callback ( denoised_params )
2024-07-06 15:31:08 +08:00
if self . need_last_noise_uncond :
self . last_noise_uncond = torch . clone ( x_out [ - uncond . shape [ 0 ] : ] )
2024-06-17 08:47:21 +08:00
2023-04-29 20:57:09 +08:00
if is_edit_model :
2024-07-06 15:31:08 +08:00
denoised = self . combine_denoised_for_edit_model ( x_out , cond_scale * self . cond_scale_miltiplier )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
denoised = self . combine_denoised ( x_out , conds_list , uncond , 1.0 )
else :
2024-07-06 15:31:08 +08:00
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale * self . cond_scale_miltiplier )
2022-09-03 17:08:45 +08:00
2023-11-29 07:10:22 +08:00
# Blend in the original latents (after)
2023-08-14 13:59:15 +08:00
if not self . mask_before_denoising and self . mask is not None :
2023-12-07 07:54:42 +08:00
denoised = apply_blend ( denoised )
2023-08-14 13:59:15 +08:00
2023-08-09 00:20:11 +08:00
self . sampler . last_latent = self . get_pred_x0 ( torch . cat ( [ x_in [ i : i + 1 ] for i in denoised_image_indexes ] ) , torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] ) , sigma )
if opts . live_preview_content == " Prompt " :
preview = self . sampler . last_latent
elif opts . live_preview_content == " Negative prompt " :
preview = self . get_pred_x0 ( x_in [ - uncond . shape [ 0 ] : ] , x_out [ - uncond . shape [ 0 ] : ] , sigma )
else :
preview = self . get_pred_x0 ( torch . cat ( [ x_in [ i : i + 1 ] for i in denoised_image_indexes ] ) , torch . cat ( [ denoised [ i : i + 1 ] for i in denoised_image_indexes ] ) , sigma )
sd_samplers_common . store_latent ( preview )
2022-09-03 17:08:45 +08:00
2023-05-14 09:49:41 +08:00
after_cfg_callback_params = AfterCFGCallbackParams ( denoised , state . sampling_step , state . sampling_steps )
cfg_after_cfg_callback ( after_cfg_callback_params )
2023-05-14 13:15:22 +08:00
denoised = after_cfg_callback_params . x
2023-05-14 09:49:41 +08:00
2022-09-15 18:10:16 +08:00
self . step + = 1
2022-09-03 17:08:45 +08:00
return denoised