From cccc5a20fce4bde9a4299f8790366790735f1d05 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Sun, 16 Oct 2022 12:10:07 -0700 Subject: [PATCH] Safeguard setting restore logic against exceptions also useful for keeping settings cache and restore logic together, and nice for code reuse (other third party scripts can import this class) --- scripts/xy_grid.py | 48 +++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 88ad3bf78..5cca168a1 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ return processed_result +class SharedSettingsStackHelper(object): + def __enter__(self): + self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers + self.hypernetwork = opts.sd_hypernetwork + self.model = shared.sd_model + + def __exit__(self, exc_type, exc_value, tb): + modules.sd_models.reload_model_weights(self.model) + + hypernetwork.load_hypernetwork(self.hypernetwork) + hypernetwork.apply_strength() + + opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + + re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") @@ -267,9 +282,6 @@ class Script(scripts.Script): if not opts.return_grid: p.batch_size = 1 - - CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -367,27 +379,19 @@ class Script(scripts.Script): return process_images(pc) - processed = draw_xy_grid( - p, - xs=xs, - ys=ys, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images - ) + with SharedSettingsStackHelper(): + processed = draw_xy_grid( + p, + xs=xs, + ys=ys, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images + ) if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) - # restore checkpoint in case it was changed by axes - modules.sd_models.reload_model_weights(shared.sd_model) - - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - hypernetwork.apply_strength() - - - opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers - return processed