add refiner to StableDiffusionProcessing class

write out correct model name in infotext, rather than the refiner model
This commit is contained in:
AUTOMATIC1111 2023-08-13 06:07:30 +03:00
parent b2080756fc
commit fa9370b741
3 changed files with 36 additions and 20 deletions

View File

@ -111,7 +111,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None] cached_uc = [None, None]
cached_c = [None, None] cached_c = [None, None]
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None):
if sampler_index is not None: if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@ -153,10 +153,14 @@ class StableDiffusionProcessing:
self.s_noise = s_noise if s_noise is not None else opts.s_noise self.s_noise = s_noise if s_noise is not None else opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.refiner_checkpoint = refiner_checkpoint
self.refiner_switch_at = refiner_switch_at
self.is_using_inpainting_conditioning = False self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False self.disable_extra_networks = False
self.token_merging_ratio = 0 self.token_merging_ratio = 0
self.token_merging_ratio_hr = 0 self.token_merging_ratio_hr = 0
self.refiner_checkpoint_info = None
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
@ -191,6 +195,11 @@ class StableDiffusionProcessing:
self.user = None self.user = None
self.sd_model_name = None
self.sd_model_hash = None
self.sd_vae_name = None
self.sd_vae_hash = None
@property @property
def sd_model(self): def sd_model(self):
return shared.sd_model return shared.sd_model
@ -408,7 +417,10 @@ class Processed:
self.batch_size = p.batch_size self.batch_size = p.batch_size
self.restore_faces = p.restore_faces self.restore_faces = p.restore_faces
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
self.sd_model_hash = shared.sd_model.sd_model_hash self.sd_model_name = p.sd_model_name
self.sd_model_hash = p.sd_model_hash
self.sd_vae_name = p.sd_vae_name
self.sd_vae_hash = p.sd_vae_hash
self.seed_resize_from_w = p.seed_resize_from_w self.seed_resize_from_w = p.seed_resize_from_w
self.seed_resize_from_h = p.seed_resize_from_h self.seed_resize_from_h = p.seed_resize_from_h
self.denoising_strength = getattr(p, 'denoising_strength', None) self.denoising_strength = getattr(p, 'denoising_strength', None)
@ -459,7 +471,10 @@ class Processed:
"batch_size": self.batch_size, "batch_size": self.batch_size,
"restore_faces": self.restore_faces, "restore_faces": self.restore_faces,
"face_restoration_model": self.face_restoration_model, "face_restoration_model": self.face_restoration_model,
"sd_model_name": self.sd_model_name,
"sd_model_hash": self.sd_model_hash, "sd_model_hash": self.sd_model_hash,
"sd_vae_name": self.sd_vae_name,
"sd_vae_hash": self.sd_vae_hash,
"seed_resize_from_w": self.seed_resize_from_w, "seed_resize_from_w": self.seed_resize_from_w,
"seed_resize_from_h": self.seed_resize_from_h, "seed_resize_from_h": self.seed_resize_from_h,
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
@ -578,10 +593,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": opts.face_restoration_model if p.restore_faces else None, "Face restoration": opts.face_restoration_model if p.restore_faces else None,
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
"VAE": p.loaded_vae_name if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@ -670,8 +685,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None: if p.tiling is None:
p.tiling = opts.tiling p.tiling = opts.tiling
p.loaded_vae_name = sd_vae.get_loaded_vae_name() if p.refiner_checkpoint not in (None, "", "None"):
p.loaded_vae_hash = sd_vae.get_loaded_vae_hash() p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
p.sd_model_hash = shared.sd_model.sd_model_hash
p.sd_vae_name = sd_vae.get_loaded_vae_name()
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments() modules.sd_hijack.model_hijack.clear_comments()

View File

@ -41,15 +41,9 @@ class ScriptRefiner(scripts.Script):
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner # the actual implementation is in sd_samplers_common.py, apply_refiner
p.refiner_checkpoint_info = None
p.refiner_switch_at = None
if not enable_refiner or refiner_checkpoint in (None, "", "None"): if not enable_refiner or refiner_checkpoint in (None, "", "None"):
return p.refiner_checkpoint_info = None
p.refiner_switch_at = None
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint) else:
if refiner_checkpoint_info is None: p.refiner_checkpoint = refiner_checkpoint
raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}') p.refiner_switch_at = refiner_switch_at
p.refiner_checkpoint_info = refiner_checkpoint_info
p.refiner_switch_at = refiner_switch_at

View File

@ -145,7 +145,7 @@ def apply_refiner(cfg_denoiser):
refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if refiner_switch_at is not None and completed_ratio <= refiner_switch_at: if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: