2022-09-04 00:32:45 +08:00
from collections import namedtuple
from copy import copy
2022-10-06 18:55:21 +08:00
from itertools import permutations , chain
2022-09-04 00:32:45 +08:00
import random
2022-10-06 18:55:21 +08:00
import csv
from io import StringIO
2022-09-26 21:46:18 +08:00
from PIL import Image
2022-09-09 22:54:04 +08:00
import numpy as np
2022-09-04 00:32:45 +08:00
import modules . scripts as scripts
import gradio as gr
2023-01-16 22:36:56 +08:00
from modules import images , paths , sd_samplers , processing , sd_models , sd_vae
2022-11-19 17:01:51 +08:00
from modules . processing import process_images , Processed , StableDiffusionProcessingTxt2Img
2022-09-04 00:32:45 +08:00
from modules . shared import opts , cmd_opts , state
2022-09-17 18:49:36 +08:00
import modules . shared as shared
2022-09-04 00:32:45 +08:00
import modules . sd_samplers
2022-09-17 18:49:36 +08:00
import modules . sd_models
2022-12-18 23:47:02 +08:00
import modules . sd_vae
import glob
import os
2022-09-06 15:11:25 +08:00
import re
2022-09-04 00:32:45 +08:00
2023-01-16 22:36:56 +08:00
from modules . ui_components import ToolButton
2022-09-04 00:32:45 +08:00
2023-01-16 22:36:56 +08:00
fill_values_symbol = " \U0001f4d2 " # 📒
2023-01-16 13:41:58 +08:00
2023-03-05 02:00:27 +08:00
AxisInfo = namedtuple ( ' AxisInfo ' , [ ' axis ' , ' values ' ] )
2023-01-16 13:41:58 +08:00
2022-09-04 00:32:45 +08:00
def apply_field ( field ) :
def fun ( p , x , xs ) :
setattr ( p , field , x )
return fun
def apply_prompt ( p , x , xs ) :
2022-10-11 18:59:56 +08:00
if xs [ 0 ] not in p . prompt and xs [ 0 ] not in p . negative_prompt :
raise RuntimeError ( f " Prompt S/R did not find { xs [ 0 ] } in prompt or negative prompt. " )
2022-10-11 18:16:57 +08:00
2022-09-04 00:32:45 +08:00
p . prompt = p . prompt . replace ( xs [ 0 ] , x )
2022-09-09 13:58:31 +08:00
p . negative_prompt = p . negative_prompt . replace ( xs [ 0 ] , x )
2022-09-04 00:32:45 +08:00
2022-10-04 10:20:09 +08:00
def apply_order ( p , x , xs ) :
token_order = [ ]
2022-10-04 14:18:00 +08:00
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
2022-10-04 10:20:09 +08:00
for token in x :
token_order . append ( ( p . prompt . find ( token ) , token ) )
token_order . sort ( key = lambda t : t [ 0 ] )
2022-10-04 13:07:36 +08:00
prompt_parts = [ ]
# Split the prompt up, taking out the tokens
for _ , token in token_order :
n = p . prompt . find ( token )
prompt_parts . append ( p . prompt [ 0 : n ] )
p . prompt = p . prompt [ n + len ( token ) : ]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = " "
for idx , part in enumerate ( prompt_parts ) :
prompt_tmp + = part
prompt_tmp + = x [ idx ]
p . prompt = prompt_tmp + p . prompt
2022-09-04 00:32:45 +08:00
def apply_sampler ( p , x , xs ) :
2022-11-27 18:43:10 +08:00
sampler_name = sd_samplers . samplers_map . get ( x . lower ( ) , None )
2022-11-27 18:17:39 +08:00
if sampler_name is None :
2022-09-04 00:32:45 +08:00
raise RuntimeError ( f " Unknown sampler: { x } " )
2022-11-27 18:17:39 +08:00
p . sampler_name = sampler_name
2022-09-04 00:32:45 +08:00
2022-10-10 01:20:35 +08:00
def confirm_samplers ( p , xs ) :
for x in xs :
2022-11-27 18:43:10 +08:00
if x . lower ( ) not in sd_samplers . samplers_map :
2022-10-10 01:20:35 +08:00
raise RuntimeError ( f " Unknown sampler: { x } " )
2022-09-04 00:32:45 +08:00
2022-09-17 18:49:36 +08:00
def apply_checkpoint ( p , x , xs ) :
2022-09-29 05:31:53 +08:00
info = modules . sd_models . get_closet_checkpoint_match ( x )
2022-10-10 01:20:35 +08:00
if info is None :
raise RuntimeError ( f " Unknown checkpoint: { x } " )
2022-09-17 18:49:36 +08:00
modules . sd_models . reload_model_weights ( shared . sd_model , info )
2022-10-10 01:20:35 +08:00
def confirm_checkpoints ( p , xs ) :
for x in xs :
if modules . sd_models . get_closet_checkpoint_match ( x ) is None :
raise RuntimeError ( f " Unknown checkpoint: { x } " )
2022-10-09 23:58:55 +08:00
def apply_clip_skip ( p , x , xs ) :
2022-10-10 02:57:17 +08:00
opts . data [ " CLIP_stop_at_last_layers " ] = x
2022-10-09 23:58:55 +08:00
2022-12-18 23:47:02 +08:00
def apply_upscale_latent_space ( p , x , xs ) :
if x . lower ( ) . strip ( ) != ' 0 ' :
opts . data [ " use_scale_latent_for_hires_fix " ] = True
else :
opts . data [ " use_scale_latent_for_hires_fix " ] = False
def find_vae ( name : str ) :
2023-01-15 00:56:09 +08:00
if name . lower ( ) in [ ' auto ' , ' automatic ' ] :
return modules . sd_vae . unspecified
if name . lower ( ) == ' none ' :
return None
2022-12-18 23:47:02 +08:00
else :
2023-01-15 00:56:09 +08:00
choices = [ x for x in sorted ( modules . sd_vae . vae_dict , key = lambda x : len ( x ) ) if name . lower ( ) . strip ( ) in x . lower ( ) ]
if len ( choices ) == 0 :
print ( f " No VAE found for { name } ; using automatic " )
return modules . sd_vae . unspecified
2022-12-18 23:47:02 +08:00
else :
2023-01-15 00:56:09 +08:00
return modules . sd_vae . vae_dict [ choices [ 0 ] ]
2022-12-18 23:47:02 +08:00
def apply_vae ( p , x , xs ) :
2023-01-15 00:56:09 +08:00
modules . sd_vae . reload_vae_weights ( shared . sd_model , vae_file = find_vae ( x ) )
2022-12-18 23:47:02 +08:00
def apply_styles ( p : StableDiffusionProcessingTxt2Img , x : str , _ ) :
2023-01-28 02:48:39 +08:00
p . styles . extend ( x . split ( ' , ' ) )
2022-09-17 18:49:36 +08:00
2022-09-04 00:32:45 +08:00
def format_value_add_label ( p , opt , x ) :
2022-09-09 23:05:43 +08:00
if type ( x ) == float :
x = round ( x , 8 )
2022-09-04 00:32:45 +08:00
return f " { opt . label } : { x } "
def format_value ( p , opt , x ) :
2022-09-09 23:05:43 +08:00
if type ( x ) == float :
x = round ( x , 8 )
2022-09-04 00:32:45 +08:00
return x
2022-10-04 14:18:00 +08:00
def format_value_join_list ( p , opt , x ) :
return " , " . join ( x )
2022-09-09 22:54:04 +08:00
def do_nothing ( p , x , xs ) :
pass
2022-10-04 14:18:00 +08:00
2022-09-09 22:54:04 +08:00
def format_nothing ( p , opt , x ) :
return " "
2022-09-04 00:32:45 +08:00
2022-10-04 14:18:00 +08:00
def str_permutations ( x ) :
""" dummy function for specifying it in AxisOption ' s type when you want to get a list of permutations """
return x
2023-01-16 22:36:56 +08:00
class AxisOption :
def __init__ ( self , label , type , apply , format_value = format_value_add_label , confirm = None , cost = 0.0 , choices = None ) :
self . label = label
self . type = type
self . apply = apply
self . format_value = format_value
self . confirm = confirm
self . cost = cost
self . choices = choices
class AxisOptionImg2Img ( AxisOption ) :
2023-01-22 03:43:37 +08:00
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . is_img2img = True
class AxisOptionTxt2Img ( AxisOption ) :
2023-01-16 22:36:56 +08:00
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . is_img2img = False
2022-09-04 00:32:45 +08:00
axis_options = [
2023-01-16 22:36:56 +08:00
AxisOption ( " Nothing " , str , do_nothing , format_value = format_nothing ) ,
AxisOption ( " Seed " , int , apply_field ( " seed " ) ) ,
AxisOption ( " Var. seed " , int , apply_field ( " subseed " ) ) ,
AxisOption ( " Var. strength " , float , apply_field ( " subseed_strength " ) ) ,
AxisOption ( " Steps " , int , apply_field ( " steps " ) ) ,
2023-01-23 06:08:08 +08:00
AxisOptionTxt2Img ( " Hires steps " , int , apply_field ( " hr_second_pass_steps " ) ) ,
2023-01-16 22:36:56 +08:00
AxisOption ( " CFG Scale " , float , apply_field ( " cfg_scale " ) ) ,
2023-03-05 02:00:27 +08:00
AxisOptionImg2Img ( " Image CFG Scale " , float , apply_field ( " image_cfg_scale " ) ) ,
2023-01-16 22:36:56 +08:00
AxisOption ( " Prompt S/R " , str , apply_prompt , format_value = format_value ) ,
AxisOption ( " Prompt order " , str_permutations , apply_order , format_value = format_value_join_list ) ,
2023-01-22 03:43:37 +08:00
AxisOptionTxt2Img ( " Sampler " , str , apply_sampler , format_value = format_value , confirm = confirm_samplers , choices = lambda : [ x . name for x in sd_samplers . samplers ] ) ,
AxisOptionImg2Img ( " Sampler " , str , apply_sampler , format_value = format_value , confirm = confirm_samplers , choices = lambda : [ x . name for x in sd_samplers . samplers_for_img2img ] ) ,
2023-01-16 22:36:56 +08:00
AxisOption ( " Checkpoint name " , str , apply_checkpoint , format_value = format_value , confirm = confirm_checkpoints , cost = 1.0 , choices = lambda : list ( sd_models . checkpoints_list ) ) ,
AxisOption ( " Sigma Churn " , float , apply_field ( " s_churn " ) ) ,
AxisOption ( " Sigma min " , float , apply_field ( " s_tmin " ) ) ,
AxisOption ( " Sigma max " , float , apply_field ( " s_tmax " ) ) ,
AxisOption ( " Sigma noise " , float , apply_field ( " s_noise " ) ) ,
AxisOption ( " Eta " , float , apply_field ( " eta " ) ) ,
AxisOption ( " Clip skip " , int , apply_clip_skip ) ,
AxisOption ( " Denoising " , float , apply_field ( " denoising_strength " ) ) ,
2023-01-22 03:43:37 +08:00
AxisOptionTxt2Img ( " Hires upscaler " , str , apply_field ( " hr_upscaler " ) , choices = lambda : [ * shared . latent_upscale_modes , * [ x . name for x in shared . sd_upscalers ] ] ) ,
AxisOptionImg2Img ( " Cond. Image Mask Weight " , float , apply_field ( " inpainting_mask_weight " ) ) ,
2023-01-16 22:36:56 +08:00
AxisOption ( " VAE " , str , apply_vae , cost = 0.7 , choices = lambda : list ( sd_vae . vae_dict ) ) ,
AxisOption ( " Styles " , str , apply_styles , choices = lambda : list ( shared . prompt_styles . styles ) ) ,
2022-09-04 00:32:45 +08:00
]
2023-02-05 16:44:56 +08:00
def draw_xyz_grid ( p , xs , ys , zs , x_labels , y_labels , z_labels , cell , draw_legend , include_lone_images , include_sub_grids , first_axes_processed , second_axes_processed , margin_size ) :
2022-09-17 19:55:40 +08:00
hor_texts = [ [ images . GridAnnotation ( x ) ] for x in x_labels ]
2023-01-24 15:24:32 +08:00
ver_texts = [ [ images . GridAnnotation ( y ) ] for y in y_labels ]
2023-01-24 15:22:40 +08:00
title_texts = [ [ images . GridAnnotation ( z ) ] for z in z_labels ]
2022-09-04 00:32:45 +08:00
2023-03-05 01:51:55 +08:00
list_size = ( len ( xs ) * len ( ys ) * len ( zs ) )
2022-10-13 07:12:12 +08:00
processed_result = None
2022-09-04 00:32:45 +08:00
2023-03-05 01:51:55 +08:00
state . job_count = list_size * p . n_iter
2022-09-06 07:09:01 +08:00
2023-01-24 15:22:40 +08:00
def process_cell ( x , y , z , ix , iy , iz ) :
2023-03-05 01:51:55 +08:00
nonlocal processed_result
2023-01-16 13:40:57 +08:00
2023-01-24 15:22:40 +08:00
def index ( ix , iy , iz ) :
2023-01-24 15:53:35 +08:00
return ix + iy * len ( xs ) + iz * len ( xs ) * len ( ys )
2023-01-16 13:40:57 +08:00
2023-03-05 01:51:55 +08:00
state . job = f " { index ( ix , iy , iz ) + 1 } out of { list_size } "
2023-01-24 15:22:40 +08:00
processed : Processed = cell ( x , y , z )
2023-01-16 13:40:57 +08:00
2023-03-05 01:51:55 +08:00
if processed_result is None :
# Use our first processed result object as a template container to hold our full results
processed_result = copy ( processed )
processed_result . images = [ None ] * list_size
processed_result . all_prompts = [ None ] * list_size
processed_result . all_seeds = [ None ] * list_size
processed_result . infotexts = [ None ] * list_size
2023-03-05 04:40:35 +08:00
processed_result . index_of_first_image = 1
2023-03-05 01:51:55 +08:00
idx = index ( ix , iy , iz )
if processed . images :
# Non-empty list indicates some degree of success.
processed_result . images [ idx ] = processed . images [ 0 ]
processed_result . all_prompts [ idx ] = processed . prompt
processed_result . all_seeds [ idx ] = processed . seed
processed_result . infotexts [ idx ] = processed . infotexts [ 0 ]
else :
cell_mode = " P "
cell_size = ( processed_result . width , processed_result . height )
if processed_result . images [ 0 ] is not None :
cell_mode = processed_result . images [ 0 ] . mode
#This corrects size in case of batches:
cell_size = processed_result . images [ 0 ] . size
processed_result . images [ idx ] = Image . new ( cell_mode , cell_size )
2023-01-16 13:40:57 +08:00
2023-01-24 15:53:35 +08:00
if first_axes_processed == ' x ' :
2022-09-04 00:32:45 +08:00
for ix , x in enumerate ( xs ) :
2023-01-24 15:53:35 +08:00
if second_axes_processed == ' y ' :
for iy , y in enumerate ( ys ) :
for iz , z in enumerate ( zs ) :
process_cell ( x , y , z , ix , iy , iz )
else :
for iz , z in enumerate ( zs ) :
for iy , y in enumerate ( ys ) :
process_cell ( x , y , z , ix , iy , iz )
elif first_axes_processed == ' y ' :
2023-01-16 13:40:57 +08:00
for iy , y in enumerate ( ys ) :
2023-01-24 15:53:35 +08:00
if second_axes_processed == ' x ' :
for ix , x in enumerate ( xs ) :
for iz , z in enumerate ( zs ) :
process_cell ( x , y , z , ix , iy , iz )
else :
2023-01-24 15:22:40 +08:00
for iz , z in enumerate ( zs ) :
2023-01-24 15:53:35 +08:00
for ix , x in enumerate ( xs ) :
process_cell ( x , y , z , ix , iy , iz )
elif first_axes_processed == ' z ' :
for iz , z in enumerate ( zs ) :
if second_axes_processed == ' x ' :
for ix , x in enumerate ( xs ) :
for iy , y in enumerate ( ys ) :
process_cell ( x , y , z , ix , iy , iz )
else :
for iy , y in enumerate ( ys ) :
for ix , x in enumerate ( xs ) :
process_cell ( x , y , z , ix , iy , iz )
2022-10-13 07:12:12 +08:00
if not processed_result :
2023-03-05 01:51:55 +08:00
# Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
print ( " Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service. " )
return Processed ( p , [ ] )
elif not any ( processed_result . images ) :
2023-01-24 15:22:40 +08:00
print ( " Unexpected error: draw_xyz_grid failed to return even a single processed image " )
2023-01-16 22:36:56 +08:00
return Processed ( p , [ ] )
2022-09-04 00:32:45 +08:00
2023-03-05 01:51:55 +08:00
z_count = len ( zs )
sub_grids = [ None ] * z_count
for i in range ( z_count ) :
start_index = ( i * len ( xs ) * len ( ys ) ) + i
2023-01-24 15:22:40 +08:00
end_index = start_index + len ( xs ) * len ( ys )
2023-03-05 01:51:55 +08:00
grid = images . image_grid ( processed_result . images [ start_index : end_index ] , rows = len ( ys ) )
2023-01-24 15:22:40 +08:00
if draw_legend :
2023-03-05 01:51:55 +08:00
grid = images . draw_grid_annotations ( grid , processed_result . images [ start_index ] . size [ 0 ] , processed_result . images [ start_index ] . size [ 1 ] , hor_texts , ver_texts , margin_size )
processed_result . images . insert ( i , grid )
processed_result . all_prompts . insert ( i , processed_result . all_prompts [ start_index ] )
processed_result . all_seeds . insert ( i , processed_result . all_seeds [ start_index ] )
processed_result . infotexts . insert ( i , processed_result . infotexts [ start_index ] )
sub_grid_size = processed_result . images [ 0 ] . size
z_grid = images . image_grid ( processed_result . images [ : z_count ] , rows = 1 )
2023-01-29 04:37:01 +08:00
if draw_legend :
z_grid = images . draw_grid_annotations ( z_grid , sub_grid_size [ 0 ] , sub_grid_size [ 1 ] , title_texts , [ [ images . GridAnnotation ( ) ] ] )
2023-03-05 01:51:55 +08:00
processed_result . images . insert ( 0 , z_grid )
2023-03-05 04:40:35 +08:00
#TODO: Deeper aspects of the program rely on index 0 "grid" images only having partial information, which is not ideal.
#processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
#processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
2023-03-05 01:51:55 +08:00
processed_result . infotexts . insert ( 0 , processed_result . infotexts [ 0 ] )
2022-09-04 00:32:45 +08:00
2023-03-05 01:51:55 +08:00
return processed_result
2022-09-04 00:32:45 +08:00
2022-10-17 03:10:07 +08:00
class SharedSettingsStackHelper ( object ) :
def __enter__ ( self ) :
self . CLIP_stop_at_last_layers = opts . CLIP_stop_at_last_layers
2022-12-18 23:47:02 +08:00
self . vae = opts . sd_vae
2022-10-17 03:10:07 +08:00
def __exit__ ( self , exc_type , exc_value , tb ) :
2023-01-15 00:56:09 +08:00
opts . data [ " sd_vae " ] = self . vae
2023-01-16 14:27:52 +08:00
modules . sd_models . reload_model_weights ( )
modules . sd_vae . reload_vae_weights ( )
2022-09-04 00:32:45 +08:00
2022-10-17 03:10:07 +08:00
opts . data [ " CLIP_stop_at_last_layers " ] = self . CLIP_stop_at_last_layers
2022-09-04 00:32:45 +08:00
2022-09-06 15:11:25 +08:00
re_range = re . compile ( r " \ s*([+-]? \ s* \ d+) \ s*- \ s*([+-]? \ s* \ d+)(?: \ s* \ (([+-] \ d+) \ s* \ ))? \ s* " )
2022-09-09 22:54:04 +08:00
re_range_float = re . compile ( r " \ s*([+-]? \ s* \ d+(?:. \ d*)?) \ s*- \ s*([+-]? \ s* \ d+(?:. \ d*)?)(?: \ s* \ (([+-] \ d+(?:. \ d*)?) \ s* \ ))? \ s* " )
2022-09-06 15:11:25 +08:00
2022-09-14 19:56:26 +08:00
re_range_count = re . compile ( r " \ s*([+-]? \ s* \ d+) \ s*- \ s*([+-]? \ s* \ d+)(?: \ s* \ [( \ d+) \ s* \ ])? \ s* " )
re_range_count_float = re . compile ( r " \ s*([+-]? \ s* \ d+(?:. \ d*)?) \ s*- \ s*([+-]? \ s* \ d+(?:. \ d*)?)(?: \ s* \ [( \ d+(?:. \ d*)?) \ s* \ ])? \ s* " )
2023-01-05 00:19:11 +08:00
2022-09-04 00:32:45 +08:00
class Script ( scripts . Script ) :
def title ( self ) :
2023-01-24 15:22:40 +08:00
return " X/Y/Z plot "
2022-09-04 00:32:45 +08:00
def ui ( self , is_img2img ) :
2023-01-22 04:58:59 +08:00
self . current_axis_options = [ x for x in axis_options if type ( x ) == AxisOption or x . is_img2img == is_img2img ]
2022-09-04 00:32:45 +08:00
with gr . Row ( ) :
2023-01-16 13:41:58 +08:00
with gr . Column ( scale = 19 ) :
with gr . Row ( ) :
2023-01-22 04:58:59 +08:00
x_type = gr . Dropdown ( label = " X type " , choices = [ x . label for x in self . current_axis_options ] , value = self . current_axis_options [ 1 ] . label , type = " index " , elem_id = self . elem_id ( " x_type " ) )
2023-01-16 13:41:58 +08:00
x_values = gr . Textbox ( label = " X values " , lines = 1 , elem_id = self . elem_id ( " x_values " ) )
2023-01-24 15:22:40 +08:00
fill_x_button = ToolButton ( value = fill_values_symbol , elem_id = " xyz_grid_fill_x_tool_button " , visible = False )
2023-01-16 13:41:58 +08:00
with gr . Row ( ) :
2023-01-22 04:58:59 +08:00
y_type = gr . Dropdown ( label = " Y type " , choices = [ x . label for x in self . current_axis_options ] , value = self . current_axis_options [ 0 ] . label , type = " index " , elem_id = self . elem_id ( " y_type " ) )
2023-01-16 13:41:58 +08:00
y_values = gr . Textbox ( label = " Y values " , lines = 1 , elem_id = self . elem_id ( " y_values " ) )
2023-01-24 15:53:35 +08:00
fill_y_button = ToolButton ( value = fill_values_symbol , elem_id = " xyz_grid_fill_y_tool_button " , visible = False )
2023-01-24 15:22:40 +08:00
with gr . Row ( ) :
z_type = gr . Dropdown ( label = " Z type " , choices = [ x . label for x in self . current_axis_options ] , value = self . current_axis_options [ 0 ] . label , type = " index " , elem_id = self . elem_id ( " z_type " ) )
z_values = gr . Textbox ( label = " Z values " , lines = 1 , elem_id = self . elem_id ( " z_values " ) )
fill_z_button = ToolButton ( value = fill_values_symbol , elem_id = " xyz_grid_fill_z_tool_button " , visible = False )
2023-01-16 22:36:56 +08:00
2023-01-22 05:58:45 +08:00
with gr . Row ( variant = " compact " , elem_id = " axis_options " ) :
2023-02-05 16:44:56 +08:00
with gr . Column ( ) :
draw_legend = gr . Checkbox ( label = ' Draw legend ' , value = True , elem_id = self . elem_id ( " draw_legend " ) )
no_fixed_seeds = gr . Checkbox ( label = ' Keep -1 for seeds ' , value = False , elem_id = self . elem_id ( " no_fixed_seeds " ) )
with gr . Column ( ) :
include_lone_images = gr . Checkbox ( label = ' Include Sub Images ' , value = False , elem_id = self . elem_id ( " include_lone_images " ) )
include_sub_grids = gr . Checkbox ( label = ' Include Sub Grids ' , value = False , elem_id = self . elem_id ( " include_sub_grids " ) )
with gr . Column ( ) :
2023-03-05 02:00:27 +08:00
margin_size = gr . Slider ( label = " Grid margins (px) " , minimum = 0 , maximum = 500 , value = 0 , step = 2 , elem_id = self . elem_id ( " margin_size " ) )
2023-02-05 16:44:56 +08:00
with gr . Row ( variant = " compact " , elem_id = " swap_axes " ) :
2023-01-24 15:22:40 +08:00
swap_xy_axes_button = gr . Button ( value = " Swap X/Y axes " , elem_id = " xy_grid_swap_axes_button " )
swap_yz_axes_button = gr . Button ( value = " Swap Y/Z axes " , elem_id = " yz_grid_swap_axes_button " )
swap_xz_axes_button = gr . Button ( value = " Swap X/Z axes " , elem_id = " xz_grid_swap_axes_button " )
2022-09-04 00:32:45 +08:00
2023-01-24 15:22:40 +08:00
def swap_axes ( axis1_type , axis1_values , axis2_type , axis2_values ) :
return self . current_axis_options [ axis2_type ] . label , axis2_values , self . current_axis_options [ axis1_type ] . label , axis1_values
2023-01-16 13:41:58 +08:00
2023-01-24 15:22:40 +08:00
xy_swap_args = [ x_type , x_values , y_type , y_values ]
swap_xy_axes_button . click ( swap_axes , inputs = xy_swap_args , outputs = xy_swap_args )
yz_swap_args = [ y_type , y_values , z_type , z_values ]
swap_yz_axes_button . click ( swap_axes , inputs = yz_swap_args , outputs = yz_swap_args )
xz_swap_args = [ x_type , x_values , z_type , z_values ]
swap_xz_axes_button . click ( swap_axes , inputs = xz_swap_args , outputs = xz_swap_args )
2023-01-16 13:41:58 +08:00
2023-01-16 22:36:56 +08:00
def fill ( x_type ) :
2023-01-22 04:58:59 +08:00
axis = self . current_axis_options [ x_type ]
2023-01-16 22:36:56 +08:00
return " , " . join ( axis . choices ( ) ) if axis . choices else gr . update ( )
fill_x_button . click ( fn = fill , inputs = [ x_type ] , outputs = [ x_values ] )
fill_y_button . click ( fn = fill , inputs = [ y_type ] , outputs = [ y_values ] )
2023-01-24 15:22:40 +08:00
fill_z_button . click ( fn = fill , inputs = [ z_type ] , outputs = [ z_values ] )
2023-01-16 22:36:56 +08:00
def select_axis ( x_type ) :
2023-01-22 04:58:59 +08:00
return gr . Button . update ( visible = self . current_axis_options [ x_type ] . choices is not None )
2023-01-16 22:36:56 +08:00
x_type . change ( fn = select_axis , inputs = [ x_type ] , outputs = [ fill_x_button ] )
y_type . change ( fn = select_axis , inputs = [ y_type ] , outputs = [ fill_y_button ] )
2023-01-24 15:22:40 +08:00
z_type . change ( fn = select_axis , inputs = [ z_type ] , outputs = [ fill_z_button ] )
2023-01-16 22:36:56 +08:00
2023-01-26 13:18:41 +08:00
self . infotext_fields = (
( x_type , " X Type " ) ,
( x_values , " X Values " ) ,
( y_type , " Y Type " ) ,
( y_values , " Y Values " ) ,
( z_type , " Z Type " ) ,
( z_values , " Z Values " ) ,
)
2023-02-05 16:44:56 +08:00
return [ x_type , x_values , y_type , y_values , z_type , z_values , draw_legend , include_lone_images , include_sub_grids , no_fixed_seeds , margin_size ]
2022-09-24 13:09:59 +08:00
2023-02-05 16:44:56 +08:00
def run ( self , p , x_type , x_values , y_type , y_values , z_type , z_values , draw_legend , include_lone_images , include_sub_grids , no_fixed_seeds , margin_size ) :
2022-10-07 07:31:36 +08:00
if not no_fixed_seeds :
modules . processing . fix_seed ( p )
2022-10-11 02:24:11 +08:00
if not opts . return_grid :
p . batch_size = 1
2022-09-04 00:32:45 +08:00
def process_axis ( opt , vals ) :
2022-09-30 02:16:12 +08:00
if opt . label == ' Nothing ' :
return [ 0 ]
2022-10-07 01:16:21 +08:00
valslist = [ x . strip ( ) for x in chain . from_iterable ( csv . reader ( StringIO ( vals ) ) ) ]
2022-09-04 00:32:45 +08:00
if opt . type == int :
valslist_ext = [ ]
for val in valslist :
2022-09-06 15:11:25 +08:00
m = re_range . fullmatch ( val )
2022-09-14 19:56:26 +08:00
mc = re_range_count . fullmatch ( val )
2022-09-06 15:11:25 +08:00
if m is not None :
start = int ( m . group ( 1 ) )
end = int ( m . group ( 2 ) ) + 1
step = int ( m . group ( 3 ) ) if m . group ( 3 ) is not None else 1
2022-09-04 00:32:45 +08:00
valslist_ext + = list ( range ( start , end , step ) )
2022-09-14 19:56:26 +08:00
elif mc is not None :
start = int ( mc . group ( 1 ) )
end = int ( mc . group ( 2 ) )
num = int ( mc . group ( 3 ) ) if mc . group ( 3 ) is not None else 1
2022-09-24 13:23:01 +08:00
valslist_ext + = [ int ( x ) for x in np . linspace ( start = start , stop = end , num = num ) . tolist ( ) ]
2022-09-04 00:32:45 +08:00
else :
valslist_ext . append ( val )
valslist = valslist_ext
2022-09-09 22:54:04 +08:00
elif opt . type == float :
valslist_ext = [ ]
for val in valslist :
m = re_range_float . fullmatch ( val )
2022-09-14 19:56:26 +08:00
mc = re_range_count_float . fullmatch ( val )
2022-09-09 22:54:04 +08:00
if m is not None :
start = float ( m . group ( 1 ) )
end = float ( m . group ( 2 ) )
step = float ( m . group ( 3 ) ) if m . group ( 3 ) is not None else 1
valslist_ext + = np . arange ( start , end + step , step ) . tolist ( )
2022-09-14 19:56:26 +08:00
elif mc is not None :
start = float ( mc . group ( 1 ) )
end = float ( mc . group ( 2 ) )
num = int ( mc . group ( 3 ) ) if mc . group ( 3 ) is not None else 1
2022-09-24 13:23:01 +08:00
valslist_ext + = np . linspace ( start = start , stop = end , num = num ) . tolist ( )
2022-09-09 22:54:04 +08:00
else :
valslist_ext . append ( val )
valslist = valslist_ext
2022-10-04 14:18:00 +08:00
elif opt . type == str_permutations :
valslist = list ( permutations ( valslist ) )
2022-09-04 00:32:45 +08:00
valslist = [ opt . type ( x ) for x in valslist ]
2022-10-08 13:30:49 +08:00
# Confirm options are valid before starting
2022-10-10 01:20:35 +08:00
if opt . confirm :
opt . confirm ( p , valslist )
2022-09-04 00:32:45 +08:00
return valslist
2023-01-22 04:58:59 +08:00
x_opt = self . current_axis_options [ x_type ]
2022-09-04 00:32:45 +08:00
xs = process_axis ( x_opt , x_values )
2023-01-22 04:58:59 +08:00
y_opt = self . current_axis_options [ y_type ]
2022-09-04 00:32:45 +08:00
ys = process_axis ( y_opt , y_values )
2023-01-24 15:22:40 +08:00
z_opt = self . current_axis_options [ z_type ]
zs = process_axis ( z_opt , z_values )
2022-09-24 13:09:59 +08:00
def fix_axis_seeds ( axis_opt , axis_list ) :
2023-01-05 00:19:11 +08:00
if axis_opt . label in [ ' Seed ' , ' Var. seed ' ] :
2022-09-24 13:09:59 +08:00
return [ int ( random . randrange ( 4294967294 ) ) if val is None or val == ' ' or val == - 1 else val for val in axis_list ]
else :
return axis_list
2022-09-24 13:23:01 +08:00
if not no_fixed_seeds :
2022-09-24 13:09:59 +08:00
xs = fix_axis_seeds ( x_opt , xs )
ys = fix_axis_seeds ( y_opt , ys )
2023-01-24 15:22:40 +08:00
zs = fix_axis_seeds ( z_opt , zs )
2022-09-24 13:09:59 +08:00
if x_opt . label == ' Steps ' :
2023-01-24 15:22:40 +08:00
total_steps = sum ( xs ) * len ( ys ) * len ( zs )
2022-09-24 13:09:59 +08:00
elif y_opt . label == ' Steps ' :
2023-01-24 15:22:40 +08:00
total_steps = sum ( ys ) * len ( xs ) * len ( zs )
elif z_opt . label == ' Steps ' :
total_steps = sum ( zs ) * len ( xs ) * len ( ys )
2022-09-24 13:09:59 +08:00
else :
2023-01-24 15:22:40 +08:00
total_steps = p . steps * len ( xs ) * len ( ys ) * len ( zs )
2022-09-24 13:09:59 +08:00
2022-10-15 05:26:38 +08:00
if isinstance ( p , StableDiffusionProcessingTxt2Img ) and p . enable_hr :
2023-01-23 06:08:08 +08:00
if x_opt . label == " Hires steps " :
2023-01-24 15:22:40 +08:00
total_steps + = sum ( xs ) * len ( ys ) * len ( zs )
2023-01-23 06:08:08 +08:00
elif y_opt . label == " Hires steps " :
2023-01-24 15:22:40 +08:00
total_steps + = sum ( ys ) * len ( xs ) * len ( zs )
elif z_opt . label == " Hires steps " :
total_steps + = sum ( zs ) * len ( xs ) * len ( ys )
2023-01-23 06:08:08 +08:00
elif p . hr_second_pass_steps :
2023-01-24 15:22:40 +08:00
total_steps + = p . hr_second_pass_steps * len ( xs ) * len ( ys ) * len ( zs )
2023-01-23 06:08:08 +08:00
else :
total_steps * = 2
total_steps * = p . n_iter
2022-10-15 05:26:38 +08:00
2023-01-23 06:08:08 +08:00
image_cell_count = p . n_iter * p . batch_size
cell_console_text = f " ; { image_cell_count } images per cell " if image_cell_count > 1 else " "
2023-01-24 15:22:40 +08:00
plural_s = ' s ' if len ( zs ) > 1 else ' '
2023-01-28 03:04:23 +08:00
print ( f " X/Y/Z plot will create { len ( xs ) * len ( ys ) * len ( zs ) * image_cell_count } images on { len ( zs ) } { len ( xs ) } x { len ( ys ) } grid { plural_s } { cell_console_text } . (Total steps to process: { total_steps } ) " )
2023-01-23 06:08:08 +08:00
shared . total_tqdm . updateTotal ( total_steps )
2022-09-24 13:09:59 +08:00
2023-01-05 00:19:11 +08:00
grid_infotext = [ None ]
2023-03-05 02:00:27 +08:00
state . xyz_plot_x = AxisInfo ( x_opt , xs )
state . xyz_plot_y = AxisInfo ( y_opt , ys )
state . xyz_plot_z = AxisInfo ( z_opt , zs )
2023-01-16 13:40:57 +08:00
# If one of the axes is very slow to change between (like SD model
# checkpoint), then make sure it is in the outer iteration of the nested
# `for` loop.
2023-03-05 01:51:55 +08:00
first_axes_processed = ' z '
2023-01-24 15:53:35 +08:00
second_axes_processed = ' y '
if x_opt . cost > y_opt . cost and x_opt . cost > z_opt . cost :
first_axes_processed = ' x '
if y_opt . cost > z_opt . cost :
second_axes_processed = ' y '
else :
second_axes_processed = ' z '
elif y_opt . cost > x_opt . cost and y_opt . cost > z_opt . cost :
first_axes_processed = ' y '
if x_opt . cost > z_opt . cost :
second_axes_processed = ' x '
else :
second_axes_processed = ' z '
elif z_opt . cost > x_opt . cost and z_opt . cost > y_opt . cost :
first_axes_processed = ' z '
if x_opt . cost > y_opt . cost :
second_axes_processed = ' x '
else :
second_axes_processed = ' y '
2023-01-16 13:40:57 +08:00
2023-01-24 15:22:40 +08:00
def cell ( x , y , z ) :
2023-01-16 11:43:34 +08:00
if shared . state . interrupted :
return Processed ( p , [ ] , p . seed , " " )
2022-09-04 00:32:45 +08:00
pc = copy ( p )
2023-01-28 02:48:39 +08:00
pc . styles = pc . styles [ : ]
2022-09-04 00:32:45 +08:00
x_opt . apply ( pc , x , xs )
y_opt . apply ( pc , y , ys )
2023-01-24 15:22:40 +08:00
z_opt . apply ( pc , z , zs )
2022-09-04 00:32:45 +08:00
2023-01-05 00:19:11 +08:00
res = process_images ( pc )
if grid_infotext [ 0 ] is None :
pc . extra_generation_params = copy ( pc . extra_generation_params )
2023-01-26 13:18:41 +08:00
pc . extra_generation_params [ ' Script ' ] = self . title ( )
2023-01-05 00:19:11 +08:00
if x_opt . label != ' Nothing ' :
pc . extra_generation_params [ " X Type " ] = x_opt . label
pc . extra_generation_params [ " X Values " ] = x_values
if x_opt . label in [ " Seed " , " Var. seed " ] and not no_fixed_seeds :
pc . extra_generation_params [ " Fixed X Values " ] = " , " . join ( [ str ( x ) for x in xs ] )
if y_opt . label != ' Nothing ' :
pc . extra_generation_params [ " Y Type " ] = y_opt . label
pc . extra_generation_params [ " Y Values " ] = y_values
if y_opt . label in [ " Seed " , " Var. seed " ] and not no_fixed_seeds :
pc . extra_generation_params [ " Fixed Y Values " ] = " , " . join ( [ str ( y ) for y in ys ] )
2022-09-04 00:32:45 +08:00
2023-01-24 15:22:40 +08:00
if z_opt . label != ' Nothing ' :
pc . extra_generation_params [ " Z Type " ] = z_opt . label
pc . extra_generation_params [ " Z Values " ] = z_values
if z_opt . label in [ " Seed " , " Var. seed " ] and not no_fixed_seeds :
pc . extra_generation_params [ " Fixed Z Values " ] = " , " . join ( [ str ( z ) for z in zs ] )
2023-01-05 00:19:11 +08:00
grid_infotext [ 0 ] = processing . create_infotext ( pc , pc . all_prompts , pc . all_seeds , pc . all_subseeds )
2022-10-01 13:02:29 +08:00
2023-01-05 00:19:11 +08:00
return res
2022-10-01 13:02:29 +08:00
2022-10-17 03:10:07 +08:00
with SharedSettingsStackHelper ( ) :
2023-03-05 01:51:55 +08:00
processed = draw_xyz_grid (
2022-10-17 03:10:07 +08:00
p ,
xs = xs ,
ys = ys ,
2023-01-24 15:22:40 +08:00
zs = zs ,
2022-10-17 03:10:07 +08:00
x_labels = [ x_opt . format_value ( p , x_opt , x ) for x in xs ] ,
y_labels = [ y_opt . format_value ( p , y_opt , y ) for y in ys ] ,
2023-01-24 15:24:32 +08:00
z_labels = [ z_opt . format_value ( p , z_opt , z ) for z in zs ] ,
2022-10-17 03:10:07 +08:00
cell = cell ,
draw_legend = draw_legend ,
2023-01-16 13:40:57 +08:00
include_lone_images = include_lone_images ,
2023-01-24 15:24:32 +08:00
include_sub_grids = include_sub_grids ,
2023-01-24 15:53:35 +08:00
first_axes_processed = first_axes_processed ,
2023-02-05 16:44:56 +08:00
second_axes_processed = second_axes_processed ,
margin_size = margin_size
2022-10-17 03:10:07 +08:00
)
2022-09-04 00:32:45 +08:00
2023-03-05 01:51:55 +08:00
z_count = len ( zs )
if not include_lone_images :
# Don't need sub-images anymore, drop from list:
processed . images = processed . images [ : z_count + 1 ]
if opts . grid_save and processed . images :
# Auto-save main and sub-grids:
grid_count = z_count + 1 if z_count > 1 else 1
for g in range ( grid_count ) :
images . save_image ( processed . images [ g ] , p . outpath_grids , " xyz_grid " , info = processed . infotexts [ g ] , extension = opts . grid_format , prompt = processed . all_prompts [ g ] , seed = processed . all_seeds [ g ] , grid = True , p = processed )
if not include_sub_grids :
# Done with sub-grids, drop all related information:
for sg in range ( z_count ) :
del processed . images [ 1 ]
del processed . all_prompts [ 1 ]
del processed . all_seeds [ 1 ]
del processed . infotexts [ 1 ]
2022-09-17 18:49:36 +08:00
2023-03-05 04:40:35 +08:00
print ( processed . images )
2022-09-04 00:32:45 +08:00
return processed