diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5582a6e5d..44d02349a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -2,16 +2,14 @@ function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); + var search = searchDiv.querySelector('textarea'); var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - search.classList.add('search'); - sort.classList.add('sort'); - sortOrder.classList.add('sortorder'); sort.dataset.sortkey = 'sortDefault'; - tabs.appendChild(search); + tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); @@ -179,7 +177,7 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksSearchButton(tabs_id, event) { - var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); + var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0713dbf0a..593abfef3 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" diff --git a/modules/images.py b/modules/images.py index 38aa933d6..ba3c43a45 100644 --- a/modules/images.py +++ b/modules/images.py @@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f8306643..96f52b7b4 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} diff --git a/modules/processing.py b/modules/processing.py index a9ee7507e..b9900ded2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -539,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +class DecodedSamples(list): + already_decoded = True + + def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): - samples = [] + samples = DecodedSamples() for i in range(batch.shape[0]): sample = decode_first_stage(model, batch[i:i + 1])[0] @@ -788,8 +792,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + if getattr(samples_ddim, 'already_decoded', False): + x_samples_ddim = samples_ddim + else: + p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -931,7 +939,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -942,11 +950,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y self.hr_upscale_to_x = hr_resize_x self.hr_upscale_to_y = hr_resize_y + self.hr_checkpoint_name = hr_checkpoint_name + self.hr_checkpoint_info = None self.hr_sampler_name = hr_sampler_name self.hr_prompt = hr_prompt self.hr_negative_prompt = hr_negative_prompt self.all_hr_prompts = None self.all_hr_negative_prompts = None + self.latent_scale_mode = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -969,6 +980,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if self.hr_checkpoint_name: + self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) + + if self.hr_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}') + + self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -978,6 +997,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") + if self.enable_hr and self.latent_scale_mode is None: + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): self.hr_resize_x = self.width self.hr_resize_y = self.height @@ -1016,14 +1040,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f - # special case: the user has chosen to do nothing - if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: - self.enable_hr = False - self.denoising_strength = None - self.extra_generation_params.pop("Hires upscale", None) - self.extra_generation_params.pop("Hires resize", None) - return - if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter @@ -1041,17 +1057,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") - if self.enable_hr and latent_scale_mode is None: - if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): - raise Exception(f"could not find upscaler named {self.hr_upscaler}") - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x if not self.enable_hr: return samples + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None + + current = shared.sd_model.sd_checkpoint_info + try: + if self.hr_checkpoint_info is not None: + self.sampler = None + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + devices.torch_gc() + + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) + finally: + self.sampler = None + sd_models.reload_model_weights(info=current) + devices.torch_gc() + + def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): self.is_hr_pass = True target_width = self.hr_upscale_to_x @@ -1069,11 +1100,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix") - if latent_scale_mode is not None: + img2img_sampler_name = self.hr_sampler_name or self.sampler_name + + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' + + self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) + + if self.latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1082,7 +1120,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: image_conditioning = self.txt2img_image_conditioning(samples) else: - decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] @@ -1108,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = self.hr_sampler_name or self.sampler_name - - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory - x = None devices.torch_gc() if not self.disable_extra_networks: @@ -1139,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + self.is_hr_pass = False - return samples + return decoded_samples def close(self): super().close() @@ -1180,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_c is not None: return - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) + hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() @@ -1189,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None self.hr_c = None - if self.enable_hr: + if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 609fd56c4..9ad981998 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 -ldm.modules.attention.print = lambda *args: None -ldm.modules.diffusionmodules.model.print = lambda *args: None +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print + +sd_hijack_inpainting.do_inpainting_hijack() optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b194..2d44b8566 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_models.py b/modules/sd_models.py index ba15b4518..f60516046 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -67,6 +66,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -87,6 +87,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -107,14 +108,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -137,11 +132,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -151,6 +149,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None @@ -430,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -444,6 +448,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -455,11 +460,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -467,19 +485,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -519,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -532,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -553,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -565,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return + return sd_model + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + + if sd_model is not None: sd_unet.apply_unet("None") - - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) - + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -585,7 +673,9 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -608,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + return sd_model diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index bc2195087..011233216 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -98,10 +98,10 @@ def extend_sdxl(model): model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None -sgm.modules.encoders.modules.print = lambda *args: None +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/shared.py b/modules/shared.py index df454d4a9..516ad7e80 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -220,12 +220,19 @@ class State: return import modules.sd_samplers - if opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) - else: - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - self.current_image_sampling_step = self.sampling_step + try: + if opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() def assign_current_image(self, image): self.current_image = image @@ -393,6 +400,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -412,17 +420,14 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), - "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), - "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), - "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(), "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), @@ -441,6 +446,22 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) + +options_templates.update(options_section(('img2img', "img2img"), { + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}), + "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(), + "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(), + "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_restart(), + "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(), + "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), + "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), +})) + + options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), @@ -460,7 +481,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), })) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { +options_templates.update(options_section(('interrogate', "Interrogate"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"), "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"), "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), @@ -493,10 +514,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { options_templates.update(options_section(('ui', "User interface"), { "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(), - "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(), "return_grid": OptionInfo(True, "Show grid in results for web"), - "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), - "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), @@ -515,11 +533,12 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), })) + options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), @@ -892,3 +911,10 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs) diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8cb..935ed4181 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, diff --git a/modules/ui.py b/modules/ui.py index 6cf3dff88..1af6b4c88 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,7 +261,6 @@ class Toprow: with gr.Row(): self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - self.button_interrogate = None self.button_deepbooru = None if is_img2img: @@ -290,7 +289,6 @@ class Toprow: with gr.Row(elem_id=f"{id_part}_tools"): self.paste = ToolButton(value=paste_symbol, elem_id="paste") self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) @@ -404,11 +402,10 @@ def create_ui(): dummy_component = gr.Label(visible=False) - with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: - from modules import ui_extra_networks - extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img') + extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs") + extra_tabs.__enter__() - with gr.Row(equal_height=False): + with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): scripts.scripts_txt2img.prepare_ui() @@ -456,6 +453,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: @@ -533,6 +534,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + hr_checkpoint_name, hr_sampler_index, hr_prompt, hr_negative_prompt, @@ -599,8 +601,9 @@ def create_ui(): (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), + (hr_checkpoint_name, "Hires checkpoint"), (hr_sampler_index, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), @@ -625,7 +628,11 @@ def create_ui(): toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + + extra_tabs.__exit__() scripts.scripts_current = scripts.scripts_img2img scripts.scripts_img2img.initialize_scripts(is_img2img=True) @@ -633,11 +640,10 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: toprow = Toprow(is_img2img=True) - with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: - from modules import ui_extra_networks - extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img') + extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") + extra_tabs.__enter__() - with FormRow(equal_height=False): + with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] copy_image_destinations = {} @@ -663,15 +669,15 @@ def create_ui(): add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height) + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) inpaint_color_sketch_orig = gr.State(None) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) @@ -959,8 +965,6 @@ def create_ui(): toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) - img2img_paste_fields = [ (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), @@ -989,6 +993,12 @@ def create_ui(): paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + extra_tabs.__exit__() + scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c6390db79..3a73c89e8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -155,7 +155,7 @@ class ExtraNetworksPage: subdirs = {"": 1, **subdirs} subdirs_html = "".join([f""" - """ for subdir in subdirs]) @@ -347,7 +347,7 @@ def pages_in_preferred_order(pages): return sorted(pages, key=lambda x: tab_scores[x.name]) -def create_ui(container, button, tabname): +def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -355,48 +355,41 @@ def create_ui(container, button, tabname): ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname - with gr.Tabs(elem_id=tabname+"_extra_tabs"): - for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page): - elem_id = f"{tabname}_{page.id_page}_cards_html" - page_elem = gr.HTML('Loading...', elem_id=elem_id) - ui.pages.append(page_elem) + related_tabs = [] - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + for page in ui.stored_extra_pages: + with gr.Tab(page.title, id=page.id_page) as tab: + elem_id = f"{tabname}_{page.id_page}_cards_html" + page_elem = gr.HTML('Loading...', elem_id=elem_id) + ui.pages.append(page_elem) - editor = page.create_user_metadata_editor(ui, tabname) - editor.create_ui() - ui.user_metadata_editors.append(editor) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) - gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) - ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) + + related_tabs.append(tab) + + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - def toggle_visibility(is_visible): - is_visible = not is_visible + for tab in unrelated_tabs: + tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) - return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) - - def fill_tabs(is_empty): - """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time.""" + for tab in related_tabs: + tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) + def pages_html(): if not ui.pages_contents: - refresh() + return refresh() - if is_empty: - return True, *ui.pages_contents - - return True, *[gr.update() for _ in ui.pages_contents] - - state_visible = gr.State(value=False) - button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False) - - state_empty = gr.State(value=True) - button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False) + return ui.pages_contents def refresh(): for pg in ui.stored_extra_pages: @@ -406,6 +399,7 @@ def create_ui(container, button, tabname): return ui.pages_contents + interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/style.css b/style.css index 14e6c0114..52919f719 100644 --- a/style.css +++ b/style.css @@ -779,9 +779,14 @@ footer { /* extra networks UI */ .extra-network-cards{ - height: 725px; - overflow: scroll; + height: calc(100vh - 24rem); + overflow: clip scroll; resize: vertical; + min-height: 52rem; +} + +.extra-networks > div.tab-nav{ + min-height: 3.4rem; } .extra-networks > div > [id *= '_extra_']{ @@ -797,7 +802,6 @@ footer { } .extra-networks .tab-nav .search, .extra-networks .tab-nav .sort{ - display: inline-block; margin: 0.3em; align-self: center; }