diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index d37b428fc..a45e6d611 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -85,6 +85,23 @@ def confirm_checkpoints(p, xs): if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") +def apply_refiner_checkpoint(p, x, xs): + if x == 'None': + p.override_settings['sd_refiner_checkpoint'] = 'None' + return + + info = modules.sd_models.get_closet_checkpoint_match(x) + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + p.override_settings['sd_refiner_checkpoint'] = info.name + +def confirm_refiner_checkpoints(p, xs): + for x in xs: + if x == 'None': + continue + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -241,6 +258,8 @@ axis_options = [ AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), + AxisOption("Refiner checkpoint", str, apply_refiner_checkpoint, format_value=format_remove_path, confirm=confirm_refiner_checkpoints, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), + AxisOption("Refiner switch at", float, apply_override('sd_refiner_switch_at')) ]