From 304222ef94d1c3c60fab466a96c448868f391bce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 17 Sep 2022 13:49:36 +0300 Subject: [PATCH] X/Y plot support for switching checkpoints. --- modules/sd_models.py | 4 ++-- script.js | 2 ++ scripts/xy_grid.py | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 036af0e4f..4bd70fc5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -127,9 +127,9 @@ def load_model(): return sd_model -def reload_model_weights(sd_model): +def reload_model_weights(sd_model, info=None): from modules import lowvram, devices - checkpoint_info = select_checkpoint() + checkpoint_info = info or select_checkpoint() if sd_model.sd_model_checkpint == checkpoint_info.filename: return diff --git a/script.js b/script.js index 4a70e51d6..e63e06956 100644 --- a/script.js +++ b/script.js @@ -66,6 +66,8 @@ titles = { "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Apply style": "Insert selected styles into prompt fields", "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", + + "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", } function gradioApp(){ diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index eccfda877..680dd7025 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,9 @@ import gradio as gr from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state +import modules.shared as shared import modules.sd_samplers +import modules.sd_models import re @@ -41,6 +43,15 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index +def apply_checkpoint(p, x, xs): + applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title] + assert len(applicable) > 0, f'Checkpoint {x} for found' + + info = applicable[0] + + modules.sd_models.reload_model_weights(shared.sd_model, info) + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -74,6 +85,7 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Sampler", str, apply_sampler, format_value), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -215,4 +227,7 @@ class Script(scripts.Script): if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) + # restore checkpoint in case it was changed by axes + modules.sd_models.reload_model_weights(shared.sd_model) + return processed