2022-09-15 00:41:55 +08:00
from collections import namedtuple
2022-09-12 06:55:34 +08:00
import numpy as np
from tqdm import trange
import modules . scripts as scripts
import gradio as gr
2023-02-04 06:50:38 +08:00
from modules import processing , shared , sd_samplers , prompt_parser , sd_samplers_common
2022-09-12 06:55:34 +08:00
from modules . processing import Processed
from modules . shared import opts , cmd_opts , state
import torch
import k_diffusion as K
from PIL import Image
from torch import autocast
from einops import rearrange , repeat
def find_noise_for_image ( p , cond , uncond , cfg_scale , steps ) :
x = p . init_latent
s_in = x . new_ones ( [ x . shape [ 0 ] ] )
dnw = K . external . CompVisDenoiser ( shared . sd_model )
sigmas = dnw . get_sigmas ( steps ) . flip ( 0 )
shared . state . sampling_steps = steps
for i in trange ( 1 , len ( sigmas ) ) :
shared . state . sampling_step + = 1
x_in = torch . cat ( [ x ] * 2 )
sigma_in = torch . cat ( [ sigmas [ i ] * s_in ] * 2 )
cond_in = torch . cat ( [ uncond , cond ] )
2022-10-22 02:32:56 +08:00
image_conditioning = torch . cat ( [ p . image_conditioning ] * 2 )
cond_in = { " c_concat " : [ image_conditioning ] , " c_crossattn " : [ cond_in ] }
2022-09-12 06:55:34 +08:00
c_out , c_in = [ K . utils . append_dims ( k , x_in . ndim ) for k in dnw . get_scalings ( sigma_in ) ]
t = dnw . sigma_to_t ( sigma_in )
eps = shared . sd_model . apply_model ( x_in * c_in , t , cond = cond_in )
denoised_uncond , denoised_cond = ( x_in + eps * c_out ) . chunk ( 2 )
denoised = denoised_uncond + ( denoised_cond - denoised_uncond ) * cfg_scale
d = ( x - denoised ) / sigmas [ i ]
dt = sigmas [ i ] - sigmas [ i - 1 ]
x = x + d * dt
2023-02-04 06:50:38 +08:00
sd_samplers_common . store_latent ( x )
2022-09-12 06:55:34 +08:00
# This shouldn't be necessary, but solved some VRAM issues
del x_in , sigma_in , cond_in , c_out , c_in , t ,
del eps , denoised_uncond , denoised_cond , denoised , d , dt
shared . state . nextjob ( )
return x / x . std ( )
2022-09-15 00:41:55 +08:00
2022-09-27 04:13:23 +08:00
Cached = namedtuple ( " Cached " , [ " noise " , " cfg_scale " , " steps " , " latent " , " original_prompt " , " original_negative_prompt " , " sigma_adjustment " ] )
# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
def find_noise_for_image_sigma_adjustment ( p , cond , uncond , cfg_scale , steps ) :
x = p . init_latent
s_in = x . new_ones ( [ x . shape [ 0 ] ] )
dnw = K . external . CompVisDenoiser ( shared . sd_model )
sigmas = dnw . get_sigmas ( steps ) . flip ( 0 )
shared . state . sampling_steps = steps
for i in trange ( 1 , len ( sigmas ) ) :
shared . state . sampling_step + = 1
x_in = torch . cat ( [ x ] * 2 )
sigma_in = torch . cat ( [ sigmas [ i - 1 ] * s_in ] * 2 )
cond_in = torch . cat ( [ uncond , cond ] )
2022-10-22 02:32:56 +08:00
image_conditioning = torch . cat ( [ p . image_conditioning ] * 2 )
cond_in = { " c_concat " : [ image_conditioning ] , " c_crossattn " : [ cond_in ] }
2022-09-27 04:13:23 +08:00
c_out , c_in = [ K . utils . append_dims ( k , x_in . ndim ) for k in dnw . get_scalings ( sigma_in ) ]
if i == 1 :
t = dnw . sigma_to_t ( torch . cat ( [ sigmas [ i ] * s_in ] * 2 ) )
else :
t = dnw . sigma_to_t ( sigma_in )
eps = shared . sd_model . apply_model ( x_in * c_in , t , cond = cond_in )
denoised_uncond , denoised_cond = ( x_in + eps * c_out ) . chunk ( 2 )
denoised = denoised_uncond + ( denoised_cond - denoised_uncond ) * cfg_scale
if i == 1 :
d = ( x - denoised ) / ( 2 * sigmas [ i ] )
else :
d = ( x - denoised ) / sigmas [ i - 1 ]
dt = sigmas [ i ] - sigmas [ i - 1 ]
x = x + d * dt
2023-02-04 06:50:38 +08:00
sd_samplers_common . store_latent ( x )
2022-09-27 04:13:23 +08:00
# This shouldn't be necessary, but solved some VRAM issues
del x_in , sigma_in , cond_in , c_out , c_in , t ,
del eps , denoised_uncond , denoised_cond , denoised , d , dt
shared . state . nextjob ( )
return x / sigmas [ - 1 ]
2022-09-15 00:41:55 +08:00
2022-09-12 06:55:34 +08:00
class Script ( scripts . Script ) :
2022-09-15 00:41:55 +08:00
def __init__ ( self ) :
self . cache = None
2022-09-12 06:55:34 +08:00
def title ( self ) :
return " img2img alternative test "
def show ( self , is_img2img ) :
return is_img2img
2023-01-05 16:29:07 +08:00
def ui ( self , is_img2img ) :
2022-10-12 20:09:42 +08:00
info = gr . Markdown ( '''
* ` CFG Scale ` should be 2 or lower .
''' )
2023-01-05 16:29:07 +08:00
override_sampler = gr . Checkbox ( label = " Override `Sampling method` to Euler?(this method is built for it) " , value = True , elem_id = self . elem_id ( " override_sampler " ) )
2022-10-13 07:39:33 +08:00
2023-01-05 16:29:07 +08:00
override_prompt = gr . Checkbox ( label = " Override `prompt` to the same value as `original prompt`?(and `negative prompt`) " , value = True , elem_id = self . elem_id ( " override_prompt " ) )
original_prompt = gr . Textbox ( label = " Original prompt " , lines = 1 , elem_id = self . elem_id ( " original_prompt " ) )
original_negative_prompt = gr . Textbox ( label = " Original negative prompt " , lines = 1 , elem_id = self . elem_id ( " original_negative_prompt " ) )
2022-10-12 20:09:42 +08:00
2023-01-05 16:29:07 +08:00
override_steps = gr . Checkbox ( label = " Override `Sampling Steps` to the same value as `Decode steps`? " , value = True , elem_id = self . elem_id ( " override_steps " ) )
st = gr . Slider ( label = " Decode steps " , minimum = 1 , maximum = 150 , step = 1 , value = 50 , elem_id = self . elem_id ( " st " ) )
2022-10-12 20:09:42 +08:00
2023-01-05 16:29:07 +08:00
override_strength = gr . Checkbox ( label = " Override `Denoising strength` to 1? " , value = True , elem_id = self . elem_id ( " override_strength " ) )
2022-10-12 20:09:42 +08:00
2023-01-05 16:29:07 +08:00
cfg = gr . Slider ( label = " Decode CFG scale " , minimum = 0.0 , maximum = 15.0 , step = 0.1 , value = 1.0 , elem_id = self . elem_id ( " cfg " ) )
randomness = gr . Slider ( label = " Randomness " , minimum = 0.0 , maximum = 1.0 , step = 0.01 , value = 0.0 , elem_id = self . elem_id ( " randomness " ) )
sigma_adjustment = gr . Checkbox ( label = " Sigma adjustment for finding noise for image " , value = False , elem_id = self . elem_id ( " sigma_adjustment " ) )
2022-09-12 06:55:34 +08:00
2022-10-12 20:09:42 +08:00
return [
info ,
2022-10-13 07:39:33 +08:00
override_sampler ,
2022-10-12 20:09:42 +08:00
override_prompt , original_prompt , original_negative_prompt ,
override_steps , st ,
override_strength ,
cfg , randomness , sigma_adjustment ,
]
2022-10-13 07:39:33 +08:00
def run ( self , p , _ , override_sampler , override_prompt , original_prompt , original_negative_prompt , override_steps , st , override_strength , cfg , randomness , sigma_adjustment ) :
# Override
if override_sampler :
2022-11-19 17:01:51 +08:00
p . sampler_name = " Euler "
2022-10-12 20:09:42 +08:00
if override_prompt :
p . prompt = original_prompt
p . negative_prompt = original_negative_prompt
if override_steps :
p . steps = st
if override_strength :
p . denoising_strength = 1.0
2022-09-12 06:55:34 +08:00
2022-11-02 17:45:03 +08:00
def sample_extra ( conditioning , unconditional_conditioning , seeds , subseeds , subseed_strength , prompts ) :
2022-09-15 00:41:55 +08:00
lat = ( p . init_latent . cpu ( ) . numpy ( ) * 10 ) . astype ( int )
2022-09-27 04:13:23 +08:00
same_params = self . cache is not None and self . cache . cfg_scale == cfg and self . cache . steps == st \
and self . cache . original_prompt == original_prompt \
and self . cache . original_negative_prompt == original_negative_prompt \
and self . cache . sigma_adjustment == sigma_adjustment
2022-09-15 00:41:55 +08:00
same_everything = same_params and self . cache . latent . shape == lat . shape and np . abs ( self . cache . latent - lat ) . sum ( ) < 100
2022-09-12 06:55:34 +08:00
2022-09-15 00:41:55 +08:00
if same_everything :
2022-09-16 14:40:43 +08:00
rec_noise = self . cache . noise
2022-09-12 06:55:34 +08:00
else :
shared . state . job_count + = 1
cond = p . sd_model . get_learned_conditioning ( p . batch_size * [ original_prompt ] )
2022-09-17 00:24:48 +08:00
uncond = p . sd_model . get_learned_conditioning ( p . batch_size * [ original_negative_prompt ] )
2022-09-27 04:13:23 +08:00
if sigma_adjustment :
rec_noise = find_noise_for_image_sigma_adjustment ( p , cond , uncond , cfg , st )
else :
rec_noise = find_noise_for_image ( p , cond , uncond , cfg , st )
self . cache = Cached ( rec_noise , cfg , st , lat , original_prompt , original_negative_prompt , sigma_adjustment )
2022-09-12 06:55:34 +08:00
2022-10-12 13:02:28 +08:00
rand_noise = processing . create_random_tensors ( p . init_latent . shape [ 1 : ] , seeds = seeds , subseeds = subseeds , subseed_strength = p . subseed_strength , seed_resize_from_h = p . seed_resize_from_h , seed_resize_from_w = p . seed_resize_from_w , p = p )
2022-09-16 14:40:43 +08:00
combined_noise = ( ( 1 - randomness ) * rec_noise + randomness * rand_noise ) / ( ( randomness * * 2 + ( 1 - randomness ) * * 2 ) * * 0.5 )
2022-11-19 17:01:51 +08:00
sampler = sd_samplers . create_sampler ( p . sampler_name , p . sd_model )
2022-09-12 06:55:34 +08:00
2022-09-16 14:40:43 +08:00
sigmas = sampler . model_wrap . get_sigmas ( p . steps )
2022-09-19 23:23:51 +08:00
noise_dt = combined_noise - ( p . init_latent / sigmas [ 0 ] )
2022-09-16 14:40:43 +08:00
2022-09-16 14:40:43 +08:00
p . seed = p . seed + 1
2022-09-16 14:40:43 +08:00
2022-10-22 02:32:56 +08:00
return sampler . sample_img2img ( p , p . init_latent , noise_dt , conditioning , unconditional_conditioning , image_conditioning = p . image_conditioning )
2022-09-16 14:40:43 +08:00
2022-09-12 06:55:34 +08:00
p . sample = sample_extra
2022-09-21 00:07:09 +08:00
p . extra_generation_params [ " Decode prompt " ] = original_prompt
p . extra_generation_params [ " Decode negative prompt " ] = original_negative_prompt
p . extra_generation_params [ " Decode CFG scale " ] = cfg
p . extra_generation_params [ " Decode steps " ] = st
p . extra_generation_params [ " Randomness " ] = randomness
2022-09-27 04:13:23 +08:00
p . extra_generation_params [ " Sigma Adjustment " ] = sigma_adjustment
2022-09-14 16:08:36 +08:00
2022-09-12 06:55:34 +08:00
processed = processing . process_images ( p )
return processed