mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
add refiner to StableDiffusionProcessing class
write out correct model name in infotext, rather than the refiner model
This commit is contained in:
parent
b2080756fc
commit
fa9370b741
@ -111,7 +111,7 @@ class StableDiffusionProcessing:
|
|||||||
cached_uc = [None, None]
|
cached_uc = [None, None]
|
||||||
cached_c = [None, None]
|
cached_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None):
|
||||||
if sampler_index is not None:
|
if sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
@ -153,10 +153,14 @@ class StableDiffusionProcessing:
|
|||||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
||||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||||
|
self.refiner_checkpoint = refiner_checkpoint
|
||||||
|
self.refiner_switch_at = refiner_switch_at
|
||||||
|
|
||||||
self.is_using_inpainting_conditioning = False
|
self.is_using_inpainting_conditioning = False
|
||||||
self.disable_extra_networks = False
|
self.disable_extra_networks = False
|
||||||
self.token_merging_ratio = 0
|
self.token_merging_ratio = 0
|
||||||
self.token_merging_ratio_hr = 0
|
self.token_merging_ratio_hr = 0
|
||||||
|
self.refiner_checkpoint_info = None
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
@ -191,6 +195,11 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
self.user = None
|
self.user = None
|
||||||
|
|
||||||
|
self.sd_model_name = None
|
||||||
|
self.sd_model_hash = None
|
||||||
|
self.sd_vae_name = None
|
||||||
|
self.sd_vae_hash = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
@ -408,7 +417,10 @@ class Processed:
|
|||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
self.restore_faces = p.restore_faces
|
self.restore_faces = p.restore_faces
|
||||||
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
||||||
self.sd_model_hash = shared.sd_model.sd_model_hash
|
self.sd_model_name = p.sd_model_name
|
||||||
|
self.sd_model_hash = p.sd_model_hash
|
||||||
|
self.sd_vae_name = p.sd_vae_name
|
||||||
|
self.sd_vae_hash = p.sd_vae_hash
|
||||||
self.seed_resize_from_w = p.seed_resize_from_w
|
self.seed_resize_from_w = p.seed_resize_from_w
|
||||||
self.seed_resize_from_h = p.seed_resize_from_h
|
self.seed_resize_from_h = p.seed_resize_from_h
|
||||||
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
||||||
@ -459,7 +471,10 @@ class Processed:
|
|||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"restore_faces": self.restore_faces,
|
"restore_faces": self.restore_faces,
|
||||||
"face_restoration_model": self.face_restoration_model,
|
"face_restoration_model": self.face_restoration_model,
|
||||||
|
"sd_model_name": self.sd_model_name,
|
||||||
"sd_model_hash": self.sd_model_hash,
|
"sd_model_hash": self.sd_model_hash,
|
||||||
|
"sd_vae_name": self.sd_vae_name,
|
||||||
|
"sd_vae_hash": self.sd_vae_hash,
|
||||||
"seed_resize_from_w": self.seed_resize_from_w,
|
"seed_resize_from_w": self.seed_resize_from_w,
|
||||||
"seed_resize_from_h": self.seed_resize_from_h,
|
"seed_resize_from_h": self.seed_resize_from_h,
|
||||||
"denoising_strength": self.denoising_strength,
|
"denoising_strength": self.denoising_strength,
|
||||||
@ -578,10 +593,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||||
"VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
@ -670,8 +685,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.tiling is None:
|
if p.tiling is None:
|
||||||
p.tiling = opts.tiling
|
p.tiling = opts.tiling
|
||||||
|
|
||||||
p.loaded_vae_name = sd_vae.get_loaded_vae_name()
|
if p.refiner_checkpoint not in (None, "", "None"):
|
||||||
p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
|
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
|
||||||
|
if p.refiner_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
|
||||||
|
|
||||||
|
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
|
||||||
|
p.sd_model_hash = shared.sd_model.sd_model_hash
|
||||||
|
p.sd_vae_name = sd_vae.get_loaded_vae_name()
|
||||||
|
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
@ -41,15 +41,9 @@ class ScriptRefiner(scripts.Script):
|
|||||||
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||||
|
|
||||||
p.refiner_checkpoint_info = None
|
|
||||||
p.refiner_switch_at = None
|
|
||||||
|
|
||||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||||
return
|
p.refiner_checkpoint_info = None
|
||||||
|
p.refiner_switch_at = None
|
||||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
|
else:
|
||||||
if refiner_checkpoint_info is None:
|
p.refiner_checkpoint = refiner_checkpoint
|
||||||
raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
|
p.refiner_switch_at = refiner_switch_at
|
||||||
|
|
||||||
p.refiner_checkpoint_info = refiner_checkpoint_info
|
|
||||||
p.refiner_switch_at = refiner_switch_at
|
|
||||||
|
@ -145,7 +145,7 @@ def apply_refiner(cfg_denoiser):
|
|||||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
|
|
||||||
if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||||
|
Loading…
Reference in New Issue
Block a user