diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9d..4e2865587 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" diff --git a/modules/processing.py b/modules/processing.py index 8086a2b02..ae58b108a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +class DecodedSamples(list): + already_decoded = True + + def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): - samples = [] + samples = DecodedSamples() for i in range(batch.shape[0]): sample = decode_first_stage(model, batch[i:i + 1])[0] @@ -788,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + if getattr(samples_ddim, 'already_decoded', False): + x_samples_ddim = samples_ddim + else: + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -930,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -941,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y self.hr_upscale_to_x = hr_resize_x self.hr_upscale_to_y = hr_resize_y + self.hr_checkpoint_name = hr_checkpoint_name + self.hr_checkpoint_info = None self.hr_sampler_name = hr_sampler_name self.hr_prompt = hr_prompt self.hr_negative_prompt = hr_negative_prompt self.all_hr_prompts = None self.all_hr_negative_prompts = None + self.latent_scale_mode = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -968,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if self.hr_checkpoint_name: + self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) + + if self.hr_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}') + + self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -977,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") + if self.enable_hr and self.latent_scale_mode is None: + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): self.hr_resize_x = self.width self.hr_resize_y = self.height @@ -1015,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f - # special case: the user has chosen to do nothing - if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: - self.enable_hr = False - self.denoising_strength = None - self.extra_generation_params.pop("Hires upscale", None) - self.extra_generation_params.pop("Hires resize", None) - return - if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter @@ -1040,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") - if self.enable_hr and latent_scale_mode is None: - if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): - raise Exception(f"could not find upscaler named {self.hr_upscaler}") - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x if not self.enable_hr: return samples + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None + + current = shared.sd_model.sd_checkpoint_info + try: + if self.hr_checkpoint_info is not None: + self.sampler = None + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + devices.torch_gc() + + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) + finally: + self.sampler = None + sd_models.reload_model_weights(info=current) + devices.torch_gc() + + def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): self.is_hr_pass = True target_width = self.hr_upscale_to_x @@ -1068,11 +1099,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix") - if latent_scale_mode is not None: + img2img_sampler_name = self.hr_sampler_name or self.sampler_name + + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' + + self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) + + if self.latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1081,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: image_conditioning = self.txt2img_image_conditioning(samples) else: - decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] @@ -1100,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. + decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) @@ -1107,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = self.hr_sampler_name or self.sampler_name - - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory - x = None devices.torch_gc() if not self.disable_extra_networks: @@ -1138,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + self.is_hr_pass = False - return samples + return decoded_samples def close(self): super().close() @@ -1179,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_c is not None: return - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) + hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() @@ -1188,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None self.hr_c = None - if self.enable_hr: + if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() diff --git a/modules/sd_models.py b/modules/sd_models.py index ba15b4518..1608b37f5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -67,6 +67,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -87,6 +88,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -107,14 +109,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -137,11 +133,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -151,6 +150,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None @@ -585,7 +589,10 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + sd_model.to(device="meta") + + devices.torch_gc() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model diff --git a/modules/shared.py b/modules/shared.py index 55199cb93..86ffb48fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -220,12 +220,19 @@ class State: return import modules.sd_samplers - if opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) - else: - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - self.current_image_sampling_step = self.sampling_step + try: + if opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() def assign_current_image(self, image): self.current_image = image @@ -521,7 +528,7 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), })) diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8cb..935ed4181 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, diff --git a/modules/ui.py b/modules/ui.py index 843a75ef8..c49538b67 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -456,6 +456,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: @@ -533,6 +537,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + hr_checkpoint_name, hr_sampler_index, hr_prompt, hr_negative_prompt, @@ -599,8 +604,9 @@ def create_ui(): (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), + (hr_checkpoint_name, "Hires checkpoint"), (hr_sampler_index, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),