From ec8774729e17f87a8ffa5a3c5328d12834cbb02a Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:53:35 -0500 Subject: [PATCH] swaps xyz axes internally if one costs more --- scripts/xyz_grid.py | 66 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index a16653dac..828c2d122 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,7 +205,7 @@ axis_options = [ ] -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order): +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed): hor_texts = [[images.GridAnnotation(x)] for x in x_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] @@ -224,7 +224,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend nonlocal image_cache, processed_result, cell_mode, cell_size def index(ix, iy, iz): - return ix + iy*len(xs) + iz*len(xs)*len(ys) + return ix + iy * len(xs) + iz * len(xs) * len(ys) state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" @@ -251,16 +251,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend except: image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) - if swap_axes_processing_order: + if first_axes_processed == 'x': for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - for iy, y in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': + for iy, y in enumerate(ys): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) if not processed_result: print("Unexpected error: draw_xyz_grid failed to return even a single processed image") @@ -322,7 +342,7 @@ class Script(scripts.Script): with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) with gr.Row(): z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) @@ -487,7 +507,26 @@ class Script(scripts.Script): # If one of the axes is very slow to change between (like SD model # checkpoint), then make sure it is in the outer iteration of the nested # `for` loop. - swap_axes_processing_order = x_opt.cost > y_opt.cost + first_axes_processed = 'x' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' def cell(x, y, z): if shared.state.interrupted: @@ -538,7 +577,8 @@ class Script(scripts.Script): draw_legend=draw_legend, include_lone_images=include_lone_images, include_sub_grids=include_sub_grids, - swap_axes_processing_order=swap_axes_processing_order + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed ) if opts.grid_save: