mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 13:55:06 +08:00
Optimize XY grid to run slower axes fewer times
This commit is contained in:
parent
dd292a925e
commit
029260b4ca
@ -175,56 +175,58 @@ def str_permutations(x):
|
|||||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
|
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
|
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
|
||||||
|
|
||||||
|
|
||||||
axis_options = [
|
axis_options = [
|
||||||
AxisOption("Nothing", str, do_nothing, format_nothing, None),
|
AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
|
||||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
|
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
|
||||||
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
|
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
|
||||||
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
|
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
|
||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
|
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
|
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
|
||||||
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
|
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
|
||||||
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
|
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
|
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
|
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
|
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
|
||||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
|
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
|
||||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
|
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
|
||||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
|
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
|
||||||
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
|
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
|
||||||
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
|
AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
|
||||||
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
|
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
|
||||||
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
|
AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
|
||||||
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
|
AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
|
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
|
||||||
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
|
||||||
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
|
||||||
|
|
||||||
# Temporary list of all the images that are generated to be populated into the grid.
|
# Temporary list of all the images that are generated to be populated into the grid.
|
||||||
# Will be filled with empty images for any individual step that fails to process properly
|
# Will be filled with empty images for any individual step that fails to process properly
|
||||||
image_cache = []
|
image_cache = [None] * (len(xs) * len(ys))
|
||||||
|
|
||||||
processed_result = None
|
processed_result = None
|
||||||
cell_mode = "P"
|
cell_mode = "P"
|
||||||
cell_size = (1,1)
|
cell_size = (1, 1)
|
||||||
|
|
||||||
state.job_count = len(xs) * len(ys) * p.n_iter
|
state.job_count = len(xs) * len(ys) * p.n_iter
|
||||||
|
|
||||||
for iy, y in enumerate(ys):
|
def process_cell(x, y, ix, iy):
|
||||||
for ix, x in enumerate(xs):
|
nonlocal image_cache, processed_result, cell_mode, cell_size
|
||||||
|
|
||||||
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
|
||||||
|
|
||||||
processed:Processed = cell(x, y)
|
processed: Processed = cell(x, y)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this dereference will throw an exception if the image was not processed
|
# this dereference will throw an exception if the image was not processed
|
||||||
# (this happens in cases such as if the user stops the process from the UI)
|
# (this happens in cases such as if the user stops the process from the UI)
|
||||||
@ -237,14 +239,23 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
|
|||||||
cell_size = processed_image.size
|
cell_size = processed_image.size
|
||||||
processed_result.images = [Image.new(cell_mode, cell_size)]
|
processed_result.images = [Image.new(cell_mode, cell_size)]
|
||||||
|
|
||||||
image_cache.append(processed_image)
|
image_cache[ix + iy * len(xs)] = processed_image
|
||||||
if include_lone_images:
|
if include_lone_images:
|
||||||
processed_result.images.append(processed_image)
|
processed_result.images.append(processed_image)
|
||||||
processed_result.all_prompts.append(processed.prompt)
|
processed_result.all_prompts.append(processed.prompt)
|
||||||
processed_result.all_seeds.append(processed.seed)
|
processed_result.all_seeds.append(processed.seed)
|
||||||
processed_result.infotexts.append(processed.infotexts[0])
|
processed_result.infotexts.append(processed.infotexts[0])
|
||||||
except:
|
except:
|
||||||
image_cache.append(Image.new(cell_mode, cell_size))
|
image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)
|
||||||
|
|
||||||
|
if swap_axes_processing_order:
|
||||||
|
for ix, x in enumerate(xs):
|
||||||
|
for iy, y in enumerate(ys):
|
||||||
|
process_cell(x, y, ix, iy)
|
||||||
|
else:
|
||||||
|
for iy, y in enumerate(ys):
|
||||||
|
for ix, x in enumerate(xs):
|
||||||
|
process_cell(x, y, ix, iy)
|
||||||
|
|
||||||
if not processed_result:
|
if not processed_result:
|
||||||
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
|
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
|
||||||
@ -405,6 +416,11 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
grid_infotext = [None]
|
grid_infotext = [None]
|
||||||
|
|
||||||
|
# If one of the axes is very slow to change between (like SD model
|
||||||
|
# checkpoint), then make sure it is in the outer iteration of the nested
|
||||||
|
# `for` loop.
|
||||||
|
swap_axes_processing_order = x_opt.cost > y_opt.cost
|
||||||
|
|
||||||
def cell(x, y):
|
def cell(x, y):
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
return Processed(p, [], p.seed, "")
|
return Processed(p, [], p.seed, "")
|
||||||
@ -443,7 +459,8 @@ class Script(scripts.Script):
|
|||||||
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
|
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
|
||||||
cell=cell,
|
cell=cell,
|
||||||
draw_legend=draw_legend,
|
draw_legend=draw_legend,
|
||||||
include_lone_images=include_lone_images
|
include_lone_images=include_lone_images,
|
||||||
|
swap_axes_processing_order=swap_axes_processing_order
|
||||||
)
|
)
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
|
Loading…
Reference in New Issue
Block a user