diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a0461..074ee9192 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -175,76 +175,87 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) +AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"]) +AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing, None), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None), - AxisOption("Prompt S/R", str, apply_prompt, format_value, None), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None), - AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), - AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), - AxisOption("VAE", str, apply_vae, format_value_add_label, None), - AxisOption("Styles", str, apply_styles, format_value_add_label, None), + AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), + AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), + AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), + AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), + AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), + AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), + AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), + AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] # Temporary list of all the images that are generated to be populated into the grid. # Will be filled with empty images for any individual step that fails to process properly - image_cache = [] + image_cache = [None] * (len(xs) * len(ys)) processed_result = None cell_mode = "P" - cell_size = (1,1) + cell_size = (1, 1) state.job_count = len(xs) * len(ys) * p.n_iter - for iy, y in enumerate(ys): + def process_cell(x, y, ix, iy): + nonlocal image_cache, processed_result, cell_mode, cell_size + + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed: Processed = cell(x, y) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache[ix + iy * len(xs)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + + if swap_axes_processing_order: for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed:Processed = cell(x, y) - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] - - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - - image_cache.append(processed_image) - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache.append(Image.new(cell_mode, cell_size)) + for iy, y in enumerate(ys): + process_cell(x, y, ix, iy) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, ix, iy) if not processed_result: print("Unexpected error: draw_xy_grid failed to return even a single processed image") @@ -405,6 +416,11 @@ class Script(scripts.Script): grid_infotext = [None] + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + swap_axes_processing_order = x_opt.cost > y_opt.cost + def cell(x, y): if shared.state.interrupted: return Processed(p, [], p.seed, "") @@ -443,7 +459,8 @@ class Script(scripts.Script): y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, draw_legend=draw_legend, - include_lone_images=include_lone_images + include_lone_images=include_lone_images, + swap_axes_processing_order=swap_axes_processing_order ) if opts.grid_save: