Merge pull request #15836 from AUTOMATIC1111/xyz-override-rework

XYZ override rework
This commit is contained in:
AUTOMATIC1111 2024-06-08 10:40:14 +03:00 committed by GitHub
commit cbac72d8ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -95,17 +95,6 @@ def confirm_checkpoints_or_none(p, xs):
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
def apply_upscale_latent_space(p, x, xs):
if x.lower().strip() != '0':
opts.data["use_scale_latent_for_hires_fix"] = True
else:
opts.data["use_scale_latent_for_hires_fix"] = False
def apply_size(p, x: str, xs) -> None: def apply_size(p, x: str, xs) -> None:
try: try:
width, _, height = x.partition('x') width, _, height = x.partition('x')
@ -118,21 +107,16 @@ def apply_size(p, x: str, xs) -> None:
def find_vae(name: str): def find_vae(name: str):
if name.lower() in ['auto', 'automatic']: match name := name.lower().strip():
return modules.sd_vae.unspecified case 'auto', 'automatic':
if name.lower() == 'none': return 'Automatic'
return None case 'none':
else: return 'None'
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')
if len(choices) == 0:
print(f"No VAE found for {name}; using automatic")
return modules.sd_vae.unspecified
else:
return modules.sd_vae.vae_dict[choices[0]]
def apply_vae(p, x, xs): def apply_vae(p, x, xs):
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) p.override_settings['sd_vae'] = find_vae(x)
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
@ -140,7 +124,7 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
def apply_uni_pc_order(p, x, xs): def apply_uni_pc_order(p, x, xs):
opts.data["uni_pc_order"] = min(x, p.steps - 1) p.override_settings['uni_pc_order'] = min(x, p.steps - 1)
def apply_face_restore(p, opt, x): def apply_face_restore(p, opt, x):
@ -264,13 +248,13 @@ axis_options = [
AxisOption("Schedule max sigma", float, apply_override("sigma_max")), AxisOption("Schedule max sigma", float, apply_override("sigma_max")),
AxisOption("Schedule rho", float, apply_override("rho")), AxisOption("Schedule rho", float, apply_override("rho")),
AxisOption("Eta", float, apply_field("eta")), AxisOption("Eta", float, apply_field("eta")),
AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Clip skip", int, apply_override('CLIP_stop_at_last_layers')),
AxisOption("Denoising", float, apply_field("denoising_strength")), AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")), AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")),
AxisOption("Extra noise", float, apply_override("img2img_extra_noise")), AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
@ -399,18 +383,12 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
class SharedSettingsStackHelper(object): class SharedSettingsStackHelper(object):
def __enter__(self): def __enter__(self):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers pass
self.vae = opts.sd_vae
self.uni_pc_order = opts.uni_pc_order
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb):
opts.data["sd_vae"] = self.vae
opts.data["uni_pc_order"] = self.uni_pc_order
modules.sd_models.reload_model_weights() modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights() modules.sd_vae.reload_vae_weights()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")