2023-08-13 13:24:16 +08:00
from __future__ import annotations
2022-09-03 17:08:45 +08:00
import json
2023-05-29 13:54:13 +08:00
import logging
2022-09-03 17:08:45 +08:00
import math
import os
import sys
2023-04-07 00:42:26 +08:00
import hashlib
2023-08-13 13:24:16 +08:00
from dataclasses import dataclass , field
2022-09-03 17:08:45 +08:00
import torch
import numpy as np
2023-06-10 03:47:27 +08:00
from PIL import Image , ImageOps
2022-09-03 17:08:45 +08:00
import random
2022-09-13 17:51:57 +08:00
import cv2
from skimage import exposure
2023-08-13 13:24:16 +08:00
from typing import Any
2022-09-03 17:08:45 +08:00
2022-09-05 08:25:37 +08:00
import modules . sd_hijack
2024-01-01 22:25:30 +08:00
from modules import devices , prompt_parser , masking , sd_samplers , lowvram , infotext_utils , extra_networks , sd_vae_approx , scripts , sd_samplers_common , sd_unet , errors , rng
2023-08-10 02:24:16 +08:00
from modules . rng import slerp # noqa: F401
2022-09-03 17:08:45 +08:00
from modules . sd_hijack import model_hijack
2023-08-04 18:23:14 +08:00
from modules . sd_samplers_common import images_tensor_to_samples , decode_first_stage , approximation_indexes
2022-09-03 17:08:45 +08:00
from modules . shared import opts , cmd_opts , state
import modules . shared as shared
2023-01-26 00:15:42 +08:00
import modules . paths as paths
2022-09-07 17:32:28 +08:00
import modules . face_restoration
2022-09-03 17:08:45 +08:00
import modules . images as images
2022-09-10 04:16:02 +08:00
import modules . styles
2022-11-29 11:11:29 +08:00
import modules . sd_models as sd_models
import modules . sd_vae as sd_vae
2022-12-09 08:14:35 +08:00
from ldm . data . util import AddMiDaS
from ldm . models . diffusion . ddpm import LatentDepth2ImageDiffusion
2022-09-03 17:08:45 +08:00
2022-12-09 08:14:35 +08:00
from einops import repeat , rearrange
2022-12-12 07:03:36 +08:00
from blendmodes . blend import blendLayers , BlendType
2023-04-10 16:37:15 +08:00
2022-09-03 17:08:45 +08:00
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
2022-09-13 17:51:57 +08:00
def setup_color_correction ( image ) :
2022-09-23 08:57:42 +08:00
logging . info ( " Calibrating color correction. " )
2022-09-13 17:51:57 +08:00
correction_target = cv2 . cvtColor ( np . asarray ( image . copy ( ) ) , cv2 . COLOR_RGB2LAB )
return correction_target
2022-12-12 07:03:36 +08:00
def apply_color_correction ( correction , original_image ) :
2022-09-23 08:57:42 +08:00
logging . info ( " Applying color correction. " )
2022-09-13 17:51:57 +08:00
image = Image . fromarray ( cv2 . cvtColor ( exposure . match_histograms (
cv2 . cvtColor (
2022-12-12 07:03:36 +08:00
np . asarray ( original_image ) ,
2022-09-13 17:51:57 +08:00
cv2 . COLOR_RGB2LAB
) ,
correction ,
channel_axis = 2
) , cv2 . COLOR_LAB2RGB ) . astype ( " uint8 " ) )
2022-12-25 17:17:49 +08:00
2022-12-12 07:03:36 +08:00
image = blendLayers ( image , original_image , BlendType . LUMINOSITY )
2022-12-25 17:17:49 +08:00
2023-08-12 00:22:11 +08:00
return image . convert ( ' RGB ' )
2022-09-13 17:51:57 +08:00
2022-10-24 14:15:26 +08:00
2023-12-03 12:07:02 +08:00
def uncrop ( image , dest_size , paste_loc ) :
x , y , w , h = paste_loc
base_image = Image . new ( ' RGBA ' , dest_size )
image = images . resize_image ( 1 , image , w , h )
base_image . paste ( image , ( x , y ) )
image = base_image
return image
2022-10-24 14:15:26 +08:00
2023-12-03 12:07:02 +08:00
2023-12-07 12:16:27 +08:00
def apply_overlay ( image , paste_loc , overlay ) :
if overlay is None :
2022-10-24 14:15:26 +08:00
return image
if paste_loc is not None :
2023-12-03 12:07:02 +08:00
image = uncrop ( image , ( overlay . width , overlay . height ) , paste_loc )
2022-10-24 14:15:26 +08:00
image = image . convert ( ' RGBA ' )
image . alpha_composite ( overlay )
image = image . convert ( ' RGB ' )
2022-10-24 03:38:42 +08:00
return image
2022-09-13 17:51:57 +08:00
2023-12-04 16:57:21 +08:00
def create_binary_mask ( image , round = True ) :
2023-08-16 02:24:55 +08:00
if image . mode == ' RGBA ' and image . getextrema ( ) [ - 1 ] != ( 255 , 255 ) :
2023-12-04 16:57:21 +08:00
if round :
image = image . split ( ) [ - 1 ] . convert ( " L " ) . point ( lambda x : 255 if x > 128 else 0 )
else :
image = image . split ( ) [ - 1 ] . convert ( " L " )
2023-08-16 02:24:55 +08:00
else :
image = image . convert ( ' L ' )
return image
2022-10-09 08:13:13 +08:00
2023-01-04 22:58:07 +08:00
def txt2img_image_conditioning ( sd_model , x , width , height ) :
2023-03-25 10:48:16 +08:00
if sd_model . model . conditioning_key in { ' hybrid ' , ' concat ' } : # Inpainting models
2023-08-31 00:08:04 +08:00
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch . ones ( x . shape [ 0 ] , 3 , height , width , device = x . device ) * 0.5
2023-08-04 18:23:14 +08:00
image_conditioning = images_tensor_to_samples ( image_conditioning , approximation_indexes . get ( opts . sd_vae_encode_method ) )
2023-03-25 10:48:16 +08:00
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch . nn . functional . pad ( image_conditioning , ( 0 , 0 , 0 , 0 , 1 , 0 ) , value = 1.0 )
image_conditioning = image_conditioning . to ( x . dtype )
return image_conditioning
2023-01-04 22:58:07 +08:00
2023-03-25 10:48:16 +08:00
elif sd_model . model . conditioning_key == " crossattn-adm " : # UnCLIP models
2023-01-04 22:58:07 +08:00
2023-03-25 10:48:16 +08:00
return x . new_zeros ( x . shape [ 0 ] , 2 * sd_model . noise_augmentor . time_embed . dim , dtype = x . dtype , device = x . device )
2023-01-04 22:58:07 +08:00
2023-03-25 10:48:16 +08:00
else :
2023-12-21 20:15:51 +08:00
sd = sd_model . model . state_dict ( )
diffusion_model_input = sd . get ( ' diffusion_model.input_blocks.0.0.weight ' , None )
2023-12-27 10:20:56 +08:00
if diffusion_model_input is not None :
if diffusion_model_input . shape [ 1 ] == 9 :
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch . ones ( x . shape [ 0 ] , 3 , height , width , device = x . device ) * 0.5
image_conditioning = images_tensor_to_samples ( image_conditioning ,
approximation_indexes . get ( opts . sd_vae_encode_method ) )
2023-12-21 20:15:51 +08:00
2023-12-27 10:20:56 +08:00
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch . nn . functional . pad ( image_conditioning , ( 0 , 0 , 0 , 0 , 1 , 0 ) , value = 1.0 )
image_conditioning = image_conditioning . to ( x . dtype )
2023-12-21 20:15:51 +08:00
2023-12-27 10:20:56 +08:00
return image_conditioning
2023-12-21 20:15:51 +08:00
2023-03-25 10:48:16 +08:00
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x . new_zeros ( x . shape [ 0 ] , 5 , 1 , 1 , dtype = x . dtype , device = x . device )
2023-01-04 22:58:07 +08:00
2023-08-13 13:24:16 +08:00
@dataclass ( repr = False )
2023-01-17 04:09:08 +08:00
class StableDiffusionProcessing :
2023-08-13 13:24:16 +08:00
sd_model : object = None
outpath_samples : str = None
outpath_grids : str = None
prompt : str = " "
prompt_for_display : str = None
negative_prompt : str = " "
2023-08-14 14:48:40 +08:00
styles : list [ str ] = None
2023-08-13 13:24:16 +08:00
seed : int = - 1
subseed : int = - 1
subseed_strength : float = 0
seed_resize_from_h : int = - 1
seed_resize_from_w : int = - 1
seed_enable_extras : bool = True
sampler_name : str = None
batch_size : int = 1
n_iter : int = 1
steps : int = 50
cfg_scale : float = 7.0
width : int = 512
height : int = 512
restore_faces : bool = None
tiling : bool = None
do_not_save_samples : bool = False
do_not_save_grid : bool = False
extra_generation_params : dict [ str , Any ] = None
overlay_images : list = None
eta : float = None
do_not_reload_embeddings : bool = False
2023-10-02 12:38:27 +08:00
denoising_strength : float = None
2023-08-13 13:24:16 +08:00
ddim_discretize : str = None
s_min_uncond : float = None
s_churn : float = None
s_tmax : float = None
s_tmin : float = None
s_noise : float = None
override_settings : dict [ str , Any ] = None
override_settings_restore_afterwards : bool = True
sampler_index : int = None
refiner_checkpoint : str = None
refiner_switch_at : float = None
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks : bool = False
2023-08-13 22:31:10 +08:00
scripts_value : scripts . ScriptRunner = field ( default = None , init = False )
script_args_value : list = field ( default = None , init = False )
scripts_setup_complete : bool = field ( default = False , init = False )
2023-08-13 13:24:16 +08:00
2023-06-08 12:53:02 +08:00
cached_uc = [ None , None ]
cached_c = [ None , None ]
2023-08-13 20:07:37 +08:00
comments : dict = None
2023-08-13 13:24:16 +08:00
sampler : sd_samplers_common . Sampler | None = field ( default = None , init = False )
is_using_inpainting_conditioning : bool = field ( default = False , init = False )
paste_to : tuple | None = field ( default = None , init = False )
is_hr_pass : bool = field ( default = False , init = False )
c : tuple = field ( default = None , init = False )
uc : tuple = field ( default = None , init = False )
rng : rng . ImageRNG | None = field ( default = None , init = False )
step_multiplier : int = field ( default = 1 , init = False )
color_corrections : list = field ( default = None , init = False )
all_prompts : list = field ( default = None , init = False )
all_negative_prompts : list = field ( default = None , init = False )
all_seeds : list = field ( default = None , init = False )
all_subseeds : list = field ( default = None , init = False )
iteration : int = field ( default = 0 , init = False )
main_prompt : str = field ( default = None , init = False )
main_negative_prompt : str = field ( default = None , init = False )
prompts : list = field ( default = None , init = False )
negative_prompts : list = field ( default = None , init = False )
seeds : list = field ( default = None , init = False )
subseeds : list = field ( default = None , init = False )
extra_network_data : dict = field ( default = None , init = False )
user : str = field ( default = None , init = False )
sd_model_name : str = field ( default = None , init = False )
sd_model_hash : str = field ( default = None , init = False )
sd_vae_name : str = field ( default = None , init = False )
sd_vae_hash : str = field ( default = None , init = False )
2023-08-14 15:43:18 +08:00
is_api : bool = field ( default = False , init = False )
2023-08-13 13:24:16 +08:00
def __post_init__ ( self ) :
if self . sampler_index is not None :
2022-11-27 18:17:39 +08:00
print ( " sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name " , file = sys . stderr )
2022-11-19 17:01:51 +08:00
2023-08-13 20:07:37 +08:00
self . comments = { }
2023-08-14 15:15:10 +08:00
if self . styles is None :
self . styles = [ ]
2023-08-13 20:07:37 +08:00
2022-09-30 08:44:38 +08:00
self . sampler_noise_scheduler_override = None
2023-08-13 13:24:16 +08:00
self . s_min_uncond = self . s_min_uncond if self . s_min_uncond is not None else opts . s_min_uncond
self . s_churn = self . s_churn if self . s_churn is not None else opts . s_churn
self . s_tmin = self . s_tmin if self . s_tmin is not None else opts . s_tmin
self . s_tmax = ( self . s_tmax if self . s_tmax is not None else opts . s_tmax ) or float ( ' inf ' )
self . s_noise = self . s_noise if self . s_noise is not None else opts . s_noise
self . extra_generation_params = self . extra_generation_params or { }
self . override_settings = self . override_settings or { }
self . script_args = self . script_args or { }
2023-08-13 11:07:30 +08:00
self . refiner_checkpoint_info = None
2022-10-04 23:49:51 +08:00
2023-08-13 13:24:16 +08:00
if not self . seed_enable_extras :
2022-09-21 18:34:10 +08:00
self . subseed = - 1
self . subseed_strength = 0
self . seed_resize_from_h = 0
self . seed_resize_from_w = 0
2023-06-08 12:53:02 +08:00
self . cached_uc = StableDiffusionProcessing . cached_uc
self . cached_c = StableDiffusionProcessing . cached_c
2023-08-13 11:07:30 +08:00
2023-01-17 04:09:08 +08:00
@property
def sd_model ( self ) :
return shared . sd_model
2023-08-13 13:24:16 +08:00
@sd_model.setter
def sd_model ( self , value ) :
pass
2023-08-13 22:31:10 +08:00
@property
def scripts ( self ) :
return self . scripts_value
@scripts.setter
def scripts ( self , value ) :
self . scripts_value = value
if self . scripts_value and self . script_args_value and not self . scripts_setup_complete :
self . setup_scripts ( )
@property
def script_args ( self ) :
return self . script_args_value
@script_args.setter
def script_args ( self , value ) :
self . script_args_value = value
if self . scripts_value and self . script_args_value and not self . scripts_setup_complete :
self . setup_scripts ( )
def setup_scripts ( self ) :
self . scripts_setup_complete = True
2023-08-14 15:43:18 +08:00
self . scripts . setup_scrips ( self , is_ui = not self . is_api )
2023-08-13 22:31:10 +08:00
2023-08-13 20:07:37 +08:00
def comment ( self , text ) :
self . comments [ text ] = 1
2022-10-28 02:27:59 +08:00
def txt2img_image_conditioning ( self , x , width = None , height = None ) :
2023-01-04 22:58:07 +08:00
self . is_using_inpainting_conditioning = self . sd_model . model . conditioning_key in { ' hybrid ' , ' concat ' }
2022-11-19 17:47:52 +08:00
2023-01-04 22:58:07 +08:00
return txt2img_image_conditioning ( self . sd_model , x , width or self . width , height or self . height )
2022-10-28 02:27:59 +08:00
2022-12-09 08:14:35 +08:00
def depth2img_image_conditioning ( self , source_image ) :
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
transformer = AddMiDaS ( model_type = " dpt_hybrid " )
transformed = transformer ( { " jpg " : rearrange ( source_image [ 0 ] , " c h w -> h w c " ) } )
midas_in = torch . from_numpy ( transformed [ " midas_in " ] [ None , . . . ] ) . to ( device = shared . device )
midas_in = repeat ( midas_in , " 1 ... -> n ... " , n = self . batch_size )
2023-08-04 18:23:14 +08:00
conditioning_image = images_tensor_to_samples ( source_image * 0.5 + 0.5 , approximation_indexes . get ( opts . sd_vae_encode_method ) )
2022-12-09 08:14:35 +08:00
conditioning = torch . nn . functional . interpolate (
self . sd_model . depth_model ( midas_in ) ,
size = conditioning_image . shape [ 2 : ] ,
mode = " bicubic " ,
align_corners = False ,
)
( depth_min , depth_max ) = torch . aminmax ( conditioning )
conditioning = 2. * ( conditioning - depth_min ) / ( depth_max - depth_min ) - 1.
return conditioning
2022-10-28 02:27:59 +08:00
2023-01-26 04:25:25 +08:00
def edit_image_conditioning ( self , source_image ) :
2023-11-07 16:33:16 +08:00
conditioning_image = shared . sd_model . encode_first_stage ( source_image ) . mode ( )
2023-01-26 04:25:25 +08:00
return conditioning_image
2023-03-25 10:48:16 +08:00
def unclip_image_conditioning ( self , source_image ) :
c_adm = self . sd_model . embedder ( source_image )
if self . sd_model . noise_augmentor is not None :
noise_level = 0 # TODO: Allow other noise levels?
c_adm , noise_level_emb = self . sd_model . noise_augmentor ( c_adm , noise_level = repeat ( torch . tensor ( [ noise_level ] ) . to ( c_adm . device ) , ' 1 -> b ' , b = c_adm . shape [ 0 ] ) )
c_adm = torch . cat ( ( c_adm , noise_level_emb ) , 1 )
return c_adm
2023-12-04 16:57:21 +08:00
def inpainting_image_conditioning ( self , source_image , latent_image , image_mask = None , round_image_mask = True ) :
2022-11-19 17:47:52 +08:00
self . is_using_inpainting_conditioning = True
2022-10-28 02:27:59 +08:00
# Handle the different mask inputs
if image_mask is not None :
if torch . is_tensor ( image_mask ) :
conditioning_mask = image_mask
else :
conditioning_mask = np . array ( image_mask . convert ( " L " ) )
conditioning_mask = conditioning_mask . astype ( np . float32 ) / 255.0
conditioning_mask = torch . from_numpy ( conditioning_mask [ None , None ] )
2023-12-04 16:57:21 +08:00
if round_image_mask :
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch . round ( conditioning_mask )
2022-10-28 02:27:59 +08:00
else :
2022-10-30 01:35:51 +08:00
conditioning_mask = source_image . new_ones ( 1 , 1 , * source_image . shape [ - 2 : ] )
2022-10-28 02:27:59 +08:00
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
2023-01-25 12:51:45 +08:00
conditioning_mask = conditioning_mask . to ( device = source_image . device , dtype = source_image . dtype )
2022-10-28 02:27:59 +08:00
conditioning_image = torch . lerp (
source_image ,
source_image * ( 1.0 - conditioning_mask ) ,
getattr ( self , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight )
)
2022-12-15 10:01:32 +08:00
2022-10-28 02:27:59 +08:00
# Encode the new masked image using first stage of network.
2023-08-04 17:53:30 +08:00
conditioning_image = self . sd_model . get_first_stage_encoding ( self . sd_model . encode_first_stage ( conditioning_image ) )
2022-10-28 02:27:59 +08:00
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch . nn . functional . interpolate ( conditioning_mask , size = latent_image . shape [ - 2 : ] )
conditioning_mask = conditioning_mask . expand ( conditioning_image . shape [ 0 ] , - 1 , - 1 , - 1 )
image_conditioning = torch . cat ( [ conditioning_mask , conditioning_image ] , dim = 1 )
image_conditioning = image_conditioning . to ( shared . device ) . type ( self . sd_model . dtype )
return image_conditioning
2023-12-04 16:57:21 +08:00
def img2img_image_conditioning ( self , source_image , latent_image , image_mask = None , round_image_mask = True ) :
2023-01-27 23:19:43 +08:00
source_image = devices . cond_cast_float ( source_image )
2022-12-09 08:14:35 +08:00
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if isinstance ( self . sd_model , LatentDepth2ImageDiffusion ) :
2023-01-27 23:19:43 +08:00
return self . depth2img_image_conditioning ( source_image )
2022-12-09 08:14:35 +08:00
2023-01-26 04:25:25 +08:00
if self . sd_model . cond_stage_key == " edit " :
return self . edit_image_conditioning ( source_image )
2022-12-09 08:14:35 +08:00
if self . sampler . conditioning_key in { ' hybrid ' , ' concat ' } :
2023-12-05 08:41:51 +08:00
return self . inpainting_image_conditioning ( source_image , latent_image , image_mask = image_mask , round_image_mask = round_image_mask )
2022-12-09 08:14:35 +08:00
2023-03-25 10:48:16 +08:00
if self . sampler . conditioning_key == " crossattn-adm " :
return self . unclip_image_conditioning ( source_image )
2023-12-21 20:15:51 +08:00
sd = self . sampler . model_wrap . inner_model . model . state_dict ( )
diffusion_model_input = sd . get ( ' diffusion_model.input_blocks.0.0.weight ' , None )
2023-12-27 10:20:56 +08:00
if diffusion_model_input is not None :
if diffusion_model_input . shape [ 1 ] == 9 :
return self . inpainting_image_conditioning ( source_image , latent_image , image_mask = image_mask )
2023-12-21 20:15:51 +08:00
2022-12-09 08:14:35 +08:00
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image . new_zeros ( latent_image . shape [ 0 ] , 5 , 1 , 1 )
2022-09-19 21:42:56 +08:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
2022-09-03 17:08:45 +08:00
pass
2023-05-19 01:16:09 +08:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-09-03 17:08:45 +08:00
raise NotImplementedError ( )
2022-11-02 08:56:47 +08:00
def close ( self ) :
self . sampler = None
2023-05-19 01:16:09 +08:00
self . c = None
self . uc = None
2023-08-06 18:25:51 +08:00
if not opts . persistent_cond_cache :
2023-06-08 12:53:02 +08:00
StableDiffusionProcessing . cached_c = [ None , None ]
StableDiffusionProcessing . cached_uc = [ None , None ]
2022-11-02 08:56:47 +08:00
2023-05-18 01:22:38 +08:00
def get_token_merging_ratio ( self , for_hr = False ) :
if for_hr :
return self . token_merging_ratio_hr or opts . token_merging_ratio_hr or self . token_merging_ratio or opts . token_merging_ratio
return self . token_merging_ratio or opts . token_merging_ratio
2023-05-19 01:16:09 +08:00
def setup_prompts ( self ) :
2023-08-19 09:56:15 +08:00
if isinstance ( self . prompt , list ) :
2023-05-19 01:16:09 +08:00
self . all_prompts = self . prompt
2023-08-19 09:56:15 +08:00
elif isinstance ( self . negative_prompt , list ) :
2023-08-15 13:46:17 +08:00
self . all_prompts = [ self . prompt ] * len ( self . negative_prompt )
2023-05-19 01:16:09 +08:00
else :
self . all_prompts = self . batch_size * self . n_iter * [ self . prompt ]
2023-08-19 09:56:15 +08:00
if isinstance ( self . negative_prompt , list ) :
2023-05-19 01:16:09 +08:00
self . all_negative_prompts = self . negative_prompt
else :
2023-08-15 13:46:17 +08:00
self . all_negative_prompts = [ self . negative_prompt ] * len ( self . all_prompts )
if len ( self . all_prompts ) != len ( self . all_negative_prompts ) :
raise RuntimeError ( f " Received a different number of prompts ( { len ( self . all_prompts ) } ) and negative prompts ( { len ( self . all_negative_prompts ) } ) " )
2023-05-19 01:16:09 +08:00
self . all_prompts = [ shared . prompt_styles . apply_styles_to_prompt ( x , self . styles ) for x in self . all_prompts ]
self . all_negative_prompts = [ shared . prompt_styles . apply_negative_styles_to_prompt ( x , self . styles ) for x in self . all_negative_prompts ]
2023-08-09 12:45:06 +08:00
self . main_prompt = self . all_prompts [ 0 ]
self . main_negative_prompt = self . all_negative_prompts [ 0 ]
2023-08-14 15:41:36 +08:00
def cached_params ( self , required_prompts , steps , extra_network_data , hires_steps = None , use_old_scheduling = False ) :
2023-08-06 18:25:51 +08:00
""" Returns parameters that invalidate the cond cache if changed """
return (
required_prompts ,
steps ,
2023-08-09 22:46:30 +08:00
hires_steps ,
use_old_scheduling ,
2023-08-06 18:25:51 +08:00
opts . CLIP_stop_at_last_layers ,
shared . sd_model . sd_checkpoint_info ,
extra_network_data ,
opts . sdxl_crop_left ,
opts . sdxl_crop_top ,
self . width ,
self . height ,
2023-12-16 19:39:43 +08:00
opts . fp8_storage ,
opts . cache_fp16_weight ,
2023-08-06 18:25:51 +08:00
)
2023-08-14 15:35:17 +08:00
def get_conds_with_caching ( self , function , required_prompts , steps , caches , extra_network_data , hires_steps = None ) :
2023-05-19 01:16:09 +08:00
"""
Returns the result of calling function ( shared . sd_model , required_prompts , steps )
using a cache to store the result if the same arguments have been used before .
cache is an array containing two elements . The first element is a tuple
representing the previously used arguments , or None if no arguments
have been used before . The second element is where the previously
computed result is stored .
2023-06-04 21:29:02 +08:00
caches is a list with items described above .
2023-05-19 01:16:09 +08:00
"""
2023-07-14 22:54:09 +08:00
2023-08-24 15:07:54 +08:00
if shared . opts . use_old_scheduling :
old_schedules = prompt_parser . get_learned_conditioning_prompt_schedules ( required_prompts , steps , hires_steps , False )
new_schedules = prompt_parser . get_learned_conditioning_prompt_schedules ( required_prompts , steps , hires_steps , True )
if old_schedules != new_schedules :
self . extra_generation_params [ " Old prompt editing timelines " ] = True
2023-08-14 15:41:36 +08:00
cached_params = self . cached_params ( required_prompts , steps , extra_network_data , hires_steps , shared . opts . use_old_scheduling )
2023-07-14 22:54:09 +08:00
2023-06-04 21:29:02 +08:00
for cache in caches :
2023-07-14 22:54:09 +08:00
if cache [ 0 ] is not None and cached_params == cache [ 0 ] :
2023-06-04 21:29:02 +08:00
return cache [ 1 ]
cache = caches [ 0 ]
2023-05-19 01:16:09 +08:00
with devices . autocast ( ) :
2023-08-09 22:46:30 +08:00
cache [ 1 ] = function ( shared . sd_model , required_prompts , steps , hires_steps , shared . opts . use_old_scheduling )
2023-05-19 01:16:09 +08:00
2023-07-14 22:54:09 +08:00
cache [ 0 ] = cached_params
2023-05-19 01:16:09 +08:00
return cache [ 1 ]
def setup_conds ( self ) :
2023-07-13 04:52:43 +08:00
prompts = prompt_parser . SdConditioning ( self . prompts , width = self . width , height = self . height )
2023-07-13 16:35:52 +08:00
negative_prompts = prompt_parser . SdConditioning ( self . negative_prompts , width = self . width , height = self . height , is_negative_prompt = True )
2023-07-13 04:52:43 +08:00
2023-05-19 01:16:09 +08:00
sampler_config = sd_samplers . find_sampler_config ( self . sampler_name )
2023-08-12 17:39:59 +08:00
total_steps = sampler_config . total_steps ( self . steps ) if sampler_config else self . steps
self . step_multiplier = total_steps / / self . steps
2023-08-14 15:35:17 +08:00
self . firstpass_steps = total_steps
2023-08-12 17:39:59 +08:00
self . uc = self . get_conds_with_caching ( prompt_parser . get_learned_conditioning , negative_prompts , total_steps , [ self . cached_uc ] , self . extra_network_data )
self . c = self . get_conds_with_caching ( prompt_parser . get_multicond_learned_conditioning , prompts , total_steps , [ self . cached_c ] , self . extra_network_data )
2023-05-19 01:16:09 +08:00
2023-08-06 22:53:33 +08:00
def get_conds ( self ) :
return self . c , self . uc
2023-05-19 01:16:09 +08:00
def parse_extra_network_prompts ( self ) :
2023-06-04 04:19:34 +08:00
self . prompts , self . extra_network_data = extra_networks . parse_prompts ( self . prompts )
2023-05-19 01:16:09 +08:00
2023-08-06 11:21:36 +08:00
def save_samples ( self ) - > bool :
""" Returns whether generated images need to be written to disk """
return opts . samples_save and not self . do_not_save_samples and ( opts . save_incomplete_images or not state . interrupted and not state . skipped )
2022-09-03 17:08:45 +08:00
class Processed :
2023-01-01 04:40:55 +08:00
def __init__ ( self , p : StableDiffusionProcessing , images_list , seed = - 1 , info = " " , subseed = None , all_prompts = None , all_negative_prompts = None , all_seeds = None , all_subseeds = None , index_of_first_image = 0 , infotexts = None , comments = " " ) :
2022-09-03 17:08:45 +08:00
self . images = images_list
self . prompt = p . prompt
2022-09-13 00:57:31 +08:00
self . negative_prompt = p . negative_prompt
2022-09-03 17:08:45 +08:00
self . seed = seed
2022-09-17 03:20:56 +08:00
self . subseed = subseed
self . subseed_strength = p . subseed_strength
2022-09-03 17:08:45 +08:00
self . info = info
2023-08-13 20:07:37 +08:00
self . comments = " " . join ( f " { comment } \n " for comment in p . comments )
2022-09-03 17:08:45 +08:00
self . width = p . width
self . height = p . height
2022-11-19 17:01:51 +08:00
self . sampler_name = p . sampler_name
2022-09-03 17:08:45 +08:00
self . cfg_scale = p . cfg_scale
2023-02-04 08:15:32 +08:00
self . image_cfg_scale = getattr ( p , ' image_cfg_scale ' , None )
2022-09-03 17:08:45 +08:00
self . steps = p . steps
2022-09-19 14:02:10 +08:00
self . batch_size = p . batch_size
self . restore_faces = p . restore_faces
self . face_restoration_model = opts . face_restoration_model if p . restore_faces else None
2023-08-13 11:07:30 +08:00
self . sd_model_name = p . sd_model_name
self . sd_model_hash = p . sd_model_hash
self . sd_vae_name = p . sd_vae_name
self . sd_vae_hash = p . sd_vae_hash
2022-09-19 14:02:10 +08:00
self . seed_resize_from_w = p . seed_resize_from_w
self . seed_resize_from_h = p . seed_resize_from_h
self . denoising_strength = getattr ( p , ' denoising_strength ' , None )
self . extra_generation_params = p . extra_generation_params
self . index_of_first_image = index_of_first_image
2022-10-05 01:13:09 +08:00
self . styles = p . styles
2022-10-05 01:17:15 +08:00
self . job_timestamp = state . job_timestamp
2022-10-09 05:28:42 +08:00
self . clip_skip = opts . CLIP_stop_at_last_layers
2023-05-18 01:22:38 +08:00
self . token_merging_ratio = p . token_merging_ratio
self . token_merging_ratio_hr = p . token_merging_ratio_hr
2022-09-19 14:02:10 +08:00
2022-09-28 10:11:03 +08:00
self . eta = p . eta
2022-09-26 22:40:47 +08:00
self . ddim_discretize = p . ddim_discretize
self . s_churn = p . s_churn
self . s_tmin = p . s_tmin
self . s_tmax = p . s_tmax
self . s_noise = p . s_noise
2023-05-17 15:26:32 +08:00
self . s_min_uncond = p . s_min_uncond
2022-09-30 08:44:38 +08:00
self . sampler_noise_scheduler_override = p . sampler_noise_scheduler_override
2023-08-19 09:56:15 +08:00
self . prompt = self . prompt if not isinstance ( self . prompt , list ) else self . prompt [ 0 ]
self . negative_prompt = self . negative_prompt if not isinstance ( self . negative_prompt , list ) else self . negative_prompt [ 0 ]
self . seed = int ( self . seed if not isinstance ( self . seed , list ) else self . seed [ 0 ] ) if self . seed is not None else - 1
self . subseed = int ( self . subseed if not isinstance ( self . subseed , list ) else self . subseed [ 0 ] ) if self . subseed is not None else - 1
2022-11-19 17:47:52 +08:00
self . is_using_inpainting_conditioning = p . is_using_inpainting_conditioning
2022-09-19 14:02:10 +08:00
2022-11-19 18:23:25 +08:00
self . all_prompts = all_prompts or p . all_prompts or [ self . prompt ]
self . all_negative_prompts = all_negative_prompts or p . all_negative_prompts or [ self . negative_prompt ]
self . all_seeds = all_seeds or p . all_seeds or [ self . seed ]
self . all_subseeds = all_subseeds or p . all_subseeds or [ self . subseed ]
2022-09-28 22:05:23 +08:00
self . infotexts = infotexts or [ info ]
2023-09-08 05:19:52 +08:00
self . version = program_version ( )
2022-09-03 17:08:45 +08:00
def js ( self ) :
obj = {
2022-11-19 18:23:25 +08:00
" prompt " : self . all_prompts [ 0 ] ,
2022-09-19 14:02:10 +08:00
" all_prompts " : self . all_prompts ,
2022-11-19 18:23:25 +08:00
" negative_prompt " : self . all_negative_prompts [ 0 ] ,
" all_negative_prompts " : self . all_negative_prompts ,
2022-09-19 14:02:10 +08:00
" seed " : self . seed ,
" all_seeds " : self . all_seeds ,
" subseed " : self . subseed ,
" all_subseeds " : self . all_subseeds ,
2022-09-17 03:20:56 +08:00
" subseed_strength " : self . subseed_strength ,
2022-09-03 17:08:45 +08:00
" width " : self . width ,
" height " : self . height ,
2022-11-19 17:01:51 +08:00
" sampler_name " : self . sampler_name ,
2022-09-03 17:08:45 +08:00
" cfg_scale " : self . cfg_scale ,
" steps " : self . steps ,
2022-09-19 14:02:10 +08:00
" batch_size " : self . batch_size ,
" restore_faces " : self . restore_faces ,
" face_restoration_model " : self . face_restoration_model ,
2023-08-13 11:07:30 +08:00
" sd_model_name " : self . sd_model_name ,
2022-09-19 14:02:10 +08:00
" sd_model_hash " : self . sd_model_hash ,
2023-08-13 11:07:30 +08:00
" sd_vae_name " : self . sd_vae_name ,
" sd_vae_hash " : self . sd_vae_hash ,
2022-09-19 14:02:10 +08:00
" seed_resize_from_w " : self . seed_resize_from_w ,
" seed_resize_from_h " : self . seed_resize_from_h ,
" denoising_strength " : self . denoising_strength ,
" extra_generation_params " : self . extra_generation_params ,
" index_of_first_image " : self . index_of_first_image ,
2022-09-28 22:05:23 +08:00
" infotexts " : self . infotexts ,
2022-10-05 01:13:09 +08:00
" styles " : self . styles ,
2022-10-05 01:17:15 +08:00
" job_timestamp " : self . job_timestamp ,
2022-10-09 03:21:15 +08:00
" clip_skip " : self . clip_skip ,
2022-11-19 17:47:52 +08:00
" is_using_inpainting_conditioning " : self . is_using_inpainting_conditioning ,
2023-09-08 05:19:52 +08:00
" version " : self . version ,
2022-09-03 17:08:45 +08:00
}
return json . dumps ( obj )
2022-12-14 05:05:40 +08:00
def infotext ( self , p : StableDiffusionProcessing , index ) :
2022-09-19 14:02:10 +08:00
return create_infotext ( p , self . all_prompts , self . all_seeds , self . all_subseeds , comments = [ ] , position_in_batch = index % self . batch_size , iteration = index / / self . batch_size )
2023-05-18 01:22:38 +08:00
def get_token_merging_ratio ( self , for_hr = False ) :
return self . token_merging_ratio_hr if for_hr else self . token_merging_ratio
2022-09-19 14:02:10 +08:00
2022-09-14 02:49:58 +08:00
def create_random_tensors ( shape , seeds , subseeds = None , subseed_strength = 0.0 , seed_resize_from_h = 0 , seed_resize_from_w = 0 , p = None ) :
2023-08-09 13:43:31 +08:00
g = rng . ImageRNG ( shape , seeds , subseeds = subseeds , subseed_strength = subseed_strength , seed_resize_from_h = seed_resize_from_h , seed_resize_from_w = seed_resize_from_w )
return g . next ( )
2022-09-03 17:08:45 +08:00
2023-07-30 20:30:33 +08:00
class DecodedSamples ( list ) :
already_decoded = True
2023-07-20 01:23:30 +08:00
def decode_latent_batch ( model , batch , target_device = None , check_for_nans = False ) :
2023-07-30 20:30:33 +08:00
samples = DecodedSamples ( )
2023-07-20 01:23:30 +08:00
for i in range ( batch . shape [ 0 ] ) :
sample = decode_first_stage ( model , batch [ i : i + 1 ] ) [ 0 ]
if check_for_nans :
2024-01-01 21:28:58 +08:00
2023-07-20 01:23:30 +08:00
try :
devices . test_for_nans ( sample , " vae " )
except devices . NansException as e :
2024-01-01 21:28:58 +08:00
if shared . opts . auto_vae_precision_bfloat16 :
autofix_dtype = torch . bfloat16
autofix_dtype_text = " bfloat16 "
autofix_dtype_setting = " Automatically convert VAE to bfloat16 "
autofix_dtype_comment = " "
elif shared . opts . auto_vae_precision :
autofix_dtype = torch . float32
autofix_dtype_text = " 32-bit float "
autofix_dtype_setting = " Automatically revert VAE to 32-bit floats "
autofix_dtype_comment = " \n To always start with 32-bit VAE, use --no-half-vae commandline flag. "
else :
raise e
if devices . dtype_vae == autofix_dtype :
2023-07-20 01:23:30 +08:00
raise e
errors . print_error_explanation (
" A tensor with all NaNs was produced in VAE. \n "
2024-01-01 21:28:58 +08:00
f " Web UI will now convert VAE into { autofix_dtype_text } and retry. \n "
f " To disable this behavior, disable the ' { autofix_dtype_setting } ' setting. { autofix_dtype_comment } "
2023-07-20 01:23:30 +08:00
)
2024-01-01 21:28:58 +08:00
devices . dtype_vae = autofix_dtype
2023-07-20 01:23:30 +08:00
model . first_stage_model . to ( devices . dtype_vae )
batch = batch . to ( devices . dtype_vae )
sample = decode_first_stage ( model , batch [ i : i + 1 ] ) [ 0 ]
if target_device is not None :
sample = sample . to ( target_device )
samples . append ( sample )
return samples
2022-10-04 22:36:39 +08:00
def get_fixed_seed ( seed ) :
2023-08-10 20:58:53 +08:00
if seed == ' ' or seed is None :
seed = - 1
elif isinstance ( seed , str ) :
try :
seed = int ( seed )
except Exception :
seed = - 1
if seed == - 1 :
2022-10-04 22:36:39 +08:00
return int ( random . randrange ( 4294967294 ) )
return seed
2022-09-09 22:54:04 +08:00
def fix_seed ( p ) :
2022-10-04 22:36:39 +08:00
p . seed = get_fixed_seed ( p . seed )
p . subseed = get_fixed_seed ( p . subseed )
2022-09-07 06:44:44 +08:00
2023-05-08 20:23:49 +08:00
def program_version ( ) :
import launch
res = launch . git_tag ( )
if res == " <none> " :
res = None
return res
2023-07-26 11:36:06 +08:00
def create_infotext ( p , all_prompts , all_seeds , all_subseeds , comments = None , iteration = 0 , position_in_batch = 0 , use_main_prompt = False , index = None , all_negative_prompts = None ) :
if index is None :
index = position_in_batch + iteration * p . batch_size
if all_negative_prompts is None :
all_negative_prompts = p . all_negative_prompts
2022-09-19 14:02:10 +08:00
2022-10-09 05:28:42 +08:00
clip_skip = getattr ( p , ' clip_skip ' , opts . CLIP_stop_at_last_layers )
2023-05-14 00:11:02 +08:00
enable_hr = getattr ( p , ' enable_hr ' , False )
2023-05-18 01:22:38 +08:00
token_merging_ratio = p . get_token_merging_ratio ( )
token_merging_ratio_hr = p . get_token_merging_ratio ( for_hr = True )
2022-10-09 03:21:15 +08:00
2023-05-16 16:54:02 +08:00
uses_ensd = opts . eta_noise_seed_delta != 0
if uses_ensd :
uses_ensd = sd_samplers_common . is_sampler_using_eta_noise_seed_delta ( p )
2022-09-19 14:02:10 +08:00
generation_params = {
" Steps " : p . steps ,
2022-11-19 17:01:51 +08:00
" Sampler " : p . sampler_name ,
2022-09-19 14:02:10 +08:00
" CFG scale " : p . cfg_scale ,
2023-02-04 07:19:56 +08:00
" Image CFG scale " : getattr ( p , ' image_cfg_scale ' , None ) ,
2023-07-26 12:04:07 +08:00
" Seed " : p . all_seeds [ 0 ] if use_main_prompt else all_seeds [ index ] ,
2023-08-10 17:41:41 +08:00
" Face restoration " : opts . face_restoration_model if p . restore_faces else None ,
2022-09-19 14:02:10 +08:00
" Size " : f " { p . width } x { p . height } " ,
2023-08-13 11:07:30 +08:00
" Model hash " : p . sd_model_hash if opts . add_model_hash_to_info else None ,
" Model " : p . sd_model_name if opts . add_model_name_to_info else None ,
2023-12-16 15:08:08 +08:00
" FP8 weight " : opts . fp8_storage if devices . fp8 else None ,
" Cache FP16 weight for LoRA " : opts . cache_fp16_weight if devices . fp8 else None ,
2023-12-02 13:33:28 +08:00
" VAE hash " : p . sd_vae_hash if opts . add_vae_hash_to_info else None ,
" VAE " : p . sd_vae_name if opts . add_vae_name_to_info else None ,
2023-07-26 12:04:07 +08:00
" Variation seed " : ( None if p . subseed_strength == 0 else ( p . all_subseeds [ 0 ] if use_main_prompt else all_subseeds [ index ] ) ) ,
2022-09-19 14:02:10 +08:00
" Variation seed strength " : ( None if p . subseed_strength == 0 else p . subseed_strength ) ,
2023-07-04 01:41:10 +08:00
" Seed resize from " : ( None if p . seed_resize_from_w < = 0 or p . seed_resize_from_h < = 0 else f " { p . seed_resize_from_w } x { p . seed_resize_from_h } " ) ,
2022-09-19 14:02:10 +08:00
" Denoising strength " : getattr ( p , ' denoising_strength ' , None ) ,
2022-11-19 17:47:52 +08:00
" Conditional mask weight " : getattr ( p , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight ) if p . is_using_inpainting_conditioning else None ,
2022-10-10 03:30:59 +08:00
" Clip skip " : None if clip_skip < = 1 else clip_skip ,
2023-05-16 16:54:02 +08:00
" ENSD " : opts . eta_noise_seed_delta if uses_ensd else None ,
2023-05-18 01:22:38 +08:00
" Token merging ratio " : None if token_merging_ratio == 0 else token_merging_ratio ,
" Token merging ratio hr " : None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr ,
2023-04-29 16:29:37 +08:00
" Init image hash " : getattr ( p , ' init_img_hash ' , None ) ,
2023-08-29 03:22:35 +08:00
" RNG " : opts . randn_source if opts . randn_source != " GPU " else None ,
2023-04-29 20:57:09 +08:00
" NGMS " : None if p . s_min_uncond == 0 else p . s_min_uncond ,
2023-08-10 17:41:41 +08:00
" Tiling " : " True " if p . tiling else None ,
2023-05-16 16:54:02 +08:00
* * p . extra_generation_params ,
2023-05-08 20:23:49 +08:00
" Version " : program_version ( ) if opts . add_version_to_infotext else None ,
2023-06-15 23:55:53 +08:00
" User " : p . user if opts . add_user_name_to_info else None ,
2022-09-19 14:02:10 +08:00
}
2024-01-01 22:25:30 +08:00
generation_params_text = " , " . join ( [ k if k == v else f ' { k } : { infotext_utils . quote ( v ) } ' for k , v in generation_params . items ( ) if v is not None ] )
2022-09-19 14:02:10 +08:00
2023-08-09 12:45:06 +08:00
prompt_text = p . main_prompt if use_main_prompt else all_prompts [ index ]
negative_prompt_text = f " \n Negative prompt: { p . main_negative_prompt if use_main_prompt else all_negative_prompts [ index ] } " if all_negative_prompts [ index ] else " "
2022-09-19 14:02:10 +08:00
2023-06-20 02:36:44 +08:00
return f " { prompt_text } { negative_prompt_text } \n { generation_params_text } " . strip ( )
2022-09-19 14:02:10 +08:00
2022-09-03 17:08:45 +08:00
def process_images ( p : StableDiffusionProcessing ) - > Processed :
2023-06-01 03:40:09 +08:00
if p . scripts is not None :
p . scripts . before_process ( p )
2023-10-09 18:36:48 +08:00
stored_opts = { k : opts . data [ k ] if k in opts . data else opts . get_default ( k ) for k in p . override_settings . keys ( ) if k in opts . data }
2022-10-26 16:47:07 +08:00
try :
2023-05-01 04:22:53 +08:00
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
2023-08-18 16:14:02 +08:00
# and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards
2023-07-03 17:17:20 +08:00
if sd_models . checkpoint_aliases . get ( p . override_settings . get ( ' sd_model_checkpoint ' ) ) is None :
2023-05-01 04:22:53 +08:00
p . override_settings . pop ( ' sd_model_checkpoint ' , None )
sd_models . reload_model_weights ( )
2022-10-26 16:47:07 +08:00
for k , v in p . override_settings . items ( ) :
2023-08-21 13:58:15 +08:00
opts . set ( k , v , is_api = True , run_callbacks = False )
2023-01-05 15:21:17 +08:00
if k == ' sd_model_checkpoint ' :
2023-01-21 13:36:07 +08:00
sd_models . reload_model_weights ( )
2023-01-05 15:21:17 +08:00
if k == ' sd_vae ' :
2023-01-21 13:36:07 +08:00
sd_vae . reload_vae_weights ( )
2022-10-26 16:47:07 +08:00
2023-05-18 01:22:38 +08:00
sd_models . apply_token_merging ( p . sd_model , p . get_token_merging_ratio ( ) )
2023-04-02 11:18:35 +08:00
2022-10-26 16:47:07 +08:00
res = process_images_inner ( p )
2022-12-20 17:36:49 +08:00
finally :
2023-05-18 01:22:38 +08:00
sd_models . apply_token_merging ( p . sd_model , 0 )
2023-04-02 11:18:35 +08:00
2022-12-20 17:36:49 +08:00
# restore opts to original state
if p . override_settings_restore_afterwards :
for k , v in stored_opts . items ( ) :
setattr ( opts , k , v )
2023-01-21 13:36:07 +08:00
if k == ' sd_vae ' :
sd_vae . reload_vae_weights ( )
2022-10-26 16:47:07 +08:00
return res
def process_images_inner ( p : StableDiffusionProcessing ) - > Processed :
2022-09-03 17:08:45 +08:00
""" this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch """
2023-08-19 09:56:15 +08:00
if isinstance ( p . prompt , list ) :
2022-09-17 16:34:33 +08:00
assert ( len ( p . prompt ) > 0 )
else :
assert p . prompt is not None
2022-10-04 23:49:51 +08:00
2022-09-12 04:24:24 +08:00
devices . torch_gc ( )
2022-09-03 17:08:45 +08:00
2022-10-04 22:36:39 +08:00
seed = get_fixed_seed ( p . seed )
subseed = get_fixed_seed ( p . subseed )
2022-09-03 17:08:45 +08:00
2023-08-10 17:41:41 +08:00
if p . restore_faces is None :
p . restore_faces = opts . face_restoration
if p . tiling is None :
p . tiling = opts . tiling
2023-08-15 13:27:50 +08:00
if p . refiner_checkpoint not in ( None , " " , " None " , " none " ) :
2023-08-13 11:07:30 +08:00
p . refiner_checkpoint_info = sd_models . get_closet_checkpoint_match ( p . refiner_checkpoint )
if p . refiner_checkpoint_info is None :
raise Exception ( f ' Could not find checkpoint with name { p . refiner_checkpoint } ' )
p . sd_model_name = shared . sd_model . sd_checkpoint_info . name_for_extra
p . sd_model_hash = shared . sd_model . sd_model_hash
p . sd_vae_name = sd_vae . get_loaded_vae_name ( )
p . sd_vae_hash = sd_vae . get_loaded_vae_hash ( )
2023-08-12 17:39:59 +08:00
2022-09-05 08:25:37 +08:00
modules . sd_hijack . model_hijack . apply_circular ( p . tiling )
2022-10-08 05:48:34 +08:00
modules . sd_hijack . model_hijack . clear_comments ( )
2022-09-05 08:25:37 +08:00
2023-05-19 01:16:09 +08:00
p . setup_prompts ( )
2023-01-22 19:28:53 +08:00
2023-08-19 09:56:15 +08:00
if isinstance ( seed , list ) :
2022-10-22 17:23:45 +08:00
p . all_seeds = seed
2022-09-03 22:21:15 +08:00
else :
2022-10-22 17:23:45 +08:00
p . all_seeds = [ int ( seed ) + ( x if p . subseed_strength == 0 else 0 ) for x in range ( len ( p . all_prompts ) ) ]
2022-09-09 22:54:04 +08:00
2023-08-19 09:56:15 +08:00
if isinstance ( subseed , list ) :
2022-10-22 17:23:45 +08:00
p . all_subseeds = subseed
2022-09-09 22:54:04 +08:00
else :
2022-10-22 17:23:45 +08:00
p . all_subseeds = [ int ( subseed ) + x for x in range ( len ( p . all_prompts ) ) ]
2022-09-03 17:08:45 +08:00
2022-10-16 13:51:24 +08:00
if os . path . exists ( cmd_opts . embeddings_dir ) and not p . do_not_reload_embeddings :
2022-10-02 20:03:39 +08:00
model_hijack . embedding_db . load_textual_inversion_embeddings ( )
2022-09-03 17:08:45 +08:00
2022-10-22 17:23:45 +08:00
if p . scripts is not None :
2022-10-30 03:20:02 +08:00
p . scripts . process ( p )
2022-10-22 17:23:45 +08:00
2022-09-28 22:05:23 +08:00
infotexts = [ ]
2022-09-03 17:08:45 +08:00
output_images = [ ]
2022-10-09 04:26:48 +08:00
with torch . no_grad ( ) , p . sd_model . ema_scope ( ) :
2022-10-04 21:54:31 +08:00
with devices . autocast ( ) :
2022-10-22 17:23:45 +08:00
p . init ( p . all_prompts , p . all_seeds , p . all_subseeds )
2022-09-03 17:08:45 +08:00
2023-01-25 23:56:23 +08:00
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
if shared . opts . live_previews_enable and opts . show_progress_type == " Approx NN " :
2023-01-24 11:49:20 +08:00
sd_vae_approx . model ( )
2023-05-27 20:47:33 +08:00
sd_unet . apply_unet ( )
2022-09-06 15:11:25 +08:00
if state . job_count == - 1 :
state . job_count = p . n_iter
2022-09-06 07:09:01 +08:00
2022-10-05 09:28:50 +08:00
for n in range ( p . n_iter ) :
2023-01-04 22:24:46 +08:00
p . iteration = n
2022-10-05 11:56:30 +08:00
if state . skipped :
state . skipped = False
2022-12-15 10:01:32 +08:00
2024-01-01 21:50:59 +08:00
if state . interrupted or state . stopping_generation :
2022-09-03 17:08:45 +08:00
break
2023-08-09 03:14:02 +08:00
sd_models . reload_model_weights ( ) # model can be changed for example by refiner
2023-05-19 01:16:09 +08:00
p . prompts = p . all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . negative_prompts = p . all_negative_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . seeds = p . all_seeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . subseeds = p . all_subseeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2022-09-03 17:08:45 +08:00
2023-08-09 13:43:31 +08:00
p . rng = rng . ImageRNG ( ( opt_C , p . height / / opt_f , p . width / / opt_f ) , p . seeds , subseeds = p . subseeds , subseed_strength = p . subseed_strength , seed_resize_from_h = p . seed_resize_from_h , seed_resize_from_w = p . seed_resize_from_w )
2023-02-22 17:52:53 +08:00
if p . scripts is not None :
2023-05-19 01:16:09 +08:00
p . scripts . before_process_batch ( p , batch_number = n , prompts = p . prompts , seeds = p . seeds , subseeds = p . subseeds )
2023-02-22 17:52:53 +08:00
2023-05-19 01:16:09 +08:00
if len ( p . prompts ) == 0 :
2022-09-17 16:34:33 +08:00
break
2023-06-04 04:19:34 +08:00
p . parse_extra_network_prompts ( )
2023-03-02 02:30:20 +08:00
2023-02-13 19:33:28 +08:00
if not p . disable_extra_networks :
with devices . autocast ( ) :
2023-06-04 03:24:44 +08:00
extra_networks . activate ( p , p . extra_network_data )
2023-01-21 21:41:25 +08:00
2022-11-03 00:05:01 +08:00
if p . scripts is not None :
2023-05-19 01:16:09 +08:00
p . scripts . process_batch ( p , batch_number = n , prompts = p . prompts , seeds = p . seeds , subseeds = p . subseeds )
2022-11-03 00:05:01 +08:00
2023-02-17 13:44:46 +08:00
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
# strength, which is saved as "Model Strength: 1.0" in the infotext
if n == 0 :
with open ( os . path . join ( paths . data_path , " params.txt " ) , " w " , encoding = " utf8 " ) as file :
2023-08-10 20:58:53 +08:00
processed = Processed ( p , [ ] )
2023-02-17 13:44:46 +08:00
file . write ( processed . infotext ( p , 0 ) )
2023-05-19 01:16:09 +08:00
p . setup_conds ( )
2022-09-03 17:08:45 +08:00
2023-07-15 13:41:22 +08:00
for comment in model_hijack . comments :
2023-08-13 20:07:37 +08:00
p . comment ( comment )
2023-07-15 13:41:22 +08:00
p . extra_generation_params . update ( model_hijack . extra_generation_params )
2022-09-03 17:08:45 +08:00
if p . n_iter > 1 :
2022-09-24 13:23:01 +08:00
shared . state . job = f " Batch { n + 1 } out of { p . n_iter } "
2022-09-03 17:08:45 +08:00
2023-11-30 06:42:07 +08:00
def rescale_zero_terminal_snr_abar ( alphas_cumprod ) :
alphas_bar_sqrt = alphas_cumprod . sqrt ( )
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt [ 0 ] . clone ( )
alphas_bar_sqrt_T = alphas_bar_sqrt [ - 1 ] . clone ( )
# Shift so the last timestep is zero.
alphas_bar_sqrt - = ( alphas_bar_sqrt_T )
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt * = alphas_bar_sqrt_0 / ( alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt * * 2 # Revert sqrt
alphas_bar [ - 1 ] = 4.8973451890853435e-08
return alphas_bar
2023-11-30 07:10:43 +08:00
2023-12-03 03:05:42 +08:00
if hasattr ( p . sd_model , ' alphas_cumprod ' ) and hasattr ( p . sd_model , ' alphas_cumprod_original ' ) :
p . sd_model . alphas_cumprod = p . sd_model . alphas_cumprod_original . to ( shared . device )
if opts . use_downcasted_alpha_bar :
p . extra_generation_params [ ' Downcast alphas_cumprod ' ] = opts . use_downcasted_alpha_bar
p . sd_model . alphas_cumprod = p . sd_model . alphas_cumprod . half ( ) . to ( shared . device )
if opts . sd_noise_schedule == " Zero Terminal SNR " :
p . extra_generation_params [ ' Noise Schedule ' ] = opts . sd_noise_schedule
p . sd_model . alphas_cumprod = rescale_zero_terminal_snr_abar ( p . sd_model . alphas_cumprod ) . to ( shared . device )
2022-09-03 17:08:45 +08:00
2023-01-25 13:23:10 +08:00
with devices . without_autocast ( ) if devices . unet_needs_upcast else devices . autocast ( ) :
2023-05-19 01:16:09 +08:00
samples_ddim = p . sample ( conditioning = p . c , unconditional_conditioning = p . uc , seeds = p . seeds , subseeds = p . subseeds , subseed_strength = p . subseed_strength , prompts = p . prompts )
2022-10-04 17:32:22 +08:00
2023-12-07 12:16:27 +08:00
if p . scripts is not None :
ps = scripts . PostSampleArgs ( samples_ddim )
p . scripts . post_sample ( p , ps )
2023-12-07 13:25:53 +08:00
samples_ddim = ps . samples
2023-12-07 12:16:27 +08:00
2023-07-30 20:30:33 +08:00
if getattr ( samples_ddim , ' already_decoded ' , False ) :
x_samples_ddim = samples_ddim
else :
2023-08-05 15:36:26 +08:00
if opts . sd_vae_decode_method != ' Full ' :
p . extra_generation_params [ ' VAE Decoder ' ] = opts . sd_vae_decode_method
2023-07-30 20:30:33 +08:00
x_samples_ddim = decode_latent_batch ( p . sd_model , samples_ddim , target_device = devices . cpu , check_for_nans = True )
2022-11-28 19:29:43 +08:00
x_samples_ddim = torch . stack ( x_samples_ddim ) . float ( )
2022-09-03 17:08:45 +08:00
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-09-29 09:14:13 +08:00
del samples_ddim
2023-06-04 18:07:22 +08:00
if lowvram . is_enabled ( shared . sd_model ) :
2022-09-29 09:14:13 +08:00
lowvram . send_everything_to_cpu ( )
devices . torch_gc ( )
2023-10-25 20:37:55 +08:00
state . nextjob ( )
2022-12-10 19:54:02 +08:00
if p . scripts is not None :
p . scripts . postprocess_batch ( p , x_samples_ddim , batch_number = n )
2022-09-13 08:15:35 +08:00
2023-07-26 12:49:57 +08:00
p . prompts = p . all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
p . negative_prompts = p . all_negative_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2023-07-26 11:36:06 +08:00
2023-07-26 12:49:57 +08:00
batch_params = scripts . PostprocessBatchListArgs ( list ( x_samples_ddim ) )
2023-07-26 11:36:06 +08:00
p . scripts . postprocess_batch_list ( p , batch_params , batch_number = n )
x_samples_ddim = batch_params . images
def infotext ( index = 0 , use_main_prompt = False ) :
2023-07-26 12:04:07 +08:00
return create_infotext ( p , p . prompts , p . seeds , p . subseeds , use_main_prompt = use_main_prompt , index = index , all_negative_prompts = p . negative_prompts )
2023-07-25 01:52:24 +08:00
2023-08-06 11:21:36 +08:00
save_samples = p . save_samples ( )
2023-08-05 18:07:39 +08:00
2022-10-05 09:28:50 +08:00
for i , x_sample in enumerate ( x_samples_ddim ) :
2023-04-07 20:04:46 +08:00
p . batch_index = i
2022-09-03 17:08:45 +08:00
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
x_sample = x_sample . astype ( np . uint8 )
2022-10-05 09:28:50 +08:00
if p . restore_faces :
2023-08-06 11:21:36 +08:00
if save_samples and opts . save_images_before_face_restoration :
2023-07-26 11:36:06 +08:00
images . save_image ( Image . fromarray ( x_sample ) , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( i ) , p = p , suffix = " -before-face-restoration " )
2022-09-12 22:47:36 +08:00
2022-10-04 17:32:22 +08:00
devices . torch_gc ( )
2022-09-03 17:08:45 +08:00
2022-10-05 09:28:50 +08:00
x_sample = modules . face_restoration . restore_faces ( x_sample )
devices . torch_gc ( )
2022-09-29 09:14:13 +08:00
2022-09-03 17:08:45 +08:00
image = Image . fromarray ( x_sample )
2022-10-24 03:44:46 +08:00
2023-01-27 04:29:27 +08:00
if p . scripts is not None :
pp = scripts . PostprocessImageArgs ( image )
p . scripts . postprocess_image ( p , pp )
image = pp . image
2023-12-07 12:16:27 +08:00
2023-12-08 05:54:26 +08:00
mask_for_overlay = getattr ( p , " mask_for_overlay " , None )
overlay_image = p . overlay_images [ i ] if getattr ( p , " overlay_images " , None ) is not None and i < len ( p . overlay_images ) else None
2023-12-07 12:16:27 +08:00
if p . scripts is not None :
ppmo = scripts . PostProcessMaskOverlayArgs ( i , mask_for_overlay , overlay_image )
p . scripts . postprocess_maskoverlay ( p , ppmo )
2023-12-07 13:25:53 +08:00
mask_for_overlay , overlay_image = ppmo . mask_for_overlay , ppmo . overlay_image
2023-12-07 12:16:27 +08:00
2022-09-13 17:51:57 +08:00
if p . color_corrections is not None and i < len ( p . color_corrections ) :
2023-08-06 11:21:36 +08:00
if save_samples and opts . save_images_before_color_correction :
2023-12-07 12:16:27 +08:00
image_without_cc = apply_overlay ( image , p . paste_to , overlay_image )
2023-07-26 11:36:06 +08:00
images . save_image ( image_without_cc , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( i ) , p = p , suffix = " -before-color-correction " )
2022-09-13 17:51:57 +08:00
image = apply_color_correction ( p . color_corrections [ i ] , image )
2022-09-12 22:47:36 +08:00
2023-12-03 12:07:02 +08:00
# If the intention is to show the output from the model
# that is being composited over the original image,
# we need to keep the original image around
# and use it in the composite step.
original_denoised_image = image . copy ( )
2023-12-04 05:49:41 +08:00
if p . paste_to is not None :
2023-12-07 13:25:53 +08:00
original_denoised_image = uncrop ( original_denoised_image , ( overlay_image . width , overlay_image . height ) , p . paste_to )
2023-12-04 05:49:41 +08:00
2023-12-07 12:16:27 +08:00
image = apply_overlay ( image , p . paste_to , overlay_image )
2022-09-03 17:08:45 +08:00
2023-08-06 11:21:36 +08:00
if save_samples :
2023-07-26 11:36:06 +08:00
images . save_image ( image , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( i ) , p = p )
2022-09-03 17:08:45 +08:00
2023-07-26 11:36:06 +08:00
text = infotext ( i )
2022-10-07 01:27:50 +08:00
infotexts . append ( text )
2022-10-09 18:10:15 +08:00
if opts . enable_pnginfo :
image . info [ " parameters " ] = text
2022-09-03 17:08:45 +08:00
output_images . append ( image )
2023-03-23 01:51:40 +08:00
2023-12-05 11:38:13 +08:00
if mask_for_overlay is not None :
2023-12-03 17:22:00 +08:00
if opts . return_mask or opts . save_mask :
2023-12-05 11:38:13 +08:00
image_mask = mask_for_overlay . convert ( ' RGB ' )
2023-12-03 17:22:00 +08:00
if save_samples and opts . save_mask :
images . save_image ( image_mask , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( i ) , p = p , suffix = " -mask " )
2023-12-04 16:57:21 +08:00
if opts . return_mask :
output_images . append ( image_mask )
2023-12-03 17:22:00 +08:00
if opts . return_mask_composite or opts . save_mask_composite :
2023-12-05 11:38:13 +08:00
image_mask_composite = Image . composite ( original_denoised_image . convert ( ' RGBA ' ) . convert ( ' RGBa ' ) , Image . new ( ' RGBa ' , image . size ) , images . resize_image ( 2 , mask_for_overlay , image . width , image . height ) . convert ( ' L ' ) ) . convert ( ' RGBA ' )
2023-12-03 17:22:00 +08:00
if save_samples and opts . save_mask_composite :
images . save_image ( image_mask_composite , p . outpath_samples , " " , p . seeds [ i ] , p . prompts [ i ] , opts . samples_format , info = infotext ( i ) , p = p , suffix = " -mask-composite " )
2023-12-04 16:57:21 +08:00
if opts . return_mask_composite :
output_images . append ( image_mask_composite )
2023-03-23 01:51:40 +08:00
2022-12-15 10:01:32 +08:00
del x_samples_ddim
2022-09-06 07:09:01 +08:00
2022-10-05 09:28:50 +08:00
devices . torch_gc ( )
2022-09-29 09:14:13 +08:00
2023-10-14 12:49:03 +08:00
if not infotexts :
infotexts . append ( Processed ( p , [ ] ) . infotext ( p , 0 ) )
2022-09-18 06:18:30 +08:00
p . color_corrections = None
2022-09-19 14:02:10 +08:00
index_of_first_image = 0
2022-09-03 17:08:45 +08:00
unwanted_grid_because_of_img_count = len ( output_images ) < 2 and opts . grid_only_if_multiple
2022-09-14 15:34:44 +08:00
if ( opts . return_grid or opts . grid_save ) and not p . do_not_save_grid and not unwanted_grid_because_of_img_count :
2022-09-03 22:21:15 +08:00
grid = images . image_grid ( output_images , p . batch_size )
2022-09-03 17:08:45 +08:00
2022-09-14 15:34:44 +08:00
if opts . return_grid :
2023-06-20 02:36:44 +08:00
text = infotext ( use_main_prompt = True )
2022-10-07 01:27:50 +08:00
infotexts . insert ( 0 , text )
2022-10-09 18:10:15 +08:00
if opts . enable_pnginfo :
grid . info [ " parameters " ] = text
2022-09-03 17:08:45 +08:00
output_images . insert ( 0 , grid )
2022-09-19 14:02:10 +08:00
index_of_first_image = 1
2022-09-03 17:08:45 +08:00
if opts . grid_save :
2023-06-20 02:36:44 +08:00
images . save_image ( grid , p . outpath_grids , " grid " , p . all_seeds [ 0 ] , p . all_prompts [ 0 ] , opts . grid_format , info = infotext ( use_main_prompt = True ) , short_filename = not opts . grid_extended_filename , p = p , grid = True )
2022-09-03 17:08:45 +08:00
2023-06-04 03:24:44 +08:00
if not p . disable_extra_networks and p . extra_network_data :
extra_networks . deactivate ( p , p . extra_network_data )
2023-01-22 04:40:13 +08:00
2022-09-12 04:24:24 +08:00
devices . torch_gc ( )
2022-10-30 03:20:02 +08:00
2023-05-10 03:17:58 +08:00
res = Processed (
p ,
images_list = output_images ,
seed = p . all_seeds [ 0 ] ,
2023-07-26 12:04:07 +08:00
info = infotexts [ 0 ] ,
2023-05-10 03:17:58 +08:00
subseed = p . all_subseeds [ 0 ] ,
index_of_first_image = index_of_first_image ,
infotexts = infotexts ,
)
2022-10-30 03:20:02 +08:00
if p . scripts is not None :
p . scripts . postprocess ( p , res )
return res
2022-09-03 17:08:45 +08:00
2023-01-09 19:57:47 +08:00
def old_hires_fix_first_pass_dimensions ( width , height ) :
""" old algorithm for auto-calculating first pass size """
desired_pixel_count = 512 * 512
actual_pixel_count = width * height
scale = math . sqrt ( desired_pixel_count / actual_pixel_count )
width = math . ceil ( scale * width / 64 ) * 64
height = math . ceil ( scale * height / 64 ) * 64
return width , height
2023-08-13 13:24:16 +08:00
@dataclass ( repr = False )
2022-09-03 17:08:45 +08:00
class StableDiffusionProcessingTxt2Img ( StableDiffusionProcessing ) :
2023-08-13 13:24:16 +08:00
enable_hr : bool = False
denoising_strength : float = 0.75
firstphase_width : int = 0
firstphase_height : int = 0
hr_scale : float = 2.0
hr_upscaler : str = None
hr_second_pass_steps : int = 0
hr_resize_x : int = 0
hr_resize_y : int = 0
hr_checkpoint_name : str = None
hr_sampler_name : str = None
hr_prompt : str = ' '
hr_negative_prompt : str = ' '
2023-12-15 16:57:17 +08:00
force_task_id : str = None
2023-08-13 13:24:16 +08:00
2023-06-08 12:53:02 +08:00
cached_hr_uc = [ None , None ]
cached_hr_c = [ None , None ]
2022-09-19 21:42:56 +08:00
2023-08-13 13:24:16 +08:00
hr_checkpoint_info : dict = field ( default = None , init = False )
hr_upscale_to_x : int = field ( default = 0 , init = False )
hr_upscale_to_y : int = field ( default = 0 , init = False )
truncate_x : int = field ( default = 0 , init = False )
truncate_y : int = field ( default = 0 , init = False )
applied_old_hires_behavior_to : tuple = field ( default = None , init = False )
latent_scale_mode : dict = field ( default = None , init = False )
hr_c : tuple | None = field ( default = None , init = False )
hr_uc : tuple | None = field ( default = None , init = False )
all_hr_prompts : list = field ( default = None , init = False )
all_hr_negative_prompts : list = field ( default = None , init = False )
hr_prompts : list = field ( default = None , init = False )
hr_negative_prompts : list = field ( default = None , init = False )
hr_extra_network_data : list = field ( default = None , init = False )
def __post_init__ ( self ) :
super ( ) . __post_init__ ( )
if self . firstphase_width != 0 or self . firstphase_height != 0 :
2023-01-09 19:57:47 +08:00
self . hr_upscale_to_x = self . width
self . hr_upscale_to_y = self . height
2023-08-13 13:24:16 +08:00
self . width = self . firstphase_width
self . height = self . firstphase_height
2023-05-19 01:16:09 +08:00
2023-06-08 12:53:02 +08:00
self . cached_hr_uc = StableDiffusionProcessingTxt2Img . cached_hr_uc
self . cached_hr_c = StableDiffusionProcessingTxt2Img . cached_hr_c
2023-05-19 01:16:09 +08:00
2023-08-10 16:20:46 +08:00
def calculate_target_resolution ( self ) :
if opts . use_old_hires_fix_width_height and self . applied_old_hires_behavior_to != ( self . width , self . height ) :
self . hr_resize_x = self . width
self . hr_resize_y = self . height
self . hr_upscale_to_x = self . width
self . hr_upscale_to_y = self . height
self . width , self . height = old_hires_fix_first_pass_dimensions ( self . width , self . height )
self . applied_old_hires_behavior_to = ( self . width , self . height )
if self . hr_resize_x == 0 and self . hr_resize_y == 0 :
self . extra_generation_params [ " Hires upscale " ] = self . hr_scale
self . hr_upscale_to_x = int ( self . width * self . hr_scale )
self . hr_upscale_to_y = int ( self . height * self . hr_scale )
else :
self . extra_generation_params [ " Hires resize " ] = f " { self . hr_resize_x } x { self . hr_resize_y } "
if self . hr_resize_y == 0 :
self . hr_upscale_to_x = self . hr_resize_x
self . hr_upscale_to_y = self . hr_resize_x * self . height / / self . width
elif self . hr_resize_x == 0 :
self . hr_upscale_to_x = self . hr_resize_y * self . width / / self . height
self . hr_upscale_to_y = self . hr_resize_y
else :
target_w = self . hr_resize_x
target_h = self . hr_resize_y
src_ratio = self . width / self . height
dst_ratio = self . hr_resize_x / self . hr_resize_y
if src_ratio < dst_ratio :
self . hr_upscale_to_x = self . hr_resize_x
self . hr_upscale_to_y = self . hr_resize_x * self . height / / self . width
else :
self . hr_upscale_to_x = self . hr_resize_y * self . width / / self . height
self . hr_upscale_to_y = self . hr_resize_y
self . truncate_x = ( self . hr_upscale_to_x - target_w ) / / opt_f
self . truncate_y = ( self . hr_upscale_to_y - target_h ) / / opt_f
2022-09-19 21:42:56 +08:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
if self . enable_hr :
2023-12-17 15:22:03 +08:00
if self . hr_checkpoint_name and self . hr_checkpoint_name != ' Use same checkpoint ' :
2023-07-30 18:48:27 +08:00
self . hr_checkpoint_info = sd_models . get_closet_checkpoint_match ( self . hr_checkpoint_name )
if self . hr_checkpoint_info is None :
raise Exception ( f ' Could not find checkpoint with name { self . hr_checkpoint_name } ' )
self . extra_generation_params [ " Hires checkpoint " ] = self . hr_checkpoint_info . short_title
2023-05-19 01:16:09 +08:00
if self . hr_sampler_name is not None and self . hr_sampler_name != self . sampler_name :
self . extra_generation_params [ " Hires sampler " ] = self . hr_sampler_name
if tuple ( self . hr_prompt ) != tuple ( self . prompt ) :
self . extra_generation_params [ " Hires prompt " ] = self . hr_prompt
2023-02-05 23:24:41 +08:00
2023-05-19 01:16:09 +08:00
if tuple ( self . hr_negative_prompt ) != tuple ( self . negative_prompt ) :
self . extra_generation_params [ " Hires negative prompt " ] = self . hr_negative_prompt
2023-02-05 23:24:41 +08:00
2023-07-30 18:48:27 +08:00
self . latent_scale_mode = shared . latent_upscale_modes . get ( self . hr_upscaler , None ) if self . hr_upscaler is not None else shared . latent_upscale_modes . get ( shared . latent_upscale_default_mode , " nearest " )
if self . enable_hr and self . latent_scale_mode is None :
if not any ( x . name == self . hr_upscaler for x in shared . sd_upscalers ) :
raise Exception ( f " could not find upscaler named { self . hr_upscaler } " )
2023-08-10 16:20:46 +08:00
self . calculate_target_resolution ( )
2023-01-05 03:04:40 +08:00
2023-01-05 06:25:52 +08:00
if not state . processing_has_refined_job_count :
if state . job_count == - 1 :
state . job_count = self . n_iter
2022-10-15 04:19:05 +08:00
2023-01-05 06:25:52 +08:00
shared . total_tqdm . updateTotal ( ( self . steps + ( self . hr_second_pass_steps or self . steps ) ) * state . job_count )
state . job_count = state . job_count * 2
state . processing_has_refined_job_count = True
2022-10-15 04:19:05 +08:00
2023-01-05 03:04:40 +08:00
if self . hr_second_pass_steps :
self . extra_generation_params [ " Hires steps " ] = self . hr_second_pass_steps
2022-10-15 04:19:05 +08:00
2023-01-03 00:42:10 +08:00
if self . hr_upscaler is not None :
self . extra_generation_params [ " Hires upscaler " ] = self . hr_upscaler
2022-10-15 04:19:05 +08:00
2023-05-19 01:16:09 +08:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-11-19 17:01:51 +08:00
self . sampler = sd_samplers . create_sampler ( self . sampler_name , self . sd_model )
2022-10-20 06:09:43 +08:00
2023-08-09 13:43:31 +08:00
x = self . rng . next ( )
2023-01-03 00:42:10 +08:00
samples = self . sampler . sample ( self , x , conditioning , unconditional_conditioning , image_conditioning = self . txt2img_image_conditioning ( x ) )
2023-07-30 19:10:33 +08:00
del x
2023-01-03 00:42:10 +08:00
2022-10-20 06:09:43 +08:00
if not self . enable_hr :
2022-09-19 21:42:56 +08:00
return samples
2023-11-17 09:05:28 +08:00
devices . torch_gc ( )
2022-09-19 21:42:56 +08:00
2023-07-30 20:12:09 +08:00
if self . latent_scale_mode is None :
2023-07-31 18:20:26 +08:00
decoded_samples = torch . stack ( decode_latent_batch ( self . sd_model , samples , target_device = devices . cpu , check_for_nans = True ) ) . to ( dtype = torch . float32 )
2023-07-30 20:12:09 +08:00
else :
decoded_samples = None
2023-08-30 23:22:50 +08:00
with sd_models . SkipWritingToConfig ( ) :
sd_models . reload_model_weights ( info = self . hr_checkpoint_info )
return self . sample_hr_pass ( samples , decoded_samples , seeds , subseeds , subseed_strength , prompts )
2023-07-30 18:48:27 +08:00
2023-07-30 20:12:09 +08:00
def sample_hr_pass ( self , samples , decoded_samples , seeds , subseeds , subseed_strength , prompts ) :
2023-08-18 17:55:10 +08:00
if shared . state . interrupted :
return samples
2023-04-29 21:28:51 +08:00
self . is_hr_pass = True
2023-01-05 03:04:40 +08:00
target_width = self . hr_upscale_to_x
target_height = self . hr_upscale_to_y
2022-09-19 21:42:56 +08:00
2022-11-02 17:45:03 +08:00
def save_intermediate ( image , index ) :
2023-01-03 00:42:10 +08:00
""" saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images """
2023-08-06 11:21:36 +08:00
if not self . save_samples ( ) or not opts . save_images_before_highres_fix :
2022-11-02 17:45:03 +08:00
return
if not isinstance ( image , Image . Image ) :
2023-01-04 17:22:01 +08:00
image = sd_samplers . sample_to_image ( image , index , approximation = 0 )
2022-11-02 17:45:03 +08:00
2023-01-04 22:24:46 +08:00
info = create_infotext ( self , self . all_prompts , self . all_seeds , self . all_subseeds , [ ] , iteration = self . iteration , position_in_batch = index )
2023-07-19 07:27:19 +08:00
images . save_image ( image , self . outpath_samples , " " , seeds [ index ] , prompts [ index ] , opts . samples_format , info = info , p = self , suffix = " -before-highres-fix " )
2022-11-02 17:45:03 +08:00
2023-07-31 15:43:26 +08:00
img2img_sampler_name = self . hr_sampler_name or self . sampler_name
self . sampler = sd_samplers . create_sampler ( img2img_sampler_name , self . sd_model )
2023-07-30 18:48:27 +08:00
if self . latent_scale_mode is not None :
2022-11-04 15:45:34 +08:00
for i in range ( samples . shape [ 0 ] ) :
save_intermediate ( samples , i )
2023-07-30 18:48:27 +08:00
samples = torch . nn . functional . interpolate ( samples , size = ( target_height / / opt_f , target_width / / opt_f ) , mode = self . latent_scale_mode [ " mode " ] , antialias = self . latent_scale_mode [ " antialias " ] )
2022-11-04 15:45:34 +08:00
2022-12-15 10:01:32 +08:00
# Avoid making the inpainting conditioning unless necessary as
2022-10-30 01:35:51 +08:00
# this does need some extra compute to decode / encode the image again.
if getattr ( self , " inpainting_mask_weight " , shared . opts . inpainting_mask_weight ) < 1.0 :
image_conditioning = self . img2img_image_conditioning ( decode_first_stage ( self . sd_model , samples ) , samples )
else :
image_conditioning = self . txt2img_image_conditioning ( samples )
2022-09-19 21:42:56 +08:00
else :
2022-10-14 22:03:03 +08:00
lowres_samples = torch . clamp ( ( decoded_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-09-21 00:32:26 +08:00
2022-10-14 22:03:03 +08:00
batch_images = [ ]
for i , x_sample in enumerate ( lowres_samples ) :
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
x_sample = x_sample . astype ( np . uint8 )
image = Image . fromarray ( x_sample )
2022-11-02 17:45:03 +08:00
save_intermediate ( image , i )
2023-01-03 00:42:10 +08:00
image = images . resize_image ( 0 , image , target_width , target_height , upscaler_name = self . hr_upscaler )
2022-10-14 22:03:03 +08:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
batch_images . append ( image )
decoded_samples = torch . from_numpy ( np . array ( batch_images ) )
2023-08-05 13:14:00 +08:00
decoded_samples = decoded_samples . to ( shared . device , dtype = devices . dtype_vae )
2022-10-14 22:03:03 +08:00
2023-08-05 15:36:26 +08:00
if opts . sd_vae_encode_method != ' Full ' :
self . extra_generation_params [ ' VAE Encoder ' ] = opts . sd_vae_encode_method
2023-08-04 18:23:14 +08:00
samples = images_tensor_to_samples ( decoded_samples , approximation_indexes . get ( opts . sd_vae_encode_method ) )
2022-09-19 21:42:56 +08:00
2022-10-30 01:35:51 +08:00
image_conditioning = self . img2img_image_conditioning ( decoded_samples , samples )
2022-10-30 01:02:56 +08:00
2022-09-19 21:42:56 +08:00
shared . state . nextjob ( )
2022-09-03 17:08:45 +08:00
2023-01-05 03:04:40 +08:00
samples = samples [ : , : , self . truncate_y / / 2 : samples . shape [ 2 ] - ( self . truncate_y + 1 ) / / 2 , self . truncate_x / / 2 : samples . shape [ 3 ] - ( self . truncate_x + 1 ) / / 2 ]
2023-08-09 13:43:31 +08:00
self . rng = rng . ImageRNG ( samples . shape [ 1 : ] , self . seeds , subseeds = self . subseeds , subseed_strength = self . subseed_strength , seed_resize_from_h = self . seed_resize_from_h , seed_resize_from_w = self . seed_resize_from_w )
noise = self . rng . next ( )
2022-09-24 09:18:34 +08:00
# GC now before running the next img2img to prevent running out of memory
devices . torch_gc ( )
2022-10-04 23:49:51 +08:00
2023-05-19 01:16:09 +08:00
if not self . disable_extra_networks :
with devices . autocast ( ) :
extra_networks . activate ( self , self . hr_extra_network_data )
2023-06-04 18:07:22 +08:00
with devices . autocast ( ) :
self . calculate_hr_conds ( )
2023-05-18 01:22:38 +08:00
sd_models . apply_token_merging ( self . sd_model , self . get_token_merging_ratio ( for_hr = True ) )
2023-04-02 11:18:35 +08:00
2023-06-28 17:37:08 +08:00
if self . scripts is not None :
self . scripts . before_hr ( self )
2023-05-19 01:16:09 +08:00
samples = self . sampler . sample_img2img ( self , samples , noise , self . hr_c , self . hr_uc , steps = self . hr_second_pass_steps or self . steps , image_conditioning = image_conditioning )
2022-09-19 21:42:56 +08:00
2023-05-18 01:22:38 +08:00
sd_models . apply_token_merging ( self . sd_model , self . get_token_merging_ratio ( ) )
2023-05-13 23:23:42 +08:00
2023-08-13 16:40:34 +08:00
self . sampler = None
devices . torch_gc ( )
2023-07-30 20:30:33 +08:00
decoded_samples = decode_latent_batch ( self . sd_model , samples , target_device = devices . cpu , check_for_nans = True )
2023-04-29 21:28:51 +08:00
self . is_hr_pass = False
2023-07-30 20:30:33 +08:00
return decoded_samples
2022-09-03 17:08:45 +08:00
2023-05-19 01:16:09 +08:00
def close ( self ) :
2023-06-08 12:53:02 +08:00
super ( ) . close ( )
2023-05-19 01:16:09 +08:00
self . hr_c = None
self . hr_uc = None
2023-08-06 18:25:51 +08:00
if not opts . persistent_cond_cache :
2023-06-08 12:53:02 +08:00
StableDiffusionProcessingTxt2Img . cached_hr_uc = [ None , None ]
StableDiffusionProcessingTxt2Img . cached_hr_c = [ None , None ]
2023-05-19 01:16:09 +08:00
def setup_prompts ( self ) :
super ( ) . setup_prompts ( )
if not self . enable_hr :
return
if self . hr_prompt == ' ' :
self . hr_prompt = self . prompt
if self . hr_negative_prompt == ' ' :
self . hr_negative_prompt = self . negative_prompt
2023-08-19 09:56:15 +08:00
if isinstance ( self . hr_prompt , list ) :
2023-05-19 01:16:09 +08:00
self . all_hr_prompts = self . hr_prompt
else :
self . all_hr_prompts = self . batch_size * self . n_iter * [ self . hr_prompt ]
2023-08-19 09:56:15 +08:00
if isinstance ( self . hr_negative_prompt , list ) :
2023-05-19 01:16:09 +08:00
self . all_hr_negative_prompts = self . hr_negative_prompt
else :
self . all_hr_negative_prompts = self . batch_size * self . n_iter * [ self . hr_negative_prompt ]
self . all_hr_prompts = [ shared . prompt_styles . apply_styles_to_prompt ( x , self . styles ) for x in self . all_hr_prompts ]
self . all_hr_negative_prompts = [ shared . prompt_styles . apply_negative_styles_to_prompt ( x , self . styles ) for x in self . all_hr_negative_prompts ]
2023-06-04 18:07:22 +08:00
def calculate_hr_conds ( self ) :
if self . hr_c is not None :
return
2023-07-30 19:10:26 +08:00
hr_prompts = prompt_parser . SdConditioning ( self . hr_prompts , width = self . hr_upscale_to_x , height = self . hr_upscale_to_y )
hr_negative_prompts = prompt_parser . SdConditioning ( self . hr_negative_prompts , width = self . hr_upscale_to_x , height = self . hr_upscale_to_y , is_negative_prompt = True )
2023-08-12 17:39:59 +08:00
sampler_config = sd_samplers . find_sampler_config ( self . hr_sampler_name or self . sampler_name )
steps = self . hr_second_pass_steps or self . steps
total_steps = sampler_config . total_steps ( steps ) if sampler_config else steps
2023-08-14 15:35:17 +08:00
self . hr_uc = self . get_conds_with_caching ( prompt_parser . get_learned_conditioning , hr_negative_prompts , self . firstpass_steps , [ self . cached_hr_uc , self . cached_uc ] , self . hr_extra_network_data , total_steps )
self . hr_c = self . get_conds_with_caching ( prompt_parser . get_multicond_learned_conditioning , hr_prompts , self . firstpass_steps , [ self . cached_hr_c , self . cached_c ] , self . hr_extra_network_data , total_steps )
2023-06-04 18:07:22 +08:00
2023-05-19 01:16:09 +08:00
def setup_conds ( self ) :
2023-08-12 17:54:32 +08:00
if self . is_hr_pass :
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
self . hr_c = None
self . calculate_hr_conds ( )
return
2023-05-19 01:16:09 +08:00
super ( ) . setup_conds ( )
2023-06-04 18:07:22 +08:00
self . hr_uc = None
self . hr_c = None
2023-07-30 18:48:27 +08:00
if self . enable_hr and self . hr_checkpoint_info is None :
2023-06-04 18:07:22 +08:00
if shared . opts . hires_fix_use_firstpass_conds :
self . calculate_hr_conds ( )
2023-08-31 02:34:17 +08:00
elif lowvram . is_enabled ( shared . sd_model ) and shared . sd_model . sd_checkpoint_info == sd_models . select_checkpoint ( ) : # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
2023-06-04 18:07:22 +08:00
with devices . autocast ( ) :
extra_networks . activate ( self , self . hr_extra_network_data )
self . calculate_hr_conds ( )
with devices . autocast ( ) :
extra_networks . activate ( self , self . extra_network_data )
2023-05-19 01:16:09 +08:00
2023-08-06 22:53:33 +08:00
def get_conds ( self ) :
if self . is_hr_pass :
return self . hr_c , self . hr_uc
return super ( ) . get_conds ( )
2023-05-19 01:16:09 +08:00
def parse_extra_network_prompts ( self ) :
res = super ( ) . parse_extra_network_prompts ( )
if self . enable_hr :
self . hr_prompts = self . all_hr_prompts [ self . iteration * self . batch_size : ( self . iteration + 1 ) * self . batch_size ]
self . hr_negative_prompts = self . all_hr_negative_prompts [ self . iteration * self . batch_size : ( self . iteration + 1 ) * self . batch_size ]
self . hr_prompts , self . hr_extra_network_data = extra_networks . parse_prompts ( self . hr_prompts )
return res
2022-09-03 17:08:45 +08:00
2023-08-13 13:24:16 +08:00
@dataclass ( repr = False )
2022-09-03 17:08:45 +08:00
class StableDiffusionProcessingImg2Img ( StableDiffusionProcessing ) :
2023-08-13 13:24:16 +08:00
init_images : list = None
resize_mode : int = 0
denoising_strength : float = 0.75
image_cfg_scale : float = None
mask : Any = None
mask_blur_x : int = 4
mask_blur_y : int = 4
mask_blur : int = None
2023-12-07 12:16:27 +08:00
mask_round : bool = True
2023-08-13 13:24:16 +08:00
inpainting_fill : int = 0
inpaint_full_res : bool = True
inpaint_full_res_padding : int = 0
inpainting_mask_invert : int = 0
initial_noise_multiplier : float = None
latent_mask : Image = None
2023-12-15 17:48:20 +08:00
force_task_id : str = None
2023-08-13 13:24:16 +08:00
image_mask : Any = field ( default = None , init = False )
nmask : torch . Tensor = field ( default = None , init = False )
image_conditioning : torch . Tensor = field ( default = None , init = False )
init_img_hash : str = field ( default = None , init = False )
mask_for_overlay : Image = field ( default = None , init = False )
init_latent : torch . Tensor = field ( default = None , init = False )
def __post_init__ ( self ) :
super ( ) . __post_init__ ( )
self . image_mask = self . mask
2022-09-03 17:08:45 +08:00
self . mask = None
2023-08-13 13:24:16 +08:00
self . initial_noise_multiplier = opts . initial_noise_multiplier if self . initial_noise_multiplier is None else self . initial_noise_multiplier
2022-09-03 17:08:45 +08:00
2023-08-03 10:03:35 +08:00
@property
def mask_blur ( self ) :
if self . mask_blur_x == self . mask_blur_y :
return self . mask_blur_x
return None
@mask_blur.setter
def mask_blur ( self , value ) :
2023-08-13 13:24:16 +08:00
if isinstance ( value , int ) :
self . mask_blur_x = value
self . mask_blur_y = value
2023-08-03 10:03:35 +08:00
2022-09-19 21:42:56 +08:00
def init ( self , all_prompts , all_seeds , all_subseeds ) :
2023-08-13 13:24:16 +08:00
self . image_cfg_scale : float = self . image_cfg_scale if shared . sd_model . cond_stage_key == " edit " else None
2022-11-19 17:01:51 +08:00
self . sampler = sd_samplers . create_sampler ( self . sampler_name , self . sd_model )
2022-09-03 17:08:45 +08:00
crop_region = None
2022-11-19 18:47:37 +08:00
image_mask = self . image_mask
2022-09-04 02:02:38 +08:00
2022-11-19 18:47:37 +08:00
if image_mask is not None :
2023-08-16 02:24:55 +08:00
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
2023-12-07 12:16:27 +08:00
image_mask = create_binary_mask ( image_mask , round = self . mask_round )
2022-09-04 02:02:38 +08:00
2022-11-19 18:47:37 +08:00
if self . inpainting_mask_invert :
image_mask = ImageOps . invert ( image_mask )
2022-09-04 06:29:43 +08:00
2023-05-12 06:03:44 +08:00
if self . mask_blur_x > 0 :
np_mask = np . array ( image_mask )
2023-08-05 12:54:23 +08:00
kernel_size = 2 * int ( 2.5 * self . mask_blur_x + 0.5 ) + 1
2023-05-12 06:03:44 +08:00
np_mask = cv2 . GaussianBlur ( np_mask , ( kernel_size , 1 ) , self . mask_blur_x )
image_mask = Image . fromarray ( np_mask )
if self . mask_blur_y > 0 :
np_mask = np . array ( image_mask )
2023-08-05 12:54:23 +08:00
kernel_size = 2 * int ( 2.5 * self . mask_blur_y + 0.5 ) + 1
2023-05-12 06:03:44 +08:00
np_mask = cv2 . GaussianBlur ( np_mask , ( 1 , kernel_size ) , self . mask_blur_y )
image_mask = Image . fromarray ( np_mask )
2022-09-03 17:08:45 +08:00
if self . inpaint_full_res :
2022-11-19 18:47:37 +08:00
self . mask_for_overlay = image_mask
mask = image_mask . convert ( ' L ' )
2022-09-22 17:11:48 +08:00
crop_region = masking . get_crop_region ( np . array ( mask ) , self . inpaint_full_res_padding )
2022-09-18 15:49:00 +08:00
crop_region = masking . expand_crop_region ( crop_region , self . width , self . height , mask . width , mask . height )
2022-09-03 17:08:45 +08:00
x1 , y1 , x2 , y2 = crop_region
mask = mask . crop ( crop_region )
2022-11-19 18:47:37 +08:00
image_mask = images . resize_image ( 2 , mask , self . width , self . height )
2022-09-03 17:08:45 +08:00
self . paste_to = ( x1 , y1 , x2 - x1 , y2 - y1 )
else :
2023-03-29 01:36:57 +08:00
image_mask = images . resize_image ( self . resize_mode , image_mask , self . width , self . height )
2022-11-19 18:47:37 +08:00
np_mask = np . array ( image_mask )
2022-09-13 22:14:40 +08:00
np_mask = np . clip ( ( np_mask . astype ( np . float32 ) ) * 2 , 0 , 255 ) . astype ( np . uint8 )
2022-09-07 05:58:01 +08:00
self . mask_for_overlay = Image . fromarray ( np_mask )
2022-09-03 17:08:45 +08:00
self . overlay_images = [ ]
2022-11-19 18:47:37 +08:00
latent_mask = self . latent_mask if self . latent_mask is not None else image_mask
2022-09-07 22:00:51 +08:00
2022-09-16 13:33:47 +08:00
add_color_corrections = opts . img2img_color_correction and self . color_corrections is None
if add_color_corrections :
self . color_corrections = [ ]
2022-09-03 17:08:45 +08:00
imgs = [ ]
for img in self . init_images :
2023-04-07 00:42:26 +08:00
# Save init image
if opts . save_init_img :
self . init_img_hash = hashlib . md5 ( img . tobytes ( ) ) . hexdigest ( )
2023-12-29 00:56:48 +08:00
images . save_image ( img , path = opts . outdir_init_images , basename = None , forced_filename = self . init_img_hash , save_to_dirs = False , existing_info = img . info )
2023-04-07 00:42:26 +08:00
2022-12-24 14:46:35 +08:00
image = images . flatten ( img , opts . img2img_background_color )
2022-09-03 17:08:45 +08:00
2022-12-08 15:09:09 +08:00
if crop_region is None and self . resize_mode != 3 :
2023-03-29 01:36:57 +08:00
image = images . resize_image ( self . resize_mode , image , self . width , self . height )
2022-09-03 17:08:45 +08:00
2022-11-19 18:47:37 +08:00
if image_mask is not None :
2022-09-03 17:08:45 +08:00
image_masked = Image . new ( ' RGBa ' , ( image . width , image . height ) )
image_masked . paste ( image . convert ( " RGBA " ) . convert ( " RGBa " ) , mask = ImageOps . invert ( self . mask_for_overlay . convert ( ' L ' ) ) )
self . overlay_images . append ( image_masked . convert ( ' RGBA ' ) )
2022-12-24 16:12:17 +08:00
# crop_region is not None if we are doing inpaint full res
2022-09-03 17:08:45 +08:00
if crop_region is not None :
image = image . crop ( crop_region )
image = images . resize_image ( 2 , image , self . width , self . height )
2022-11-19 18:47:37 +08:00
if image_mask is not None :
2022-09-08 15:03:21 +08:00
if self . inpainting_fill != 1 :
2022-09-18 15:49:00 +08:00
image = masking . fill ( image , latent_mask )
2022-09-08 15:03:21 +08:00
2022-09-16 13:33:47 +08:00
if add_color_corrections :
2022-09-13 17:51:57 +08:00
self . color_corrections . append ( setup_color_correction ( image ) )
2022-09-03 17:08:45 +08:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
imgs . append ( image )
if len ( imgs ) == 1 :
batch_images = np . expand_dims ( imgs [ 0 ] , axis = 0 ) . repeat ( self . batch_size , axis = 0 )
if self . overlay_images is not None :
self . overlay_images = self . overlay_images * self . batch_size
2022-10-23 03:06:54 +08:00
if self . color_corrections is not None and len ( self . color_corrections ) == 1 :
self . color_corrections = self . color_corrections * self . batch_size
2022-09-03 17:08:45 +08:00
elif len ( imgs ) < = self . batch_size :
self . batch_size = len ( imgs )
batch_images = np . array ( imgs )
else :
raise RuntimeError ( f " bad number of images passed: { len ( imgs ) } ; expecting { self . batch_size } or less " )
image = torch . from_numpy ( batch_images )
2023-08-05 13:14:00 +08:00
image = image . to ( shared . device , dtype = devices . dtype_vae )
2023-08-05 15:36:26 +08:00
if opts . sd_vae_encode_method != ' Full ' :
self . extra_generation_params [ ' VAE Encoder ' ] = opts . sd_vae_encode_method
2023-08-04 13:38:52 +08:00
self . init_latent = images_tensor_to_samples ( image , approximation_indexes . get ( opts . sd_vae_encode_method ) , self . sd_model )
2023-08-02 23:53:09 +08:00
devices . torch_gc ( )
2022-09-03 17:08:45 +08:00
2023-03-29 01:36:57 +08:00
if self . resize_mode == 3 :
self . init_latent = torch . nn . functional . interpolate ( self . init_latent , size = ( self . height / / opt_f , self . width / / opt_f ) , mode = " bilinear " )
2022-12-08 15:09:09 +08:00
2022-11-19 18:47:37 +08:00
if image_mask is not None :
2022-09-07 22:00:51 +08:00
init_mask = latent_mask
2022-09-04 06:29:43 +08:00
latmask = init_mask . convert ( ' RGB ' ) . resize ( ( self . init_latent . shape [ 3 ] , self . init_latent . shape [ 2 ] ) )
2022-09-13 01:09:32 +08:00
latmask = np . moveaxis ( np . array ( latmask , dtype = np . float32 ) , 2 , 0 ) / 255
2022-09-03 17:08:45 +08:00
latmask = latmask [ 0 ]
2023-12-07 12:16:27 +08:00
if self . mask_round :
2023-12-04 16:57:21 +08:00
latmask = np . around ( latmask )
2022-09-03 17:08:45 +08:00
latmask = np . tile ( latmask [ None ] , ( 4 , 1 , 1 ) )
self . mask = torch . asarray ( 1.0 - latmask ) . to ( shared . device ) . type ( self . sd_model . dtype )
self . nmask = torch . asarray ( latmask ) . to ( shared . device ) . type ( self . sd_model . dtype )
2022-09-19 21:42:56 +08:00
# this needs to be fixed to be done in sample() using actual seeds for batches
2022-09-03 17:08:45 +08:00
if self . inpainting_fill == 2 :
2022-09-19 21:42:56 +08:00
self . init_latent = self . init_latent * self . mask + create_random_tensors ( self . init_latent . shape [ 1 : ] , all_seeds [ 0 : self . init_latent . shape [ 0 ] ] ) * self . nmask
2022-09-03 17:08:45 +08:00
elif self . inpainting_fill == 3 :
self . init_latent = self . init_latent * self . mask
2023-12-07 12:16:27 +08:00
self . image_conditioning = self . img2img_image_conditioning ( image * 2 - 1 , self . init_latent , image_mask , self . mask_round )
2022-10-20 06:09:43 +08:00
2022-11-02 17:45:03 +08:00
def sample ( self , conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2023-08-09 13:43:31 +08:00
x = self . rng . next ( )
2022-12-10 14:51:26 +08:00
if self . initial_noise_multiplier != 1.0 :
self . extra_generation_params [ " Noise multiplier " ] = self . initial_noise_multiplier
x * = self . initial_noise_multiplier
2022-09-19 21:42:56 +08:00
2022-10-20 04:47:45 +08:00
samples = self . sampler . sample_img2img ( self , self . init_latent , x , conditioning , unconditional_conditioning , image_conditioning = self . image_conditioning )
2022-09-03 17:08:45 +08:00
if self . mask is not None :
2023-12-07 13:25:53 +08:00
blended_samples = samples * self . nmask + self . init_latent * self . mask
2023-12-07 12:16:27 +08:00
2023-12-07 13:25:53 +08:00
if self . scripts is not None :
mba = scripts . MaskBlendArgs ( samples , self . nmask , self . init_latent , self . mask , blended_samples )
self . scripts . on_mask_blend ( self , mba )
blended_samples = mba . blended_latent
2023-12-07 12:16:27 +08:00
2023-12-07 13:25:53 +08:00
samples = blended_samples
2022-09-03 17:08:45 +08:00
2022-09-29 09:14:13 +08:00
del x
devices . torch_gc ( )
2022-11-02 17:45:03 +08:00
return samples
2023-05-18 01:22:38 +08:00
def get_token_merging_ratio ( self , for_hr = False ) :
return self . token_merging_ratio or ( " token_merging_ratio " in self . override_settings and opts . token_merging_ratio ) or opts . token_merging_ratio_img2img or opts . token_merging_ratio