From c4ee6d9b73300d906b8df4602157d646df2415ef Mon Sep 17 00:00:00 2001 From: Robert Barron Date: Sun, 30 Jul 2023 00:41:10 -0700 Subject: [PATCH 001/311] xyz_grid: allow varying the seed along an axis along with the axis's other changes --- scripts/xyz_grid.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1010845e5..4eb1b197d 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -416,6 +416,10 @@ class Script(scripts.Script): with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Column(): + vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) + vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) + vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -475,9 +479,9 @@ class Script(scripts.Script): (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -648,6 +652,16 @@ class Script(scripts.Script): y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + xdim = len(xs) if vary_seeds_x else 1 + ydim = len(ys) if vary_seeds_y else 1 + + if vary_seeds_x: + pc.seed += ix + if vary_seeds_y: + pc.seed += iy * xdim + if vary_seeds_z: + pc.seed += iz * xdim * ydim + res = process_images(pc) # Sets subgrid infotexts From 8aa13d5dce2789a7d0bd802e6d62453b3c380496 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 14:12:18 +0800 Subject: [PATCH 002/311] Interrupt after current generation --- modules/call_queue.py | 1 + modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 1 + modules/shared_state.py | 11 +++++++++-- scripts/loopback.py | 6 +++--- scripts/xyz_grid.py | 2 +- 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index ddf0d5738..01c6d17f6 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): shared.state.skipped = False shared.state.interrupted = False + shared.state.interrupted_next = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a6..31f8c2aaf 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -49,7 +49,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break try: diff --git a/modules/processing.py b/modules/processing.py index 40598f5cf..e7eecd66d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -819,7 +819,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 32bf73532..4638ef068 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -113,6 +113,7 @@ options_templates.update(options_section(('system', "System"), { "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), + "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index a68789cc8..c72c3f63c 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) class State: skipped = False interrupted = False + interrupted_next = False job = "" job_no = 0 job_count = 0 @@ -76,8 +77,12 @@ class State: log.info("Received skip request") def interrupt(self): - self.interrupted = True - log.info("Received interrupt request") + if shared.opts.interrupt_after_current and self.job_count > 1: + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") + else: + self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -91,6 +96,7 @@ class State: obj = { "skipped": self.skipped, "interrupted": self.interrupted, + "interrupted_next": self.interrupted_next, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -114,6 +120,7 @@ class State: self.id_live_preview = 0 self.skipped = False self.interrupted = False + self.interrupted_next = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/scripts/loopback.py b/scripts/loopback.py index 2d5feaf9b..ad921269a 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ class Script(scripts.Script): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted: + if state.interrupted or state.interrupted_next: break if initial_seed is None: @@ -122,8 +122,8 @@ class Script(scripts.Script): p.inpainting_fill = original_inpainting_fill - if state.interrupted: - break + if state.interrupted or state.interrupted_next: + break if len(history) > 1: grid = images.image_grid(history, rows=1) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc4..495008ada 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -688,7 +688,7 @@ class Script(scripts.Script): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: + if shared.state.interrupted or state.interrupted_next: return Processed(p, [], p.seed, "") pc = copy(p) From 3d15e58b0a30f2ef1e731f9e429f4d3cf1c259c5 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 15:00:17 +0800 Subject: [PATCH 003/311] feat: refactor --- modules/shared_state.py | 12 ++++++------ modules/ui.py | 8 +++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/modules/shared_state.py b/modules/shared_state.py index c72c3f63c..532fdcd8d 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -77,12 +77,12 @@ class State: log.info("Received skip request") def interrupt(self): - if shared.opts.interrupt_after_current and self.job_count > 1: - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") - else: - self.interrupted = True - log.info("Received interrupt request") + self.interrupted = True + log.info("Received interrupt request") + + def interrupt_next(self): + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: diff --git a/modules/ui.py b/modules/ui.py index bcf391997..c30093d75 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -216,8 +216,14 @@ class Toprow: outputs=[], ) + def interrupt_fn(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.interrupt_next() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_fn, inputs=[], outputs=[], ) From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: [PATCH 004/311] Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 11 files changed, 36 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8de..a62e5eff9 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e96..f221c95f3 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d48707..efe5c6814 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695f..d95a0fd18 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc42497..96faeaf3e 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab3..fcdaeafd8 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c2..4cc402951 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce4501580..d25afcbb9 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4c..8ea4ea609 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a842..0f14c71e4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,3 +118,4 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b6cdea18..3b8ff8209 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,6 +391,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + if shared.cmd_opts.opt_unet_fp8_storage: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: [PATCH 005/311] Add sdxl only arg --- modules/cmd_args.py | 1 + modules/sd_models.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 0f14c71e4..20bfb2c44 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -119,3 +119,4 @@ parser.add_argument("--disable-all-extensions", action='store_true', help="preve parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) +parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff8209..08af128fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 01:49:05 +0800 Subject: [PATCH 006/311] Add CPU fp8 support Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear) And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet. --- modules/devices.py | 6 +++++- modules/processing.py | 2 +- modules/sd_models.py | 20 ++++++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 1d4eb5635..0cd2b55d2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 40598f5cf..2df8a7ea7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index 08af128fc..c5fe57bfe 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -391,12 +391,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + + if shared.cmd_opts.opt_unet_fp8_storage: + enable_fp8 = True + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + enable_fp8 = True + + if enable_fp8: + devices.fp8 = True + if devices.device == devices.cpu: + for module in model.model.diffusion_model.modules(): + if isinstance(module, torch.nn.Conv2d): + module.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for cpu") + else: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 24 Oct 2023 02:11:27 +0800 Subject: [PATCH 007/311] Fix lint --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index c5fe57bfe..44d4038b2 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,7 +396,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True - + if enable_fp8: devices.fp8 = True if devices.device == devices.cpu: From 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:36:43 +0800 Subject: [PATCH 008/311] fp8 for TE --- modules/sd_models.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b2..693952943 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) timer.record("apply fp8 unet for cpu") else: + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") From 4830b251366436ee8499c003fe87e46ddb4a4581 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:53:37 +0800 Subject: [PATCH 009/311] Fix alphas_cumprod dtype --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 693952943..236604545 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -416,6 +416,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: [PATCH 010/311] Fix alphas cumprod --- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 236604545..7ed89a9c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 011233216..11259a369 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -93,7 +93,7 @@ def extend_sdxl(model): model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() - model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) model.conditioner.wrapped = torch.nn.Module() From dda067f64d3289cee3ffd65767126cb30ae73b13 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 19:53:22 +0800 Subject: [PATCH 011/311] ignore mps for fp8 --- modules/sd_models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 7ed89a9c9..ccb6afd2a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -392,7 +392,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if shared.cmd_opts.opt_unet_fp8_storage: + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.cmd_opts.opt_unet_fp8_storage: enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True From 0beb131c7ffae6f756a6339206da311232a36970 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 20:07:37 +0800 Subject: [PATCH 012/311] change torch version --- modules/launch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8cdbafa50..636da679e 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,8 +308,8 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') From d4d3134f6d2d232c7bcfa80900a362921e644976 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 28 Oct 2023 15:24:26 +0800 Subject: [PATCH 013/311] ManualCast for 10/16 series gpu --- modules/devices.py | 53 +++++++++++++++++++++++++++++++++++++++---- modules/processing.py | 2 +- modules/sd_models.py | 21 +++++++++-------- 3 files changed, 62 insertions(+), 14 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d2..c05f2b35e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward -def autocast(disable=False, unet=False): +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea7..40598f5cf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2a..31bcb913a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 From ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Sat, 28 Oct 2023 16:52:35 +0800 Subject: [PATCH 014/311] Add MPS manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index c05f2b35e..d7c905c28 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_autocast(): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From f2b83517aa49268219706f9718d734c6f242fd93 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:13 +0000 Subject: [PATCH 015/311] Add new arguments to known command prompts --- modules/cmd_args.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 4e602a842..c6d4b6123 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -76,7 +76,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) -parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False) +parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None) +parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -90,7 +92,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -112,9 +114,8 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') +parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) -parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) From 844c23975f369f85c7179379ec419e1bd067de18 Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 15:40:58 +0000 Subject: [PATCH 016/311] Add assertions for checking additional settings freezing parameters --- modules/options.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/modules/options.py b/modules/options.py index ab40aff73..1ac32d950 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,18 +85,35 @@ class Options: if self.data is not None: if key in self.data or key in self.data_labels: + + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" + # Get the info related to the setting being changed info = self.data_labels.get(key, None) if info.do_not_save: return + # Restrict component arguments comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted") + # Check that this section isn't frozen + if cmd_opts.freeze_settings_in_sections is not None: + frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names + section_key = info.section[0] + section_name = info.section[1] + assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections" + + # Check that this section of the settings isn't frozen + if cmd_opts.freeze_specific_settings is not None: + frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys + assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings" + + # Check shorthand option which disables editing options in "saving-paths" if cmd_opts.hide_ui_dir_config and key in self.restricted_opts: - raise RuntimeError(f"not possible to set {key} because it is restricted") + raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config") self.data[key] = value return @@ -210,8 +227,6 @@ class Options: def add_option(self, key, info): self.data_labels[key] = info - if key not in self.data: - self.data[key] = info.default def reorder(self): """reorder settings so that all items related to section always go together""" From be31e7e71a08dc27543d31aa6e6532463ccbf20f Mon Sep 17 00:00:00 2001 From: Nick Harrison <42382362+nickpharrison@users.noreply.github.com> Date: Sun, 29 Oct 2023 16:05:01 +0000 Subject: [PATCH 017/311] Remove blank line whitespace --- modules/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/options.py b/modules/options.py index 1ac32d950..e270a42df 100644 --- a/modules/options.py +++ b/modules/options.py @@ -85,7 +85,7 @@ class Options: if self.data is not None: if key in self.data or key in self.data_labels: - + # Check that settings aren't globally frozen assert not cmd_opts.freeze_settings, "changing settings is disabled" From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: [PATCH 018/311] Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 +++++++++------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++------------------ modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5deaa..a9fb9bfa3 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c28..03e7bdb7c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d895..1b11ead61 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa7..eb4914347 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f218..d27f35e96 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc4..b2250c04d 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ axis_options = [ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] From 890181e1d456b613bf60f6e8378dc68b39011af9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:54:39 +0800 Subject: [PATCH 019/311] Update the xformers/torch versions --- modules/errors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 8c339464d..a3498c11f 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -93,8 +93,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): From f383af2729ec2d1969200218577ab19dd78f7d48 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:23 +0800 Subject: [PATCH 020/311] update xformers/torch versions --- modules/launch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 636da679e..c225bbc18 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -312,7 +312,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From 043d2edcf6a543f236f1f3cb70ac72e7b3b357b6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:56:31 +0800 Subject: [PATCH 021/311] Better naming --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 03e7bdb7c..c19a7f402 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs): @contextlib.contextmanager -def manual_autocast(): +def manual_cast(): for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward @@ -148,10 +148,10 @@ def autocast(disable=False): return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_autocast() + return manual_cast() if has_mps() and shared.cmd_opts.precision != "full": - return manual_autocast() + return manual_cast() if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() From b2e039d07bed76350120ff448964c907a3b5e4a3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 20 Nov 2023 14:05:32 +0800 Subject: [PATCH 022/311] Update webui-macos-env.sh --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c426..db7e8b1a0 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: [PATCH 023/311] Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- modules/initialize_util.py | 1 + modules/sd_models.py | 14 +++++++++++--- modules/shared_options.py | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfbd..d22ed8438 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 1b11ead61..7fb1d8d50 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,6 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index eb4914347..0a7777f1c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") diff --git a/modules/shared_options.py b/modules/shared_options.py index d27f35e96..eaa9f1352 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -201,6 +201,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From f5d719d1f1baa775d838aa75d9af1971bcc78e8f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 22 Nov 2023 01:45:56 +0800 Subject: [PATCH 024/311] Add forced reload for fp16 cache --- modules/initialize_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 7fb1d8d50..b6767138d 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,7 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) - shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") From 40ac134c553ac824d4a96666bba14d550300daa5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 25 Nov 2023 12:35:09 +0800 Subject: [PATCH 025/311] Fix pre-fp8 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a7777f1c..90437c87a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - if not check_fp8(model) and devices.fp8: + if devices.fp8: # prevent model to load state dict in fp8 model.half() From 29f04149b60bcf6e8e2b41a161d6cc7e8981710f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 26 Nov 2023 12:07:33 +0300 Subject: [PATCH 026/311] update torch to 2.1.0 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 6 +++--- webui-macos-env.sh | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index eb234a838..c534a5d66 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.0.0" - expected_xformers_version = "0.0.20" + expected_torch_version = "2.1.0" + expected_xformers_version = "0.0.22.post7" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 264ec9ca6..1f2b6c5e8 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -308,11 +308,11 @@ def requirements_met(requirements_file): def prepare_environment(): - torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") + torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 24bc5c426..db7e8b1a0 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" +export TORCH_COMMAND="pip install torch==2.1.0 torchvision==0.16.0" export PYTORCH_ENABLE_MPS_FALLBACK=1 #################################################################### From dec791d35ddcd02ca33563d3d0355e05e45de8ad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:05:01 -0700 Subject: [PATCH 027/311] Removed code which forces the inpainting mask to be 0 or 1. Now fractional values (e.g. 0.5) are accepted. --- modules/processing.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index e124e7f0d..317458f58 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -83,7 +83,7 @@ def apply_overlay(image, paste_loc, index, overlays): def create_binary_mask(image): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -319,9 +319,6 @@ class StableDiffusionProcessing: conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) - - # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 - conditioning_mask = torch.round(conditioning_mask) else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -1504,7 +1501,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) From bbba133f054706c3668b7d03b0e6d0afc15705db Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 15:09:43 -0700 Subject: [PATCH 028/311] Removed conflicting step that replaces the softly inpainted latents with a naive blend with the original latents. --- modules/processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 317458f58..ae894f1a7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1523,9 +1523,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None: - samples = samples * self.nmask + self.init_latent * self.mask - del x devices.torch_gc() From e715e46b6aa7f2e5e147cfa1fa2f49b1d926a074 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:10:22 -0700 Subject: [PATCH 029/311] Implements "scheduling" for blending of the original latents and a latent blending formula that preserves details in blend transition areas. --- modules/sd_samplers_cfg_denoiser.py | 61 ++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38d..c4d6fda65 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -43,6 +43,9 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.mask = None self.nmask = None + self.mask_blend_power = 1 + self.mask_blend_scale = 1 + self.mask_blend_offset = 0 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -56,6 +59,9 @@ class CFGDenoiser(torch.nn.Module): self.sampler = sampler self.model_wrap = None self.p = None + + # NOTE: masking before denoising can cause the original latents to be oversmoothed + # as the original latents do not have noise self.mask_before_denoising = False @property @@ -89,6 +95,55 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + def latent_blend(a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + # Record the original latent vector magnitudes. + # We bring them to a power so that larger magnitudes are favored over smaller ones. + # 64-bit operations are used here to allow large exponents. + detail_preservation = 32 + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + + one_minus_t = 1 - t + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + + # Linearly interpolate the image vectors. + image_interp = a * one_minus_t + b * t + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) + + return image_interp + + def get_modified_nmask(nmask, _sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -105,8 +160,9 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = self.init_latent * self.mask + self.nmask * x + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -207,8 +263,9 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) From a6e584645305c0a91a3d46f73546e191b249210f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:13:42 -0700 Subject: [PATCH 030/311] Nerfs the aggressive post-processing step of overlaying the original image. --- modules/processing.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index ae894f1a7..12e08e876 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1412,7 +1412,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1423,8 +1428,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + np_mask = np.array(image_mask).astype(np.float32) + np_mask /= 255 + np_mask = 1-pow(1-np_mask, 100) + np_mask *= 255 + np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) self.mask_for_overlay = Image.fromarray(np_mask) self.overlay_images = [] From debf836fcc8d9becc3da8b1a29e33f40b0d9ef3e Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 16:15:36 -0700 Subject: [PATCH 031/311] Added UI elements to control blending parameters. --- modules/img2img.py | 48 +++++++++++++++++++++++++++++++- modules/processing.py | 3 ++ modules/sd_samplers_common.py | 3 ++ modules/ui.py | 9 ++++++ scripts/outpainting_mk_2.py | 10 +++++-- scripts/poor_mans_outpainting.py | 11 ++++++-- test/test_img2img.py | 3 ++ 7 files changed, 82 insertions(+), 5 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 1519e132b..240d05884 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -116,7 +116,47 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal process_images(p) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, + mode: int, + prompt: str, + negative_prompt: str, + prompt_styles, + init_img, + sketch, + init_img_with_mask, + inpaint_color_sketch, + inpaint_color_sketch_orig, + init_img_inpaint, + init_mask_inpaint, + steps: int, + sampler_name: str, + mask_blur: int, + mask_alpha: float, + mask_blend_power: float, + mask_blend_scale: float, + mask_blend_offset: float, + inpainting_fill: int, + n_iter: int, + batch_size: int, + cfg_scale: float, + image_cfg_scale: float, + denoising_strength: float, + selected_scale_tab: int, + height: int, + width: int, + scale_by: float, + resize_mode: int, + inpaint_full_res: bool, + inpaint_full_res_padding: int, + inpainting_mask_invert: int, + img2img_batch_input_dir: str, + img2img_batch_output_dir: str, + img2img_batch_inpaint_mask_dir: str, + override_settings_texts, + img2img_batch_use_png_info: bool, + img2img_batch_png_info_props: list, + img2img_batch_png_info_dir: str, + request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -174,6 +214,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s init_images=[image], mask=mask, mask_blur=mask_blur, + mask_blend_power=mask_blend_power, + mask_blend_scale=mask_blend_scale, + mask_blend_offset=mask_blend_offset, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -194,6 +237,9 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + p.extra_generation_params["Mask blend power"] = mask_blend_power + p.extra_generation_params["Mask blend scale"] = mask_blend_scale + p.extra_generation_params["Mask blend offset"] = mask_blend_offset with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index 12e08e876..da4d6fda9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1349,6 +1349,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None + mask_blend_power: float = 1 + mask_blend_scale: float = 1 + mask_blend_offset: float = 0 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 58efcad23..8904da2fb 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,6 +277,9 @@ class Sampler: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None + self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None + self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 579bab980..86c130869 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,6 +732,9 @@ def create_ui(): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -781,6 +784,9 @@ def create_ui(): sampler_name, mask_blur, mask_alpha, + mask_blend_power, + mask_blend_scale, + mask_blend_offset, inpainting_fill, batch_count, batch_size, @@ -879,6 +885,9 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (mask_blend_power, "Mask blend power"), + (mask_blend_scale, "Mask blend scale"), + (mask_blend_offset, "Mask blend offset"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c98ab4809..6aa97edfa 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,13 +133,16 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -167,6 +170,9 @@ class Script(scripts.Script): p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index ea0632b68..b10140f14 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,16 +22,23 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 + p.mask_blend_power = mask_blend_power + p.mask_blend_scale = mask_blend_scale + p.mask_blend_offset = mask_blend_offset + p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 117d2d1eb..6289e59e1 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -24,6 +24,9 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, + "mask_blend_power": 1, + "mask_blend_scale": 1, + "mask_blend_offset": 0, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From c5c7fa06aae1ae9f8b6d29ae2da3874921d4729b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 22:35:07 -0700 Subject: [PATCH 032/311] Added slider for detail preservation strength, removed largely needless offset parameter, changed labels in UI and for saving to/pasting data from PNG files. --- modules/img2img.py | 10 +++++----- modules/processing.py | 2 +- modules/sd_samplers_cfg_denoiser.py | 11 +++++------ modules/sd_samplers_common.py | 2 +- modules/ui.py | 14 +++++++------- scripts/outpainting_mk_2.py | 12 ++++++------ scripts/poor_mans_outpainting.py | 12 ++++++------ test/test_img2img.py | 2 +- 8 files changed, 32 insertions(+), 33 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 240d05884..023808d6c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -134,7 +134,7 @@ def img2img(id_task: str, mask_alpha: float, mask_blend_power: float, mask_blend_scale: float, - mask_blend_offset: float, + inpaint_detail_preservation: float, inpainting_fill: int, n_iter: int, batch_size: int, @@ -216,7 +216,7 @@ def img2img(id_task: str, mask_blur=mask_blur, mask_blend_power=mask_blend_power, mask_blend_scale=mask_blend_scale, - mask_blend_offset=mask_blend_offset, + inpaint_detail_preservation=inpaint_detail_preservation, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -237,9 +237,9 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blend power"] = mask_blend_power - p.extra_generation_params["Mask blend scale"] = mask_blend_scale - p.extra_generation_params["Mask blend offset"] = mask_blend_offset + p.extra_generation_params["Mask blending bias"] = mask_blend_power + p.extra_generation_params["Mask blending preservation"] = mask_blend_scale + p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index da4d6fda9..361e8b05d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1351,7 +1351,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur: int = None mask_blend_power: float = 1 mask_blend_scale: float = 1 - mask_blend_offset: float = 0 + inpaint_detail_preservation: float = 16 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index c4d6fda65..598cd4876 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -45,7 +45,7 @@ class CFGDenoiser(torch.nn.Module): self.nmask = None self.mask_blend_power = 1 self.mask_blend_scale = 1 - self.mask_blend_offset = 0 + self.inpaint_detail_preservation = 16 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -105,14 +105,13 @@ class CFGDenoiser(torch.nn.Module): # Record the original latent vector magnitudes. # We bring them to a power so that larger magnitudes are favored over smaller ones. # 64-bit operations are used here to allow large exponents. - detail_preservation = 32 - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** detail_preservation + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation one_minus_t = 1 - t # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / detail_preservation) + interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) # Linearly interpolate the image vectors. image_interp = a * one_minus_t + b * t @@ -142,7 +141,7 @@ class CFGDenoiser(torch.nn.Module): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale + self.mask_blend_offset) + return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 8904da2fb..ecd8ab0a0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -279,7 +279,7 @@ class Sampler: self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.mask_blend_offset = p.mask_blend_offset if hasattr(p, 'mask_blend_offset') else None + self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/ui.py b/modules/ui.py index 86c130869..f5e201477 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -732,9 +732,9 @@ def create_ui(): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_scale") - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id="img2img_mask_blend_offset") + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -786,7 +786,7 @@ def create_ui(): mask_alpha, mask_blend_power, mask_blend_scale, - mask_blend_offset, + inpaint_detail_preservation, inpainting_fill, batch_count, batch_size, @@ -885,9 +885,9 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blend power"), - (mask_blend_scale, "Mask blend scale"), - (mask_blend_offset, "Mask blend offset"), + (mask_blend_power, "Mask blending bias"), + (mask_blend_scale, "Mask blending preservation"), + (inpaint_detail_preservation, "Mask blending detail boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 6aa97edfa..54d95825a 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -133,16 +133,16 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend scale', minimum=-4, maximum=4, step=0.1, value=1, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation] + return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -172,7 +172,7 @@ class Script(scripts.Script): p.mask_blur_y = mask_blur_y*4 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index b10140f14..e3acb3d47 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,22 +22,22 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Mask blend power', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Mask blend scale', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - mask_blend_offset = gr.Slider(label='Mask blend offset', minimum=-4, maximum=4, step=0.1, value=0, elem_id=self.elem_id("mask_blend_offset")) + mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction] + return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, mask_blend_offset, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 p.mask_blend_power = mask_blend_power p.mask_blend_scale = mask_blend_scale - p.mask_blend_offset = mask_blend_offset + p.inpaint_detail_preservation = inpaint_detail_preservation p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 6289e59e1..88b06eb8d 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -26,7 +26,7 @@ def simple_img2img_request(img2img_basic_image_base64): "mask_blur": 4, "mask_blend_power": 1, "mask_blend_scale": 1, - "mask_blend_offset": 0, + "inpaint_detail_preservation": 16, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 284fd8f415ec70e14ae5de0b7f5ce738007a6b7f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:03:50 -0700 Subject: [PATCH 033/311] Tweaked UI sliders and labels. --- modules/img2img.py | 2 +- modules/ui.py | 6 +++--- scripts/outpainting_mk_2.py | 4 ++-- scripts/poor_mans_outpainting.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 023808d6c..0ae163654 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -239,7 +239,7 @@ def img2img(id_task: str, p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blending bias"] = mask_blend_power p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending detail boost"] = inpaint_detail_preservation + p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index f5e201477..3a9038b22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -733,8 +733,8 @@ def create_ui(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=1, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id="img2img_mask_blend_offset") + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -887,7 +887,7 @@ def create_ui(): (mask_blur, "Mask blur"), (mask_blend_power, "Mask blending bias"), (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending detail boost"), + (inpaint_detail_preservation, "Mask blending contrast boost"), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 54d95825a..bd9cb61bf 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -134,8 +134,8 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index e3acb3d47..5388f5db4 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -23,8 +23,8 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending detail boost', minimum=1, maximum=32, step=0.5, value=16, elem_id=self.elem_id("inpaint_detail_preservation")) + mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) + inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) From c7a1ff87207544dd4bcf3aefffa67a4a38678c16 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Tue, 28 Nov 2023 23:31:10 -0700 Subject: [PATCH 034/311] Tweaked default values. --- modules/processing.py | 4 ++-- modules/sd_samplers_cfg_denoiser.py | 4 ++-- test/test_img2img.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 361e8b05d..92fdebadd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1350,8 +1350,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_y: int = 4 mask_blur: int = None mask_blend_power: float = 1 - mask_blend_scale: float = 1 - inpaint_detail_preservation: float = 16 + mask_blend_scale: float = 0.5 + inpaint_detail_preservation: float = 4 inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 598cd4876..ceb612d79 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -44,8 +44,8 @@ class CFGDenoiser(torch.nn.Module): self.mask = None self.nmask = None self.mask_blend_power = 1 - self.mask_blend_scale = 1 - self.inpaint_detail_preservation = 16 + self.mask_blend_scale = 0.5 + self.inpaint_detail_preservation = 4 self.init_latent = None self.steps = None """number of steps as specified by user in UI""" diff --git a/test/test_img2img.py b/test/test_img2img.py index 88b06eb8d..5cda2dbae 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -25,8 +25,8 @@ def simple_img2img_request(img2img_basic_image_base64): "mask": None, "mask_blur": 4, "mask_blend_power": 1, - "mask_blend_scale": 1, - "inpaint_detail_preservation": 16, + "mask_blend_scale": 0.5, + "inpaint_detail_preservation": 4, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From b25c126ccdbc4da22ade46597a9addf808998989 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:38:53 -0500 Subject: [PATCH 035/311] Protect alphas_cumprod from downcasting --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e86..de80a493a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -387,7 +387,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.upcast_sampling and depth_model: model.depth_model = None + alphas_cumprod = model.alphas_cumprod + model.alphas_cumprod = None model.half() + model.alphas_cumprod = alphas_cumprod + model.alphas_cumprod_original = alphas_cumprod model.first_stage_model = vae if depth_model: model.depth_model = depth_model @@ -642,6 +646,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: weight_dtype_conversion = { 'first_stage_model': None, + 'alphas_cumprod': None, '': torch.float16, } From 588a52891dca4d030ca7028dd9c0b56022a68b57 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:40:23 -0500 Subject: [PATCH 036/311] Add options for zero terminal SNR --- modules/shared_options.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index 04e68a712..515967770 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,6 +218,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd" "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -335,6 +336,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From 6d0a8dcd892f7ad9b399fed6edbad6ede13c5f69 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:42:07 -0500 Subject: [PATCH 037/311] Implement zero terminal SNR schedule option --- modules/processing.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index ac58ef869..c88eec705 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,6 +863,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From ec6ee5c13bf3453f8703e225a191333a9bbcf10a Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:27 -0500 Subject: [PATCH 038/311] Fix infotext for ztSNR --- modules/shared_options.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 515967770..bc3d56dec 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -218,7 +218,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd" "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), - "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.") + "use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod") })) options_templates.update(options_section(('interrogate', "Interrogate"), { @@ -336,7 +336,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), - 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise schedule for sampling").info("for use with zero terminal SNR trained models") + 'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models") })) options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), { From ffa7f8201d849636bb327b3b40298e7c169ff204 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:10:43 -0500 Subject: [PATCH 039/311] Lint --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c88eec705..f3883d5b3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,7 +863,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() @@ -881,7 +881,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From de79597ab9894965e3702939b8536ec3dcc3c859 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 29 Nov 2023 18:33:32 -0500 Subject: [PATCH 040/311] Only apply ztSNR related code if alphas_cumprod exists --- modules/processing.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f3883d5b3..7e73d7e29 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 668ae34e21df848ef4909b8b49c4142a3674701b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 22:48:31 -0500 Subject: [PATCH 041/311] remove debug print --- modules/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 7e73d7e29..d73c8bfc1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -890,7 +890,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) if opts.sd_noise_schedule == "Zero Terminal SNR": p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - print("rescaling noise schedule for zero snr") p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): From 50a21cb09fe3e9ea2d4fe058e0484e192c8a86e3 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 2 Dec 2023 22:06:47 +0800 Subject: [PATCH 042/311] Ensure the cached weight will not be affected --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b8a9ae65..dcf816b3a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -435,9 +435,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer for module in model.modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): if shared.opts.cache_fp16_weight: - module.fp16_weight = module.weight.clone().half() + module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: - module.fp16_bias = module.bias.clone().half() + module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") From 309a606c2fa645b6b8623f96ea56117e685a47fb Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:07:45 -0500 Subject: [PATCH 043/311] ensure that original alpha bar always exists --- modules/processing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index d73c8bfc1..bfa59038b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,15 +882,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From 81c4ddf6ebebe6f18338de3b0391da1d8521a525 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 2 Dec 2023 13:11:00 -0500 Subject: [PATCH 044/311] fix linting --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bfa59038b..eeccea743 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod - + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if opts.use_downcasted_alpha_bar: From 4a43334376d9e116f7a1446f042f9af9c0484fc6 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:05:42 -0500 Subject: [PATCH 045/311] Revert 309a606c --- modules/processing.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index eeccea743..d73c8bfc1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -882,17 +882,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: alphas_bar[-1] = 4.8973451890853435e-08 return alphas_bar - if hasattr(p.sd_model, 'alphas_cumprod') and not hasattr(p.sd_model, 'alphas_cumprod_original'): - p.sd_model.alphas_cumprod_original = p.sd_model.alphas_cumprod + if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'): + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) From dc1adeecdd02f3fb910481e808a6d60a77100fea Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:06:56 -0500 Subject: [PATCH 046/311] Create alphas_cumprod_original on full precision path --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index de80a493a..976c7d5be 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,6 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + model.alphas_cumprod_original = alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 78acdcf677a96894651ff0d7d8287f2a994f3781 Mon Sep 17 00:00:00 2001 From: drhead Date: Sat, 2 Dec 2023 14:09:18 -0500 Subject: [PATCH 047/311] fix variable --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 976c7d5be..5a19a00a5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -374,7 +374,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() - model.alphas_cumprod_original = alphas_cumprod + model.alphas_cumprod_original = model.alphas_cumprod devices.dtype_unet = torch.float32 timer.record("apply float()") else: From 609dea36ea919aa7db42fd4233c416a45c74578b Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 18:56:49 -0700 Subject: [PATCH 048/311] Added utility functions related to processing masks. --- modules/images.py | 191 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/modules/images.py b/modules/images.py index eb6447338..b5a0cead6 100644 --- a/modules/images.py +++ b/modules/images.py @@ -776,3 +776,194 @@ def flatten(img, bgcolor): img = background return img.convert('RGB') + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x /= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + From 73ab982d1b7394574d1cf2e0a151bc457eeed769 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:07:02 -0700 Subject: [PATCH 049/311] Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content. --- modules/processing.py | 119 ++++++++++++++++++++++++++++++++---------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92fdebadd..ad716e11f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field import torch import numpy as np -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageFilter import random import cv2 from skimage import exposure @@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + def apply_overlay(image, paste_loc, index, overlays): if overlays is None or index >= len(overlays): return image @@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays): overlay = overlays[index] if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -140,6 +146,7 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None + masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = 0 @@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim + # todo: generate masks the old fashioned way else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + # Generate the mask(s) based on similarity between the original and denoised latent vectors + if getattr(p, "image_mask", None) is not None: + # latent_mask = p.nmask[0].float().cpu() + + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_orig = p.init_latent + latent_proc = samples_ddim + latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, p.width, p.height) + converted_mask = create_binary_mask(converted_mask) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if p.paste_to is not None: + converted_mask = uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + p.paste_to) + + p.masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + p.overlay_images[i] = image_masked.convert('RGBA') + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, + target_device=devices.cpu, + check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, + use_main_prompt=use_main_prompt, index=index, + all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -923,19 +987,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], + p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') if opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") @@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) - mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1415,12 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1431,13 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) + self.masks_for_overlay = [] self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1459,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1486,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + if self.masks_for_overlay is not None: + self.masks_for_overlay = self.masks_for_overlay * self.batch_size + if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size From bb04d400c95df01d191ef6c1a43e66b95425fa33 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:08:26 -0700 Subject: [PATCH 050/311] Rewrote latent_blend() to use in-place operations and to aggressively "del" references with the intention of minimizing allocations and easing garbage collection. --- modules/sd_samplers_cfg_denoiser.py | 39 ++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index ceb612d79..efbe7a403 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -102,29 +102,44 @@ class CFGDenoiser(torch.nn.Module): The "detail_preservation" factor biases the magnitude interpolation towards the larger of the two magnitudes. """ - # Record the original latent vector magnitudes. - # We bring them to a power so that larger magnitudes are favored over smaller ones. - # 64-bit operations are used here to allow large exponents. - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64) ** self.inpaint_detail_preservation + # NOTE: We use inplace operations wherever possible. one_minus_t = 1 - t - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - interp_magnitude = (a_magnitude * one_minus_t + b_magnitude * t) ** (1 / self.inpaint_detail_preservation) - # Linearly interpolate the image vectors. - image_interp = a * one_minus_t + b * t + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - image_interp_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64) + 0.0001 + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. - image_interp *= (interp_magnitude / image_interp_magnitude).to(image_interp.dtype) + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor - return image_interp + image_interp_scaled = image_interp_scaled.to(result_type) + del result_type + + return image_interp_scaled def get_modified_nmask(nmask, _sigma): """ From 28a2b5b4aab43424733039c31d910e8b8dd507cd Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:20:20 -0700 Subject: [PATCH 051/311] Fixed a math mistake. --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 6648097e8..949534986 100644 --- a/modules/images.py +++ b/modules/images.py @@ -969,7 +969,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 x = gaussian(x) x -= gauss_zero - x /= gauss_kernel_scale + x *= gauss_kernel_scale x = max(0.0, x) return x From 552f8bc832cd21ee0338e08b6a701687d0d79fad Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sun, 3 Dec 2023 14:49:41 -0700 Subject: [PATCH 052/311] "Uncrop" the original denoised image for the composite step, fixing a "ValueError: Images do not match" *shudder* --- modules/processing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 66aaab831..cd7216f83 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -994,6 +994,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # we need to keep the original image around # and use it in the composite step. original_denoised_image = image.copy() + + if p.paste_to is not None: + original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: From aaacf4823241450d88315af9d465d6815119fe0d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:27:22 -0700 Subject: [PATCH 053/311] Organized the settings and UI of soft inpainting to allow for toggling the feature, and centralizes default values to reduce the amount of copy-pasta. --- modules/img2img.py | 14 +-- modules/processing.py | 5 +- modules/sd_samplers_cfg_denoiser.py | 35 +++++--- modules/sd_samplers_common.py | 4 +- modules/soft_inpainting.py | 133 ++++++++++++++++++++++++++++ modules/ui.py | 17 ++-- scripts/outpainting_mk_2.py | 15 ++-- scripts/poor_mans_outpainting.py | 15 ++-- test/test_img2img.py | 8 +- 9 files changed, 197 insertions(+), 49 deletions(-) create mode 100644 modules/soft_inpainting.py diff --git a/modules/img2img.py b/modules/img2img.py index 596f741c1..3aa8a9cef 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,6 +15,7 @@ import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts +import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -162,6 +163,7 @@ def img2img(id_task: str, sampler_name: str, mask_blur: int, mask_alpha: float, + mask_blend_enabled: bool, mask_blend_power: float, mask_blend_scale: float, inpaint_detail_preservation: float, @@ -227,6 +229,9 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None + p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -244,9 +249,7 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - mask_blend_power=mask_blend_power, - mask_blend_scale=mask_blend_scale, - inpaint_detail_preservation=inpaint_detail_preservation, + soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -267,9 +270,8 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - p.extra_generation_params["Mask blending bias"] = mask_blend_power - p.extra_generation_params["Mask blending preservation"] = mask_blend_scale - p.extra_generation_params["Mask blending contrast boost"] = inpaint_detail_preservation + if soft_inpainting is not None: + soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/processing.py b/modules/processing.py index cd7216f83..b209c84a3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ import modules.sd_models as sd_models import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion +import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -1425,9 +1426,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - mask_blend_power: float = 1 - mask_blend_scale: float = 0.5 - inpaint_detail_preservation: float = 4 + soft_inpainting: si.SoftInpaintingParameters = si.default inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index efbe7a403..0ee0b7dde 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,6 +6,7 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import modules.soft_inpainting as si def catenate_conds(conds): @@ -43,9 +44,7 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.mask = None self.nmask = None - self.mask_blend_power = 1 - self.mask_blend_scale = 0.5 - self.inpaint_detail_preservation = 4 + self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -95,7 +94,8 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t): + def latent_blend(a, b, t, one_minus_t=None): + """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -104,7 +104,11 @@ class CFGDenoiser(torch.nn.Module): """ # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + if one_minus_t is None: + one_minus_t = 1 - t + + if self.soft_inpainting is None: + return a * one_minus_t + b * t # Linearly interpolate the image vectors. a_scaled = a * one_minus_t @@ -119,10 +123,10 @@ class CFGDenoiser(torch.nn.Module): current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) del a_magnitude, b_magnitude, one_minus_t # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -156,7 +160,10 @@ class CFGDenoiser(torch.nn.Module): NOTE: "mask" is not used """ - return torch.pow(nmask, (_sigma ** self.mask_blend_power) * self.mask_blend_scale) + if self.soft_inpainting is None: + return nmask + + return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -176,7 +183,10 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + x = latent_blend(self.init_latent, x, self.nmask, self.mask) + else: + x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -279,7 +289,10 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + if self.soft_inpainting is None: + denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + else: + denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index ecd8ab0a0..9682bee3d 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,9 +277,7 @@ class Sampler: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.mask_blend_power = p.mask_blend_power if hasattr(p, 'mask_blend_power') else None - self.model_wrap_cfg.mask_blend_scale = p.mask_blend_scale if hasattr(p, 'mask_blend_scale') else None - self.model_wrap_cfg.inpaint_detail_preservation = p.inpaint_detail_preservation if hasattr(p, 'inpaint_detail_preservation') else None + self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py new file mode 100644 index 000000000..259c36ec8 --- /dev/null +++ b/modules/soft_inpainting.py @@ -0,0 +1,133 @@ +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def get_paste_fields(self): + return [ + (self.mask_blend_power, gen_param_labels.mask_blend_power), + (self.mask_blend_scale, gen_param_labels.mask_blend_scale), + (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), + ] + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +default = SoftInpaintingSettings(1, 0.5, 4) +ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + mask_blend_power="Shifts when preservation of original content occurs during denoising.", + # "Below 1: Stronger preservation near the end (with low sigma)\n" + # "1: Balanced (proportional to sigma)\n" + # "Above 1: Stronger preservation in the beginning (with high sigma)", + mask_blend_scale="How strongly partially masked content should be preserved.", + # "Low values: Favors generated content.\n" + # "High values: Favors original content.", + inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") +el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") + + +def gradio_ui(): + import gradio as gr + from modules.ui_components import InputAccordion + """ + with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + """ + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + return ( + [ + soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation + ], + [ + (soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) + ] + ) diff --git a/modules/ui.py b/modules/ui.py index b13ed66cb..0e4fb17aa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,6 +29,7 @@ import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text +import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -678,9 +679,16 @@ def create_ui(): with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + soft_inpainting = si.gradio_ui() + + + """ mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") + """ with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -736,9 +744,7 @@ def create_ui(): sampler_name, mask_blur, mask_alpha, - mask_blend_power, - mask_blend_scale, - inpaint_detail_preservation, + *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -837,11 +843,10 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (mask_blend_power, "Mask blending bias"), - (mask_blend_scale, "Mask blending preservation"), - (inpaint_detail_preservation, "Mask blending contrast boost"), + *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index bd9cb61bf..f78886883 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -133,16 +134,14 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation] + return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +169,9 @@ class Script(scripts.Script): p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation + + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 5388f5db4..11f7f74a8 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,6 +7,7 @@ from PIL import Image, ImageDraw from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state +import modules.soft_inpainting as si class Script(scripts.Script): @@ -22,23 +23,19 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id=self.elem_id("mask_blend_power")) - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id=self.elem_id("mask_blend_scale")) - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id=self.elem_id("inpaint_detail_preservation")) + soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction] + return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.mask_blend_power = mask_blend_power - p.mask_blend_scale = mask_blend_scale - p.inpaint_detail_preservation = inpaint_detail_preservation - + p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ + if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 5cda2dbae..87bd85091 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,6 +1,7 @@ import pytest import requests +import modules.soft_inpainting as si @pytest.fixture() @@ -24,9 +25,10 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_power": 1, - "mask_blend_scale": 0.5, - "inpaint_detail_preservation": 4, + "mask_blend_enabled": True, + "mask_blend_power": si.default.mask_blend_power, + "mask_blend_scale": si.default.mask_blend_scale, + "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From 259d33c3c8e27557cb9bab9b3a1dd7fc7450d16c Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 01:57:21 -0700 Subject: [PATCH 054/311] Enables the original functionality to be toggled on and off. --- modules/processing.py | 91 +++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b209c84a3..b40b1a40d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -88,9 +88,12 @@ def apply_overlay(image, paste_loc, index, overlays): return image -def create_binary_mask(image): +def create_binary_mask(image, round=True): if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): - image = image.split()[-1].convert("L") + if round: + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.split()[-1].convert("L") else: image = image.convert('L') return image @@ -316,7 +319,7 @@ class StableDiffusionProcessing: c_adm = torch.cat((c_adm, noise_level_emb), 1) return c_adm - def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -327,6 +330,11 @@ class StableDiffusionProcessing: conditioning_mask = np.array(image_mask.convert("L")) conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + if round_image_mask: + # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:]) @@ -350,7 +358,7 @@ class StableDiffusionProcessing: return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True): source_image = devices.cond_cast_float(source_image) # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely @@ -362,7 +370,10 @@ class StableDiffusionProcessing: return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image, + latent_image, + image_mask=image_mask, + round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -878,8 +889,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None: + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: # latent_mask = p.nmask[0].float().cpu() # convert the original mask into a form we use to scale distances for thresholding @@ -911,7 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: converted_mask = converted_mask.astype(np.uint8) converted_mask = Image.fromarray(converted_mask) converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask) + converted_mask = create_binary_mask(converted_mask, round=False) # Remove aliasing artifacts using a gaussian blur. converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) @@ -1010,23 +1022,33 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: + image_mask = p.mask_for_overlay.convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + else: + image_mask = None + image_mask_composite = None - if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if image_mask is not None and image_mask_composite is not None: + if opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") - if opts.save_mask_composite: - images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + if opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") - if opts.return_mask: - output_images.append(image_mask) + if opts.return_mask: + output_images.append(image_mask) - if opts.return_mask_composite: - output_images.append(image_mask_composite) + if opts.return_mask_composite: + output_images.append(image_mask_composite) del x_samples_ddim @@ -1439,6 +1461,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) + mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1471,7 +1494,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask) + image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1489,6 +1512,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: + self.mask_for_overlay = image_mask if self.soft_inpainting is None else None mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1500,7 +1524,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - self.masks_for_overlay = [] + if self.soft_inpainting is None: + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) + + self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1522,8 +1551,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - self.overlay_images.append(image) - self.masks_for_overlay.append(image_mask) + if self.soft_inpainting is not None: + # We apply the masks AFTER to adjust mask based on changed content. + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) + else: + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1576,6 +1612,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] + if self.soft_inpainting is None: + latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) @@ -1587,7 +1625,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, + self.init_latent, + image_mask, + self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From 15322e1b1a9e31edcc2f7d72a32d02365058737d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 12:36:41 +0300 Subject: [PATCH 055/311] repair old handler for postprocessing API --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 3c85a74c1..d166f859b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -153,4 +153,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) From 883d6a2b34a2817304d23c2481a6f9fc56687a53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 13:11:00 +0300 Subject: [PATCH 056/311] repair old handler for postprocessing API in a way that doesn't break interface --- modules/postprocessing.py | 8 ++++++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index d166f859b..0c59fad48 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui from modules.shared import opts -def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin(job="extras") @@ -128,6 +128,10 @@ def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, out return outputs, ui_common.plaintext_to_html(infotext), '' +def run_postprocessing_webui(id_task, *args, **kwargs): + return run_postprocessing(*args, **kwargs) + + def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): """old handler for API""" @@ -153,4 +157,4 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ }, }) - return run_postprocessing("", extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index fbad0800a..13d888e48 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -35,7 +35,7 @@ def create_ui(): tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) submit.click( - fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']), _js="submit_extras", inputs=[ dummy_component, From 22e23dbf29b0bbc807daa57318c31145f8dd0774 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 15:56:03 +0300 Subject: [PATCH 057/311] add hypertile infotext --- .../hypertile/scripts/hypertile_script.py | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index d3ab60915..395d584b6 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -17,11 +17,42 @@ class ScriptHypertile(scripts.Script): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + self.add_infotext(p) + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + # exclusive hypertile seed for the second pass - if not shared.opts.hypertile_enable_unet: + if enable: hypertile.set_hypertile_seed(p.all_seeds[0]) - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') def configure_hypertile(width, height, enable_unet=True): @@ -57,16 +88,16 @@ def on_ui_settings(): benefit. """), - "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), - "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), - "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), - "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), - "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), } for name, opt in options.items(): From 854f8c318c2610c76259056ab02739176aa849e8 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 5 Dec 2023 04:40:12 +0900 Subject: [PATCH 058/311] remove clean_text() --- modules/styles.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7e..7fb6c2e11 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import csv import fnmatch import os import os.path -import re import typing import shutil @@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None -def clean_text(text: str) -> str: - """ - Iterating through a list of regular expressions and replacement strings, we - clean up the prompt and style text to make it easier to match against each - other. - """ - re_list = [ - ("multiple commas", re.compile("(,+\s+)+,?"), ", "), - ("multiple spaces", re.compile("\s{2,}"), " "), - ] - for _, regex, replace in re_list: - text = regex.sub(replace, text) - - return text.strip(", ") - - def merge_prompts(style_prompt: str, prompt: str) -> str: if "{prompt}" in style_prompt: res = style_prompt.replace("{prompt}", prompt) @@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles: prompt = merge_prompts(style, prompt) - return clean_text(prompt) + return prompt def unwrap_style_text_from_prompt(style_text, prompt): @@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching purposes here. It isn't returned; the original style text is not modified. """ - stripped_prompt = clean_text(prompt) - stripped_style_text = clean_text(style_text) + stripped_prompt = prompt + stripped_style_text = style_text if "{prompt}" in stripped_style_text: # Work out whether the prompt is wrapped in the style text. If so, we # return True and the "inner" prompt text that isn't part of the style. From 976c1053efeb5054692ed3cfa294cf79196f3946 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:06:58 -0700 Subject: [PATCH 059/311] Cleaned up code, moved main code contributions into soft_inpainting.py --- modules/processing.py | 56 ++------- modules/sd_samplers_cfg_denoiser.py | 84 ++----------- modules/soft_inpainting.py | 175 +++++++++++++++++++++++++--- modules/ui.py | 7 -- 4 files changed, 173 insertions(+), 149 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b40b1a40d..0b3603875 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -892,55 +892,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - # latent_mask = p.nmask[0].float().cpu() - - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_orig = p.init_latent - latent_proc = samples_ddim - latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, p.width, p.height) - converted_mask = create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if p.paste_to is not None: - converted_mask = uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - p.paste_to) - - p.masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - p.overlay_images[i] = image_masked.convert('RGBA') + si.generate_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 0ee0b7dde..a700e6922 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -94,76 +94,6 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - def latent_blend(a, b, t, one_minus_t=None): - - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - # NOTE: We use inplace operations wherever possible. - - if one_minus_t is None: - one_minus_t = 1 - t - - if self.soft_inpainting is None: - return a * one_minus_t + b * t - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - - image_interp_scaled = image_interp_scaled.to(result_type) - del result_type - - return image_interp_scaled - - def get_modified_nmask(nmask, _sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - if self.soft_inpainting is None: - return nmask - - return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale) if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -184,9 +114,12 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - x = latent_blend(self.init_latent, x, self.nmask, self.mask) + x = self.init_latent * self.mask + self.nmask * x else: - x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma)) + x = si.latent_blend(self.soft_inpainting, + self.init_latent, + x, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -290,9 +223,12 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: if self.soft_inpainting is None: - denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask) + denoised = self.init_latent * self.mask + self.nmask * denoised else: - denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma)) + denoised = si.latent_blend(self.soft_inpainting, + self.init_latent, + denoised, + si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 259c36ec8..b81c8dd95 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -4,13 +4,6 @@ class SoftInpaintingSettings: self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation - def get_paste_fields(self): - return [ - (self.mask_blend_power, gen_param_labels.mask_blend_power), - (self.mask_blend_scale, gen_param_labels.mask_blend_scale), - (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation), - ] - def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power @@ -18,25 +11,169 @@ class SoftInpaintingSettings: dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + one_minus_t = 1 - t + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t + b_scaled = b * t + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t + b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, one_minus_t + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def generate_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc. uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" enabled_el_id = "soft_inpainting_enabled" -default = SoftInpaintingSettings(1, 0.5, 4) -ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost") +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") ui_info = SoftInpaintingSettings( - mask_blend_power="Shifts when preservation of original content occurs during denoising.", - # "Below 1: Stronger preservation near the end (with low sigma)\n" - # "1: Balanced (proportional to sigma)\n" - # "Above 1: Stronger preservation in the beginning (with high sigma)", - mask_blend_scale="How strongly partially masked content should be preserved.", - # "Low values: Favors generated content.\n" - # "High values: Favors original content.", - inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.") + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") -gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost") -el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation") +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +# ------------------- UI ------------------- def gradio_ui(): diff --git a/modules/ui.py b/modules/ui.py index 0e4fb17aa..4f1265a3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -683,13 +683,6 @@ def create_ui(): with FormRow(): soft_inpainting = si.gradio_ui() - - """ - mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power") - mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale") - inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset") - """ - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") From 1455159cf44cd8c21656818463f6095eae887540 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 16:43:57 -0700 Subject: [PATCH 060/311] Fixed issue with whitespace, removed commented out code that was meant to be used as a reference. --- modules/soft_inpainting.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index b81c8dd95..56a877746 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -179,15 +179,7 @@ el_ids = SoftInpaintingSettings( def gradio_ui(): import gradio as gr from modules.ui_components import InputAccordion - """ - with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: - with gr.Row(): - refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") - create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") - - """ with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: with gr.Group(): gr.Markdown( @@ -223,11 +215,11 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_power} - + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - + - **Below 1**: Stronger preservation near the end (with low sigma) - **1**: Balanced (proportional to sigma) - **Above 1**: Stronger preservation in the beginning (with high sigma) @@ -235,21 +227,21 @@ def gradio_ui(): gr.Markdown( f""" ### {ui_labels.mask_blend_scale} - + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - + - **Low values**: Favors generated content. - **High values**: Favors original content. """) gr.Markdown( f""" ### {ui_labels.inpaint_detail_preservation} - + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. This can prevent the loss of contrast that occurs with linear interpolation. - + - **Low values**: Softer blending, details may fade. - **High values**: Stronger contrast, may over-saturate colors. """) From 57f29bd61dc30f1a8c94ead9b780f4655f7d7d6d Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:18 -0700 Subject: [PATCH 061/311] Re-introduce latent blending step from the vanilla inpainting procedure. --- modules/processing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 0b3603875..c8dc4d934 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1597,6 +1597,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + if self.mask is not None and self.soft_inpainting is None: + samples = samples * self.nmask + self.init_latent * self.mask + del x devices.torch_gc() From 60c602232fd760fb548fb0b3d18b5297f8823c2a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:41:51 -0700 Subject: [PATCH 062/311] Restored original formatting. --- modules/processing.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c8dc4d934..90ae249a4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,10 +370,7 @@ class StableDiffusionProcessing: return self.edit_image_conditioning(source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, - latent_image, - image_mask=image_mask, - round_image_mask=round_image_mask) + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask) if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) @@ -885,7 +882,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate masks the old fashioned way + # todo: generate adaptive masks based on pixel differences. + # if p.masks_for_overlay is used, it will already be populated with masks else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method @@ -900,9 +898,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: height=p.height, paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, - target_device=devices.cpu, - check_for_nans=True) + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -927,9 +923,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, - use_main_prompt=use_main_prompt, index=index, - all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -972,8 +966,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], - p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) @@ -983,14 +976,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if save_samples and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): if hasattr(p, 'masks_for_overlay') and p.masks_for_overlay: image_mask = p.masks_for_overlay[i].convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') elif hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite( - original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), - images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') else: image_mask = None image_mask_composite = None @@ -1515,8 +1504,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res @@ -1583,10 +1572,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, - self.init_latent, - image_mask, - self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() From b32a334e3da7b06d82441beaa08a673b4f55bca1 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 17:57:10 -0700 Subject: [PATCH 063/311] Applies a convert('RGBA') operation early to mimic previous behaviour. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 90ae249a4..7fc282cfd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1500,7 +1500,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: if self.soft_inpainting is not None: # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image) + self.overlay_images.append(image.convert('RGBA')) self.masks_for_overlay.append(image_mask) else: image_masked = Image.new('RGBa', (image.width, image.height)) From 6fc12428e3c5f903584ca7986e0c441f80fa2807 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:42:59 -0700 Subject: [PATCH 064/311] Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly). --- modules/processing.py | 23 ++++++++----- modules/soft_inpainting.py | 66 ++++++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7fc282cfd..71bb056a2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -883,20 +883,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim # todo: generate adaptive masks based on pixel differences. - # if p.masks_for_overlay is used, it will already be populated with masks + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: + si.apply_masks(soft_inpainting=p.soft_inpainting, + nmask=p.nmask, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.generate_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) + si.apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 56a877746..b36ac8fa1 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t): # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t + a_scaled = a * one_minus_t2 + b_scaled = b * t2 image_interp = a_scaled image_interp.add_(b_scaled) result_type = image_interp.dtype - del a_scaled, b_scaled + del a_scaled, b_scaled, t2, one_minus_t2 # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t + del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. @@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + # todo: Why is sigma 2D? Both values are the same. + return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) -def generate_adaptive_masks( +def apply_adaptive_masks( latent_orig, latent_processed, overlay_images, @@ -142,6 +149,45 @@ def generate_adaptive_masks( overlay_images[i] = image_masked.convert('RGBA') +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + # ------------------- Constants ------------------- From 49bbf1140731036875573bb7c44aa7e74623c856 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:47:40 -0700 Subject: [PATCH 065/311] Fixed unused import. --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 71bb056a2..e1823ac33 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field import torch import numpy as np -from PIL import Image, ImageOps, ImageFilter +from PIL import Image, ImageOps import random import cv2 from skimage import exposure From 895456c4a2e87f5fe3ee23b4482e68fce317a1ca Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Tue, 5 Dec 2023 18:00:48 -0600 Subject: [PATCH 066/311] change state dict comparison to ref compare --- modules/sd_disable_initialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 8863107ae..273a7edd8 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper): would be on the meta device. """ - if state_dict == sd: + if state_dict is sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} original(module, state_dict, strict=strict) From 672dc4efa8e0da38426b121e7c7216d0a8e465fd Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:16:10 +0800 Subject: [PATCH 067/311] Fix forced reload --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index dcf816b3a..d0046f88c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -801,7 +801,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if check_fp8(sd_model) != devices.fp8: # load from state dict again to prevent extra numerical errors forced_reload = True - elif sd_model.sd_model_checkpoint == checkpoint_info.filename: + elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) From 4d56383025f2cbd00dc6296161e31a896624ab75 Mon Sep 17 00:00:00 2001 From: "fuchen.ljl" Date: Wed, 6 Dec 2023 20:23:56 +0800 Subject: [PATCH 068/311] Long distance memory overflow issue Problem: The memory will slowly increase with the drawing until restarting. Observation: GC analysis shows that no occupation has occurred, so it is suspected to be a problem with the underlying allocator. Reason: Under Linux, glibc is used to allocate memory. glibc uses brk and mmap to allocate memory, and the memory allocated by brk cannot be released until the high-address memory is released. That is to say, if you apply for two pieces of memory A and B through brk, it is impossible to release A before B is released, and it is still occupied by the process. Check the suspected "memory leak" through TOP. So I replaced TCMalloc, but found that libtcmalloc_minimal could not find ptthread_Key_Create. After analysis, it was found that pthread was not entered during compilation. --- webui.sh | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed..081624c4e 100755 --- a/webui.sh +++ b/webui.sh @@ -222,13 +222,29 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" - if [[ ! -z "${TCMALLOC}" ]]; then - echo "Using TCMalloc: ${TCMALLOC}" - export LD_PRELOAD="${TCMALLOC}" - else - printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" - fi + # Define Tcmalloc Libs arrays + TCMALLOC_LIBS=("libtcmalloc(_minimal|)\.so\.\d" "libtcmalloc\.so\.\d") + + # Traversal array + for lib in "${TCMALLOC_LIBS[@]}" + do + #Determine which type of tcmalloc library the library supports + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)" + TC_INFO=(${TCMALLOC//=>/}) + if [[ ! -z "${TC_INFO}" ]]; then + echo "Using TCMalloc: ${TC_INFO}" + #Determine if the library is linked to libptthread and resolve undefined symbol: ptthread_Key_Create + if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then + echo "$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO}" + export LD_PRELOAD="${TC_INFO}" + break + else + echo "$TC_INFO is not linked with libpthreadand will trigger undefined symbol: ptthread_Key_Create error" + fi + else + printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n" + fi + done fi } From 746783f7a47f38f728f221cc26fe04035d3ca66b Mon Sep 17 00:00:00 2001 From: Nuullll Date: Wed, 6 Dec 2023 20:55:42 +0800 Subject: [PATCH 069/311] [IPEX] Fix embedding Cast `torch.bmm` args into same `dtype`. Fixes the following error when using Text Inversion embedding (#14224): ``` RuntimeError: could not create a primitive descriptor for a matmul primitive ``` --- modules/xpu_specific.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d933c7903..ec1ad100a 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -48,3 +48,6 @@ if has_xpu: CondFunc('torch.nn.modules.conv.Conv2d.forward', lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), lambda orig_func, self, input: input.dtype != self.weight.data.dtype) + CondFunc('torch.bmm', + lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), + lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) From 9d2cbf8e97832662e446145d3961c39e78919d3d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:06:32 +0900 Subject: [PATCH 070/311] add option: Live preview in full page image viewer make #13459 "show the preview image in the modal view if available" optional --- javascript/imageviewer.js | 2 +- modules/shared_options.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index e4dae91bc..625c5d148 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -34,7 +34,7 @@ function updateOnBackgroundChange() { if (modalImage && modalImage.offsetParent) { let currentButton = selected_gallery_button(); let preview = gradioApp().querySelectorAll('.livePreview > img'); - if (preview.length > 0) { + if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) { // show preview image if available modalImage.src = preview[preview.length - 1].src; } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) { diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018..88cfddede 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -330,6 +330,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), { "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), + "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From e90d4334ad37024a802f4ef27069b625a6508f72 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 16:54:42 -0700 Subject: [PATCH 071/311] A custom blending function can be provided by p, replacing the use of soft_inpainting. --- modules/sd_samplers_cfg_denoiser.py | 34 ++++++++++++++--------------- modules/sd_samplers_common.py | 1 - 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a700e6922..f13e8dcc5 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,7 +6,6 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback -import modules.soft_inpainting as si def catenate_conds(conds): @@ -44,7 +43,6 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.mask = None self.nmask = None - self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -94,7 +92,6 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -111,15 +108,24 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(latent): + if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): + return self.p.denoiser_masked_blend_function( + self, + # Using an argument dictionary so that arguments can be added without breaking extensions. + args= + { + "denoiser": self, + "current_latent": latent, + "sigma": sigma + }) + else: + return self.init_latent * self.mask + self.nmask * latent + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - x = self.init_latent * self.mask + self.nmask * x - else: - x = si.latent_blend(self.soft_inpainting, - self.init_latent, - x, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -222,13 +228,7 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - denoised = self.init_latent * self.mask + self.nmask * denoised - else: - denoised = si.latent_blend(self.soft_inpainting, - self.init_latent, - denoised, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 9682bee3d..58efcad23 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -277,7 +277,6 @@ class Sampler: self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.soft_inpainting = p.soft_inpainting if hasattr(p, 'soft_inpainting') else None self.model_wrap_cfg.step = 0 self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) From 4608f6236fc24d937f89500b2c9bf48484537cf9 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 18:11:17 -0700 Subject: [PATCH 072/311] Removed changes in some scripts since the arguments for soft painting are no longer passed through the same path as "mask_blur". --- modules/img2img.py | 50 +------------------------------- modules/ui.py | 7 ----- scripts/outpainting_mk_2.py | 9 ++---- scripts/poor_mans_outpainting.py | 8 ++--- test/test_img2img.py | 5 ---- 5 files changed, 5 insertions(+), 74 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 3aa8a9cef..c583290a0 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -15,7 +15,6 @@ import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html import modules.scripts -import modules.soft_inpainting as si def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): @@ -147,48 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal return batch_results -def img2img(id_task: str, - mode: int, - prompt: str, - negative_prompt: str, - prompt_styles, - init_img, - sketch, - init_img_with_mask, - inpaint_color_sketch, - inpaint_color_sketch_orig, - init_img_inpaint, - init_mask_inpaint, - steps: int, - sampler_name: str, - mask_blur: int, - mask_alpha: float, - mask_blend_enabled: bool, - mask_blend_power: float, - mask_blend_scale: float, - inpaint_detail_preservation: float, - inpainting_fill: int, - n_iter: int, - batch_size: int, - cfg_scale: float, - image_cfg_scale: float, - denoising_strength: float, - selected_scale_tab: int, - height: int, - width: int, - scale_by: float, - resize_mode: int, - inpaint_full_res: bool, - inpaint_full_res_padding: int, - inpainting_mask_invert: int, - img2img_batch_input_dir: str, - img2img_batch_output_dir: str, - img2img_batch_inpaint_mask_dir: str, - override_settings_texts, - img2img_batch_use_png_info: bool, - img2img_batch_png_info_props: list, - img2img_batch_png_info_dir: str, - request: gr.Request, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -229,9 +187,6 @@ def img2img(id_task: str, assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - p = StableDiffusionProcessingImg2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples, @@ -249,7 +204,6 @@ def img2img(id_task: str, init_images=[image], mask=mask, mask_blur=mask_blur, - soft_inpainting=soft_inpainting, inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, @@ -270,8 +224,6 @@ def img2img(id_task: str, if mask: p.extra_generation_params["Mask blur"] = mask_blur - if soft_inpainting is not None: - soft_inpainting.add_generation_params(p.extra_generation_params) with closing(p): if is_batch: diff --git a/modules/ui.py b/modules/ui.py index bd2091e1f..d80486dd4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -29,7 +29,6 @@ import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.generation_parameters_copypaste import image_from_url_text -import modules.soft_inpainting as si create_setting_component = ui_settings.create_setting_component @@ -680,9 +679,6 @@ def create_ui(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - with FormRow(): - soft_inpainting = si.gradio_ui() - with FormRow(): inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") @@ -737,7 +733,6 @@ def create_ui(): sampler_name, mask_blur, mask_alpha, - *(soft_inpainting[0]), inpainting_fill, batch_count, batch_size, @@ -836,10 +831,8 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *(soft_inpainting[1]), *scripts.scripts_img2img.infotext_fields ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index f78886883..c98ab4809 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from PIL import Image, ImageDraw from modules import images from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -134,14 +133,13 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) - return [info, pixels, mask_blur, *soft_inpainting, direction, noise_q, color_variation] + return [info, pixels, mask_blur, direction, noise_q, color_variation] - def run(self, p, _, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, direction, noise_q, color_variation): + def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation): initial_seed_and_info = [None, None] process_width = p.width @@ -170,9 +168,6 @@ class Script(scripts.Script): p.mask_blur_x = mask_blur_x*4 p.mask_blur_y = mask_blur_y*4 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None - init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 target_h = math.ceil((init_img.height + up + down) / 64) * 64 diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 11f7f74a8..ea0632b68 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,7 +7,6 @@ from PIL import Image, ImageDraw from modules import images, devices from modules.processing import Processed, process_images from modules.shared import opts, state -import modules.soft_inpainting as si class Script(scripts.Script): @@ -23,19 +22,16 @@ class Script(scripts.Script): pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) - soft_inpainting = si.gradio_ui()[0] inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) - return [pixels, mask_blur, *soft_inpainting, inpainting_fill, direction] + return [pixels, mask_blur, inpainting_fill, direction] - def run(self, p, pixels, mask_blur, mask_blend_enabled, mask_blend_power, mask_blend_scale, inpaint_detail_preservation, inpainting_fill, direction): + def run(self, p, pixels, mask_blur, inpainting_fill, direction): initial_seed = None initial_info = None p.mask_blur = mask_blur * 2 - p.soft_inpainting = si.SoftInpaintingSettings(mask_blend_power, mask_blend_scale, inpaint_detail_preservation) \ - if mask_blend_enabled else None p.inpainting_fill = inpainting_fill p.inpaint_full_res = False diff --git a/test/test_img2img.py b/test/test_img2img.py index 87bd85091..117d2d1eb 100644 --- a/test/test_img2img.py +++ b/test/test_img2img.py @@ -1,7 +1,6 @@ import pytest import requests -import modules.soft_inpainting as si @pytest.fixture() @@ -25,10 +24,6 @@ def simple_img2img_request(img2img_basic_image_base64): "inpainting_mask_invert": False, "mask": None, "mask_blur": 4, - "mask_blend_enabled": True, - "mask_blend_power": si.default.mask_blend_power, - "mask_blend_scale": si.default.mask_blend_scale, - "inpaint_detail_preservation": si.default.inpaint_detail_preservation, "n_iter": 1, "negative_prompt": "", "override_settings": {}, From ac4578912395627731f2cd8529f87a95df1f7644 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 21:16:27 -0700 Subject: [PATCH 073/311] Removed soft inpainting, added hooks for softpainting to work instead. --- modules/processing.py | 94 ++++++++++++----------------- modules/scripts.py | 70 +++++++++++++++++++++ modules/sd_samplers_cfg_denoiser.py | 23 +++---- 3 files changed, 118 insertions(+), 69 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7d46949fa..5a1a90afe 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,7 +30,6 @@ import modules.sd_models as sd_models import modules.sd_vae as sd_vae from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion -import modules.soft_inpainting as si from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType @@ -73,12 +72,10 @@ def uncrop(image, dest_size, paste_loc): return image -def apply_overlay(image, paste_loc, index, overlays): - if overlays is None or index >= len(overlays): +def apply_overlay(image, paste_loc, overlay): + if overlay is None: return image - overlay = overlays[index] - if paste_loc is not None: image = uncrop(image, (overlay.width, overlay.height), paste_loc) @@ -150,7 +147,6 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None - masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = None @@ -880,31 +876,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = pp.samples + if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim - # todo: generate adaptive masks based on pixel differences. - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_masks(soft_inpainting=p.soft_inpainting, - nmask=p.nmask, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - # Generate the mask(s) based on similarity between the original and denoised latent vectors - if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.apply_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -955,9 +937,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image + + mask_for_overlay = p.mask_for_overlay + overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: - image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) + image_without_cc = apply_overlay(image, p.paste_to, overlay_image) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) @@ -968,9 +959,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_images[i].width, p.overlay_images[i].height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) - image = apply_overlay(image, p.paste_to, i, p.overlay_images) + image = apply_overlay(image, p.paste_to, overlay_image) if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) @@ -981,13 +972,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image.info["parameters"] = text output_images.append(image) - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay: - mask_for_overlay = p.mask_for_overlay - elif hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and p.masks_for_overlay[i]: - mask_for_overlay = p.masks_for_overlay[i] - else: - mask_for_overlay = None - if mask_for_overlay is not None: if opts.return_mask or opts.save_mask: image_mask = mask_for_overlay.convert('RGB') @@ -1401,7 +1385,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask_blur_x: int = 4 mask_blur_y: int = 4 mask_blur: int = None - soft_inpainting: si.SoftInpaintingParameters = si.default + mask_round: bool = True inpainting_fill: int = 0 inpaint_full_res: bool = True inpaint_full_res_padding: int = 0 @@ -1447,7 +1431,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if image_mask is not None: # image_mask is passed in as RGBA by Gradio to support alpha masks, # but we still want to support binary masks. - image_mask = create_binary_mask(image_mask, round=(self.soft_inpainting is None)) + image_mask = create_binary_mask(image_mask, round=self.mask_round) if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) @@ -1465,7 +1449,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - self.mask_for_overlay = image_mask if self.soft_inpainting is None else None + self.mask_for_overlay = image_mask mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1476,13 +1460,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) - if self.soft_inpainting is None: - np_mask = np.array(image_mask) - np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) - - self.masks_for_overlay = [] if self.soft_inpainting is not None else None self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1504,15 +1485,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - if self.soft_inpainting is not None: - # We apply the masks AFTER to adjust mask based on changed content. - self.overlay_images.append(image.convert('RGBA')) - self.masks_for_overlay.append(image_mask) - else: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image_masked.convert('RGBA')) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1565,7 +1541,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2])) latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255 latmask = latmask[0] - if self.soft_inpainting is None: + if self.mask_round: latmask = np.around(latmask) latmask = np.tile(latmask[None], (4, 1, 1)) @@ -1578,7 +1554,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask - self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.soft_inpainting is None) + self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = self.rng.next() @@ -1589,8 +1565,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - if self.mask is not None and self.soft_inpainting is None: - samples = samples * self.nmask + self.init_latent * self.mask + blended_samples = samples * self.nmask + self.init_latent * self.mask + + if self.scripts is not None: + mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent + + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 7f9454eb5..92a07c564 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() +class MaskBlendArgs: + def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + self.current_latent = current_latent + self.nmask = nmask + self.init_latent = init_latent + self.mask = mask + self.blended_samples = blended_samples + + self.denoiser = denoiser + self.is_final_blend = denoiser is None + self.sigma = sigma + +class PostSampleArgs: + def __init__(self, samples): + self.samples = samples class PostprocessImageArgs: def __init__(self, image): self.image = image +class PostProcessMaskOverlayArgs: + def __init__(self, index, mask_for_overlay, overlay_image): + self.index = index + self.mask_for_overlay = mask_for_overlay + self.overlay_image = overlay_image class PostprocessBatchListArgs: def __init__(self, images): @@ -206,6 +226,25 @@ class Script: pass + def on_mask_blend(self, p, mba: MaskBlendArgs, *args): + """ + Called in inpainting mode when the original content is blended with the inpainted content. + This is called at every step in the denoising process and once at the end. + If is_final_blend is true, this is called for the final blending stage. + Otherwise, denoiser and sigma are defined and may be used to inform the procedure. + """ + + pass + + def post_sample(self, p, ps: PostSampleArgs, *args): + """ + Called after the samples have been generated, + but before they have been decoded by the VAE, if applicable. + Check getattr(samples, 'already_decoded', False) to test if the images are decoded. + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -213,6 +252,13 @@ class Script: pass + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -767,6 +813,22 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def post_sample(self, p, ps: PostSampleArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.post_sample(p, ps, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + + def on_mask_blend(self, p, mba: MaskBlendArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.on_mask_blend(p, mba, *script_args) + except Exception: + errors.report(f"Error running post_sample: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: @@ -775,6 +837,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_maskoverlay(p, ppmo, *script_args) + except Exception: + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f13e8dcc5..eb9d5dafa 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -109,19 +109,16 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" # If we use masks, blending between the denoised and original latent images occurs here. - def apply_blend(latent): - if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): - return self.p.denoiser_masked_blend_function( - self, - # Using an argument dictionary so that arguments can be added without breaking extensions. - args= - { - "denoiser": self, - "current_latent": latent, - "sigma": sigma - }) - else: - return self.init_latent * self.mask + self.nmask * latent + def apply_blend(current_latent): + blended_latent = current_latent * self.nmask + self.init_latent * self.mask + + if self.p.scripts is not None: + from modules import scripts + mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma) + self.p.scripts.on_mask_blend(self.p, mba) + blended_latent = mba.blended_latent + + return blended_latent # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: From 2abc417834d752e43a283f8603bfddfb1c80b30f Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 22:25:53 -0700 Subject: [PATCH 074/311] Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to. --- modules/processing.py | 23 +-- modules/scripts.py | 4 +- modules/soft_inpainting.py | 308 ---------------------------- scripts/soft_inpainting.py | 401 +++++++++++++++++++++++++++++++++++++ 4 files changed, 413 insertions(+), 323 deletions(-) delete mode 100644 modules/soft_inpainting.py create mode 100644 scripts/soft_inpainting.py diff --git a/modules/processing.py b/modules/processing.py index 5a1a90afe..f8d85bdf5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ps = scripts.PostSampleArgs(samples_ddim) p.scripts.post_sample(p, ps) - samples_ddim = pp.samples + samples_ddim = ps.samples if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) p.scripts.postprocess_maskoverlay(p, ppmo) - mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image if p.color_corrections is not None and i < len(p.color_corrections): if save_samples and opts.save_images_before_color_correction: @@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: original_denoised_image = image.copy() if p.paste_to is not None: - original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) + original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to) image = apply_overlay(image, p.paste_to, overlay_image) @@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size - if self.masks_for_overlay is not None: - self.masks_for_overlay = self.masks_for_overlay * self.batch_size - if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size @@ -1565,14 +1561,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) - blended_samples = samples * self.nmask + self.init_latent * self.mask + if self.mask is not None: + blended_samples = samples * self.nmask + self.init_latent * self.mask - if self.scripts is not None: - mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) - self.scripts.on_mask_blend(self, mba) - blended_samples = mba.blended_latent + if self.scripts is not None: + mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples) + self.scripts.on_mask_blend(self, mba) + blended_samples = mba.blended_latent - samples = blended_samples + samples = blended_samples del x devices.torch_gc() diff --git a/modules/scripts.py b/modules/scripts.py index 92a07c564..b6fcf96e0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, AlwaysVisible = object() class MaskBlendArgs: - def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): + def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None): self.current_latent = current_latent self.nmask = nmask self.init_latent = init_latent self.mask = mask - self.blended_samples = blended_samples + self.blended_latent = blended_latent self.denoiser = denoiser self.is_final_blend = denoiser is None diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py deleted file mode 100644 index b36ac8fa1..000000000 --- a/modules/soft_inpainting.py +++ /dev/null @@ -1,308 +0,0 @@ -class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): - self.mask_blend_power = mask_blend_power - self.mask_blend_scale = mask_blend_scale - self.inpaint_detail_preservation = inpaint_detail_preservation - - def add_generation_params(self, dest): - dest[enabled_gen_param_label] = True - dest[gen_param_labels.mask_blend_power] = self.mask_blend_power - dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale - dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation - - -# ------------------- Methods ------------------- - - -def latent_blend(soft_inpainting, a, b, t): - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - import torch - - # NOTE: We use inplace operations wherever possible. - - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) - - one_minus_t2 = 1 - t2 - one_minus_t3 = 1 - t3 - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t2 - b_scaled = b * t2 - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled, t2, one_minus_t2 - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, t3, one_minus_t3 - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - del result_type - - return image_interp_scaled - - -def get_modified_nmask(soft_inpainting, nmask, sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - import torch - # todo: Why is sigma 2D? Both values are the same. - return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) - - -def apply_adaptive_masks( - latent_orig, - latent_processed, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() - # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() - - latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc. uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - paste_to) - - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - -def apply_masks( - soft_inpainting, - nmask, - overlay_images, - masks_for_overlay, - width, height, - paste_to): - import torch - import numpy as np - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) - converted_mask = 255. * converted_mask - converted_mask = converted_mask.cpu().numpy().astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (width, height), - paste_to) - - for i, overlay_image in enumerate(overlay_images): - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - -# ------------------- Constants ------------------- - - -default = SoftInpaintingSettings(1, 0.5, 4) - -enabled_ui_label = "Soft inpainting" -enabled_gen_param_label = "Soft inpainting enabled" -enabled_el_id = "soft_inpainting_enabled" - -ui_labels = SoftInpaintingSettings( - "Schedule bias", - "Preservation strength", - "Transition contrast boost") - -ui_info = SoftInpaintingSettings( - "Shifts when preservation of original content occurs during denoising.", - "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") - -gen_param_labels = SoftInpaintingSettings( - "Soft inpainting schedule bias", - "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") - -el_ids = SoftInpaintingSettings( - "mask_blend_power", - "mask_blend_scale", - "inpaint_detail_preservation") - - -# ------------------- UI ------------------- - - -def gradio_ui(): - import gradio as gr - from modules.ui_components import InputAccordion - - with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: - with gr.Group(): - gr.Markdown( - """ - Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. - **High _Mask blur_** values are recommended! - """) - - result = SoftInpaintingSettings( - gr.Slider(label=ui_labels.mask_blend_power, - info=ui_info.mask_blend_power, - minimum=0, - maximum=8, - step=0.1, - value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), - gr.Slider(label=ui_labels.mask_blend_scale, - info=ui_info.mask_blend_scale, - minimum=0, - maximum=8, - step=0.05, - value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), - gr.Slider(label=ui_labels.inpaint_detail_preservation, - info=ui_info.inpaint_detail_preservation, - minimum=1, - maximum=32, - step=0.5, - value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) - - with gr.Accordion("Help", open=False): - gr.Markdown( - f""" - ### {ui_labels.mask_blend_power} - - The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). - This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. - This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - - - **Below 1**: Stronger preservation near the end (with low sigma) - - **1**: Balanced (proportional to sigma) - - **Above 1**: Stronger preservation in the beginning (with high sigma) - """) - gr.Markdown( - f""" - ### {ui_labels.mask_blend_scale} - - Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. - This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - - - **Low values**: Favors generated content. - - **High values**: Favors original content. - """) - gr.Markdown( - f""" - ### {ui_labels.inpaint_detail_preservation} - - This parameter controls how the original latent vectors and denoised latent vectors are interpolated. - With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. - This can prevent the loss of contrast that occurs with linear interpolation. - - - **Low values**: Softer blending, details may fade. - - **High values**: Stronger contrast, may over-saturate colors. - """) - - return ( - [ - soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation - ], - [ - (soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation) - ] - ) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py new file mode 100644 index 000000000..47e0269bf --- /dev/null +++ b/scripts/soft_inpainting.py @@ -0,0 +1,401 @@ +import gradio as gr +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + + +# ------------------- Methods ------------------- + + +def latent_blend(soft_inpainting, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + soft_inpainting.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(soft_inpainting, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + + +def apply_adaptive_masks( + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + # latent_mask = p.nmask[0].float().cpu() + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation") + + +class Script(scripts.Script): + + def __init__(self): + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + result = SoftInpaintingSettings( + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power), + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale), + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation)) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (result.mask_blend_power, gen_param_labels.mask_blend_power), + (result.mask_blend_scale, gen_param_labels.mask_blend_scale), + (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + result.mask_blend_power, + result.mask_blend_scale, + result.inpaint_detail_preservation] + + def process(self, p, enabled, power, scale, detail_preservation): + if not enabled: + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + if mba.sigma is None: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation) + + from modules import images + from modules.shared import opts + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(soft_inpainting=settings, + nmask=p.nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + if not enabled: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file From 8dbacc7d018774a3bc801cc57617795274a15087 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:30:30 -0700 Subject: [PATCH 075/311] Fixed "No newline at end of file". --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 47e0269bf..6d0cf8479 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -398,4 +398,4 @@ class Script(scripts.Script): return ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] - ppmo.overlay_image = self.overlay_images[ppmo.index] \ No newline at end of file + ppmo.overlay_image = self.overlay_images[ppmo.index] From 56604f08a18588e8e6b57d7c3f9c61d6624846f8 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:53:44 -0700 Subject: [PATCH 076/311] Moved image filters used by soft inpainting into soft_inpainting.py from images.py --- modules/images.py | 190 ---------------------------------- scripts/soft_inpainting.py | 205 +++++++++++++++++++++++++++++++++++-- 2 files changed, 199 insertions(+), 196 deletions(-) diff --git a/modules/images.py b/modules/images.py index 949534986..16f9ae7cc 100644 --- a/modules/images.py +++ b/modules/images.py @@ -792,193 +792,3 @@ def flatten(img, bgcolor): return img.convert('RGB') - -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): - """ - Generalization convolution filter capable of applying - weighted mean, median, maximum, and minimum filters - parametrically using an arbitrary kernel. - - Args: - img (nparray): - The image, a 2-D array of floats, to which the filter is being applied. - kernel (nparray): - The kernel, a 2-D array of floats. - kernel_center (nparray): - The kernel center coordinate, a 1-D array with two elements. - percentile_min (float): - The lower bound of the histogram window used by the filter, - from 0 to 1. - percentile_max (float): - The upper bound of the histogram window used by the filter, - from 0 to 1. - min_width (float): - The minimum size of the histogram window bounds, in weight units. - Must be greater than 0. - - Returns: - (nparray): A filtered copy of the input image "img", a 2-D array of floats. - """ - - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center - - def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index - - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 - - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) - - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max - window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. - if window_width < min_width: - window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 - - if window_max > sum: - window_max = sum - window_min = sum - min_width - - if window_min < 0: - window_min = 0 - window_max = min_width - - value = 0 - value_weight = 0 - - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break - - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s - - value += values[i].value * w - value_weight += w - - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) - - return img_out - -def smoothstep(x): - """ - The smoothstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * (3 - 2 * x) - -def smootherstep(x): - """ - The smootherstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * x * (x * (6 * x - 15) + 10) - - -def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): - """ - Creates a Gaussian kernel with thresholded edges. - - Args: - stddev_radius (float): - Standard deviation of the gaussian kernel, in pixels. - max_radius (int): - The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. - The kernel is thresholded so that any values one pixel beyond this radius - is weighted at 0. - - Returns: - (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) - """ - # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. - def gaussian(sqr_mag): - return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) - - # Helper function for converting a tuple to an array. - def vec(x): - return np.array(x) - - """ - Since a gaussian is unbounded, we need to limit ourselves - to a finite range. - We taper the ends off at the end of that range so they equal zero - while preserving the maximum value of 1 at the mean. - """ - zero_radius = max_radius + 1.0 - gauss_zero = gaussian(zero_radius * zero_radius) - gauss_kernel_scale = 1 / (1 - gauss_zero) - - def gaussian_kernel_func(coordinate): - x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 - x = gaussian(x) - x -= gauss_zero - x *= gauss_kernel_scale - x = max(0.0, x) - return x - - size = max_radius * 2 + 1 - kernel_center = max_radius - kernel = np.zeros((size, size)) - - for index in np.ndindex(kernel.shape): - kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) - - return kernel, kernel_center - diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6d0cf8479..1f451b553 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -1,4 +1,6 @@ +import numpy as np import gradio as gr +import math from modules.ui_components import InputAccordion import modules.scripts as scripts @@ -101,7 +103,6 @@ def apply_adaptive_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -115,15 +116,15 @@ def apply_adaptive_masks( latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) masks_for_overlay = [] for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% @@ -131,7 +132,7 @@ def apply_adaptive_masks( # converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** 2) - converted_mask = images.smootherstep(converted_mask) + converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask converted_mask = converted_mask.astype(np.uint8) @@ -166,7 +167,6 @@ def apply_masks( width, height, paste_to): import torch - import numpy as np import modules.processing as proc import modules.images as images from PIL import Image, ImageOps, ImageFilter @@ -202,6 +202,196 @@ def apply_masks( return masks_for_overlay +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + # ------------------- Constants ------------------- @@ -232,6 +422,9 @@ el_ids = SoftInpaintingSettings( "inpaint_detail_preservation") +# ----- + + class Script(scripts.Script): def __init__(self): From 0ef4a4cb2365051b1e308f0136a0d8c01d071569 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 14:54:26 -0700 Subject: [PATCH 077/311] Fixed error that occurs when using vanilla samplers (somehow). --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f8d85bdf5..bea01ec68 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -937,8 +937,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.scripts.postprocess_image(p, pp) image = pp.image - mask_for_overlay = p.mask_for_overlay - overlay_image = p.overlay_images[i] if p.overlay_images is not None and i < len(p.overlay_images) else None + mask_for_overlay = getattr(p, "mask_for_overlay", None) + overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None if p.scripts is not None: ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) From f284ae23bcdfa212cf4763659c06e124ec5b1456 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:19:35 -0700 Subject: [PATCH 078/311] Added parameters for the composite stage, fixed batched generation. --- scripts/soft_inpainting.py | 196 +++++++++++++++++++++++++++++-------- 1 file changed, 154 insertions(+), 42 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1f451b553..1b21aee9d 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -6,22 +6,34 @@ import modules.scripts as scripts class SoftInpaintingSettings: - def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation): + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): self.mask_blend_power = mask_blend_power self.mask_blend_scale = mask_blend_scale self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast def add_generation_params(self, dest): dest[enabled_gen_param_label] = True dest[gen_param_labels.mask_blend_power] = self.mask_blend_power dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast # ------------------- Methods ------------------- -def latent_blend(soft_inpainting, a, b, t): +def latent_blend(settings, a, b, t): """ Interpolates two latent image representations according to the parameter t, where the interpolated vectors' magnitudes are also interpolated separately. @@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t): # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + settings.inpaint_detail_preservation) * one_minus_t3 b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - soft_inpainting.inpaint_detail_preservation) * t3 + settings.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. @@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t): return image_interp_scaled -def get_modified_nmask(soft_inpainting, nmask, sigma): +def get_modified_nmask(settings, nmask, sigma): """ Converts a negative mask representing the transparency of the original latent vectors being overlayed to a mask that is scaled according to the denoising strength for this step. @@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) def apply_adaptive_masks( + settings:SoftInpaintingSettings, + nmask, latent_orig, latent_processed, overlay_images, @@ -108,11 +122,13 @@ def apply_adaptive_masks( from PIL import Image, ImageOps, ImageFilter # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - # latent_mask = p.nmask[0].float().cpu() + latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) - # mask_scalar = mask_scalar / (1.00001-mask_scalar) - # mask_scalar = mask_scalar.numpy() + mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1-settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -128,10 +144,10 @@ def apply_adaptive_masks( percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% - # half_weighted_distance = 1 # * mask_scalar - # converted_mask = converted_mask / half_weighted_distance + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance - converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) converted_mask = smootherstep(converted_mask) converted_mask = 1 - converted_mask converted_mask = 255. * converted_mask @@ -161,7 +177,7 @@ def apply_adaptive_masks( def apply_masks( - soft_inpainting, + settings, nmask, overlay_images, width, height, @@ -172,7 +188,7 @@ def apply_masks( from PIL import Image, ImageOps, ImageFilter converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) converted_mask = 255. * converted_mask converted_mask = converted_mask.cpu().numpy().astype(np.uint8) converted_mask = Image.fromarray(converted_mask) @@ -395,7 +411,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): # ------------------- Constants ------------------- -default = SoftInpaintingSettings(1, 0.5, 4) +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) enabled_ui_label = "Soft inpainting" enabled_gen_param_label = "Soft inpainting enabled" @@ -404,25 +420,37 @@ enabled_el_id = "soft_inpainting_enabled" ui_labels = SoftInpaintingSettings( "Schedule bias", "Preservation strength", - "Transition contrast boost") + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") ui_info = SoftInpaintingSettings( "Shifts when preservation of original content occurs during denoising.", "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.") + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") gen_param_labels = SoftInpaintingSettings( "Soft inpainting schedule bias", "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost") + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") el_ids = SoftInpaintingSettings( "mask_blend_power", "mask_blend_scale", - "inpaint_detail_preservation") + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") -# ----- +# ------------------- Script ------------------- class Script(scripts.Script): @@ -449,28 +477,62 @@ class Script(scripts.Script): **High _Mask blur_** values are recommended! """) - result = SoftInpaintingSettings( + power = \ gr.Slider(label=ui_labels.mask_blend_power, info=ui_info.mask_blend_power, minimum=0, maximum=8, step=0.1, value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power), + elem_id=el_ids.mask_blend_power) + scale = \ gr.Slider(label=ui_labels.mask_blend_scale, info=ui_info.mask_blend_scale, minimum=0, maximum=8, step=0.05, value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale), + elem_id=el_ids.mask_blend_scale) + detail = \ gr.Slider(label=ui_labels.inpaint_detail_preservation, info=ui_info.inpaint_detail_preservation, minimum=1, maximum=32, step=0.5, value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation)) + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) with gr.Accordion("Help", open=False): gr.Markdown( @@ -507,41 +569,86 @@ class Script(scripts.Script): - **High values**: Stronger contrast, may over-saturate colors. """) + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), - (result.mask_blend_power, gen_param_labels.mask_blend_power), - (result.mask_blend_scale, gen_param_labels.mask_blend_scale), - (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)] + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] self.paste_field_names = [] for _, field_name in self.infotext_fields: self.paste_field_names.append(field_name) return [soft_inpainting_enabled, - result.mask_blend_power, - result.mask_blend_scale, - result.inpaint_detail_preservation] + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] - def process(self, p, enabled, power, scale, detail_preservation): + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return # Shut off the rounding it normally does. p.mask_round = False - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - if mba.sigma is None: + if mba.is_final_blend: mba.blended_latent = mba.current_latent return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) # todo: Why is sigma 2D? Both values are the same. mba.blended_latent = latent_blend(settings, @@ -549,11 +656,11 @@ class Script(scripts.Script): mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation) + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) from modules import images from modules.shared import opts @@ -570,15 +677,20 @@ class Script(scripts.Script): self.overlay_images.append(image.convert('RGBA')) + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + if getattr(ps.samples, 'already_decoded', False): - self.masks_for_overlay = apply_masks(soft_inpainting=settings, + self.masks_for_overlay = apply_masks(settings=settings, nmask=p.nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: - self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent, + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=p.nmask, + latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, width=p.width, @@ -586,7 +698,7 @@ class Script(scripts.Script): paste_to=p.paste_to) - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From fc3e246c0f4f292c33b181a902cd934629ff0d7a Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 20:28:38 -0700 Subject: [PATCH 079/311] Fixed complaint about whitespace, updated help section for a parameter. --- scripts/soft_inpainting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 1b21aee9d..6fb5cfbd0 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -572,7 +572,7 @@ class Script(scripts.Script): gr.Markdown( """ ## Pixel Composite Settings - + Masks are generated based on how much a part of the image changed after denoising. These masks are used to blend the original and final images together. If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. @@ -602,10 +602,10 @@ class Script(scripts.Script): f""" ### {ui_labels.composite_difference_contrast} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the contrast between the opacity of the original and inpainted content. - - **Low values**: Two images patches must be almost the same in order to retain original pixels. - - **High values**: Two images patches can be very different and still retain original pixels. + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. """) self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), From 659f62e120b210e3043712ff928e8b7b6cd6cf61 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Thu, 7 Dec 2023 21:39:54 -0700 Subject: [PATCH 080/311] Fixed grammar error. --- scripts/soft_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 6fb5cfbd0..51f9ca2fe 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -592,7 +592,7 @@ class Script(scripts.Script): f""" ### {ui_labels.composite_difference_threshold} - This value represents the difference at which the opacity of the original pixels will have less than 50% opacity. + This value represents the difference at which the original pixels will have less than 50% opacity. - **Low values**: Two images patches must be almost the same in order to retain original pixels. - **High values**: Two images patches can be very different and still retain original pixels. From 16bdcce92d5b482d50cdc32a8f308040d320b6c9 Mon Sep 17 00:00:00 2001 From: Rene Kroon Date: Fri, 8 Dec 2023 21:19:29 +0100 Subject: [PATCH 081/311] #13354: solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706a..629bf8537 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) From b2414476ef164ba55cff2508c58b73d23bbc3000 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:32:41 -0700 Subject: [PATCH 082/311] soft_inpainting now appears in the "inpaint" section, and will not activate unless inpainting is activated. --- scripts/soft_inpainting.py | 43 ++++++++++++++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index 51f9ca2fe..f10a1e562 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -32,6 +32,19 @@ class SoftInpaintingSettings: # ------------------- Methods ------------------- +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + def latent_blend(settings, a, b, t): """ @@ -454,8 +467,8 @@ el_ids = SoftInpaintingSettings( class Script(scripts.Script): - def __init__(self): + self.section = "inpaint" self.masks_for_overlay = None self.overlay_images = None @@ -632,6 +645,9 @@ class Script(scripts.Script): if not enabled: return + if not processing_uses_inpainting(p): + return + # Shut off the rounding it normally does. p.mask_round = False @@ -644,6 +660,9 @@ class Script(scripts.Script): if not enabled: return + if not processing_uses_inpainting(p): + return + if mba.is_final_blend: mba.blended_latent = mba.current_latent return @@ -660,11 +679,18 @@ class Script(scripts.Script): if not enabled: return - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return from modules import images from modules.shared import opts + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + # since the original code puts holes in the existing overlay images, # we have to rebuild them. self.overlay_images = [] @@ -682,14 +708,14 @@ class Script(scripts.Script): if getattr(ps.samples, 'already_decoded', False): self.masks_for_overlay = apply_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, overlay_images=self.overlay_images, width=p.width, height=p.height, paste_to=p.paste_to) else: self.masks_for_overlay = apply_adaptive_masks(settings=settings, - nmask=p.nmask, + nmask=nmask, latent_orig=p.init_latent, latent_processed=ps.samples, overlay_images=self.overlay_images, @@ -702,5 +728,14 @@ class Script(scripts.Script): if not enabled: return + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] ppmo.overlay_image = self.overlay_images[ppmo.index] From f1ff932cafa2bf34fa35f41072f21a8ea5474d84 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Fri, 8 Dec 2023 17:33:11 -0700 Subject: [PATCH 083/311] Formatted soft_inpainting. --- scripts/soft_inpainting.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py index f10a1e562..d90243442 100644 --- a/scripts/soft_inpainting.py +++ b/scripts/soft_inpainting.py @@ -122,7 +122,7 @@ def get_modified_nmask(settings, nmask, sigma): def apply_adaptive_masks( - settings:SoftInpaintingSettings, + settings: SoftInpaintingSettings, nmask, latent_orig, latent_processed, @@ -137,10 +137,10 @@ def apply_adaptive_masks( # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. latent_mask = nmask[0].float() # convert the original mask into a form we use to scale distances for thresholding - mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) - mask_scalar = (0.5 * (1-settings.composite_mask_influence) + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + mask_scalar * settings.composite_mask_influence) - mask_scalar = mask_scalar / (1.00001-mask_scalar) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) mask_scalar = mask_scalar.cpu().numpy() latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) @@ -152,9 +152,9 @@ def apply_adaptive_masks( for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): converted_mask = distance_map.float().cpu().numpy() converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) + percentile_min=0.9, percentile_max=1, min_width=1) converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) + percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% half_weighted_distance = settings.composite_difference_threshold * mask_scalar @@ -276,6 +276,7 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe An element of the histogram, its weight and bounds. """ + def __init__(self, value, weight): self.value: float = value self.weight: float = weight @@ -355,6 +356,7 @@ def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, pe return img_out + def smoothstep(x): """ The smoothstep function, input should be clamped to 0-1 range. @@ -362,6 +364,7 @@ def smoothstep(x): """ return x * x * (3 - 2 * x) + def smootherstep(x): """ The smootherstep function, input should be clamped to 0-1 range. @@ -385,6 +388,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): Returns: (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) """ + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. def gaussian(sqr_mag): return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) @@ -656,7 +660,8 @@ class Script(scripts.Script): # p.extra_generation_params["Mask rounding"] = False settings.add_generation_params(p.extra_generation_params) - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -675,7 +680,8 @@ class Script(scripts.Script): mba.current_latent, get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): if not enabled: return @@ -723,8 +729,8 @@ class Script(scripts.Script): height=p.height, paste_to=p.paste_to) - - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): if not enabled: return From 59429793440fb3cb1624ddcc702c6f9807373203 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:09:45 +0800 Subject: [PATCH 084/311] Fix ControlNet --- modules/xpu_specific.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index ec1ad100a..9bb0a5615 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -51,3 +51,9 @@ if has_xpu: CondFunc('torch.bmm', lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out), lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype) + CondFunc('torch.cat', + lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), + lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) + CondFunc('torch.nn.functional.scaled_dot_product_attention', + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file From 049d5642e58d572ee8657ac754e72d019eea0e6c Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 9 Dec 2023 18:11:26 +0800 Subject: [PATCH 085/311] Fix format --- modules/xpu_specific.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 9bb0a5615..d8da94a0e 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -56,4 +56,4 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) \ No newline at end of file + lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) From 39ec4cfea9040bc94e639eb4aa8ab8ed37a68f01 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 19:12:59 +0600 Subject: [PATCH 086/311] Re-add setting lost as part of e294e46 --- modules/shared_options.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018..acb6e2d48 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -256,6 +256,7 @@ options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "keyedit_precision_extra": OptionInfo(0.05, "Precision for when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"), "keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) From 9c201550ddae0b33367adfb99bcbb57ba9b207a9 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 21:04:45 +0600 Subject: [PATCH 087/311] Add keyboard shortcuts for generation (Removed Alt+Enter) Ctrl+Enter to start/restart generation (New) Alt/Option+Enter to skip generation (New) Ctrl+Alt/Option+Enter to interrupt generation --- modules/ui_toprow.py | 4 ++-- script.js | 23 +++++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 88838f977..c3865e3d9 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ class Toprow: def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index c0e678ea7..69598f454 100644 --- a/script.js +++ b/script.js @@ -121,16 +121,21 @@ document.addEventListener("DOMContentLoaded", function() { }); /** - * Add a ctrl+enter as a shortcut to start a generation + * Add keyboard shortcuts: + * Ctrl+Enter to start/restart a generation + * Alt/Option+Enter to skip a generation + * Alt/Option+Ctrl+Enter to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; - const isModifierKey = e.metaKey || e.ctrlKey || e.altKey; + const isCtrlKey = e.metaKey || e.ctrlKey; + const isAltKey = e.altKey; - const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); + const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); + const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isEnter && isModifierKey) { + if (isCtrlKey && isEnter && !isAltKey) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -150,6 +155,16 @@ document.addEventListener('keydown', function(e) { } e.preventDefault(); } + + if (isAltKey && isEnter && !isCtrlKey) { + skipButton.click(); + e.preventDefault(); + } + + if (isAltKey && isCtrlKey && isEnter) { + interruptButton.click(); + e.preventDefault(); + } }); /** From 1a79a5049bdfef285235e83f37b201e39dd54f81 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 22:35:31 +0600 Subject: [PATCH 088/311] Assign id for "extra_options". Replace numeric field with slider in Settings. --- .../extra-options-section/scripts/extra_options_section.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df625..b9867fe63 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 5381405eaa1e809e5cfb97522bd4c19d3c946079 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 9 Dec 2023 14:09:28 -0500 Subject: [PATCH 089/311] re-derive sqrt alpha bar and sqrt one minus alphabar This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent. --- modules/sd_samplers_timesteps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c..c4bd5c127 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): self.inner_model = model def predict_eps_from_z_and_v(self, x_t, t, v): - return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t def forward(self, input, timesteps, **kwargs): model_output = self.inner_model.apply_model(input, timesteps, **kwargs) From 23a0e60b9bf90a80f8af9732cc6495fbfce2ea21 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 14:03:41 +0900 Subject: [PATCH 090/311] fix save styles --- modules/styles.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 7fb6c2e11..07588945e 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -155,10 +155,8 @@ class StyleDatabase: row["name"], prompt, negative_prompt, path ) - def get_style_paths(self) -> list(): - """ - Returns a list of all distinct paths, including the default path, of - files that styles are loaded from.""" + def get_style_paths(self) -> set: + """Returns a set of all distinct paths of files that styles are loaded from.""" # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: @@ -172,9 +170,9 @@ class StyleDatabase: style_paths.add(style.path) # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths.discard("do_not_save") - return list(style_paths) + return style_paths def get_style_prompts(self, styles): return [self.styles.get(x, self.no_style).prompt for x in styles] @@ -196,20 +194,7 @@ class StyleDatabase: # The path argument is deprecated, but kept for backwards compatibility _ = path - # Update any styles without a path to the default path - for style in list(self.styles.values()): - if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) - - # Create a list of all distinct paths, including the default path - style_paths = set() - style_paths.add(self.default_path) - for _, style in self.styles.items(): - if style.path: - style_paths.add(style.path) - - # Remove any paths for styles that are just list dividers - style_paths.remove("do_not_save") + style_paths = self.get_style_paths() csv_names = [os.path.split(path)[1].lower() for path in style_paths] From 8b74389e76a7678e972583ef16100e90e1519e55 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 10 Dec 2023 15:48:16 +0900 Subject: [PATCH 091/311] fix styles.csv filename --- modules/styles.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 07588945e..81d9800d1 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -98,10 +98,8 @@ class StyleDatabase: self.path = path folder, file = os.path.split(self.path) - self.default_file = file.split("*")[0] + ".csv" - if self.default_file == ".csv": - self.default_file = "styles.csv" - self.default_path = os.path.join(folder, self.default_file) + filename, _, ext = file.partition('*') + self.default_path = os.path.join(folder, filename + ext) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] From 6b8143a84e112f029ee1868b6ab98b1d2c773ead Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 15:35:06 +0600 Subject: [PATCH 092/311] Number of columns slider: max count set to 20, add description info --- .../extra-options-section/scripts/extra_options_section.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index b9867fe63..ac2c3de46 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -71,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) From 1d42babd324b933bae317cb427fe0513138954f4 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 16:28:56 +0600 Subject: [PATCH 093/311] Replace Ctrl+Alt+Enter with Esc --- modules/ui_toprow.py | 4 ++-- script.js | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index c3865e3d9..9caf8faa2 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -79,11 +79,11 @@ class Toprow: def create_prompts(self): with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6): with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]): - self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False) with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]): - self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Ctrl+Alt+Enter to interrupt)", elem_classes=["prompt"]) + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)", elem_classes=["prompt"]) self.prompt_img.change( fn=modules.images.image_data, diff --git a/script.js b/script.js index 69598f454..44950090a 100644 --- a/script.js +++ b/script.js @@ -124,18 +124,19 @@ document.addEventListener("DOMContentLoaded", function() { * Add keyboard shortcuts: * Ctrl+Enter to start/restart a generation * Alt/Option+Enter to skip a generation - * Alt/Option+Ctrl+Enter to interrupt a generation + * Esc to interrupt a generation */ document.addEventListener('keydown', function(e) { const isEnter = e.key === 'Enter' || e.keyCode === 13; const isCtrlKey = e.metaKey || e.ctrlKey; const isAltKey = e.altKey; + const isEsc = e.key === 'Escape'; const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]'); const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]'); - if (isCtrlKey && isEnter && !isAltKey) { + if (isCtrlKey && isEnter) { if (interruptButton.style.display === 'block') { interruptButton.click(); const callback = (mutationList) => { @@ -156,14 +157,18 @@ document.addEventListener('keydown', function(e) { e.preventDefault(); } - if (isAltKey && isEnter && !isCtrlKey) { + if (isAltKey && isEnter) { skipButton.click(); e.preventDefault(); } - if (isAltKey && isCtrlKey && isEnter) { - interruptButton.click(); - e.preventDefault(); + if (isEsc) { + if (!globalPopup || globalPopup.style.display === "none") { + interruptButton.click(); + e.preventDefault(); + } else { + closePopup(); + } } }); From cee1a4065162982e18f32761259d9107538c2d93 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 17:06:12 +0600 Subject: [PATCH 094/311] Fix linter issues --- script.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/script.js b/script.js index 44950090a..354154b01 100644 --- a/script.js +++ b/script.js @@ -163,11 +163,13 @@ document.addEventListener('keydown', function(e) { } if (isEsc) { + const globalPopup = document.querySelector('.global-popup'); if (!globalPopup || globalPopup.style.display === "none") { interruptButton.click(); e.preventDefault(); } else { - closePopup(); + if (!globalPopup) return; + globalPopup.style.display = "none"; } } }); From 6513470f0db1aed1b0a5200634e8e02f7c05e932 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Mon, 11 Dec 2023 18:06:08 +0600 Subject: [PATCH 095/311] Remove unnecessary 'else', add 'lightboxModal' check --- script.js | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/script.js b/script.js index 354154b01..be1bc317e 100644 --- a/script.js +++ b/script.js @@ -164,12 +164,11 @@ document.addEventListener('keydown', function(e) { if (isEsc) { const globalPopup = document.querySelector('.global-popup'); - if (!globalPopup || globalPopup.style.display === "none") { + const lightboxModal = document.querySelector('#lightboxModal'); + if (!globalPopup || globalPopup.style.display === 'none') { + if (document.activeElement === lightboxModal) return; interruptButton.click(); e.preventDefault(); - } else { - if (!globalPopup) return; - globalPopup.style.display = "none"; } } }); From cc41cc4349514bbfeb9f37445c931a050b076bd6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 13 Dec 2023 02:06:56 +0900 Subject: [PATCH 096/311] on mouse hover show / hide modal image viewer icons --- style.css | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/style.css b/style.css index ee39a57b7..ec449bdea 100644 --- a/style.css +++ b/style.css @@ -749,6 +749,22 @@ table.popup-table .link{ display: none; } +@media (pointer: fine) { + .modalPrev:hover, + .modalNext:hover, + .modalControls:hover ~ .modalPrev, + .modalControls:hover ~ .modalNext, + .modalControls:hover .cursor { + opacity: 1; + } + + .modalPrev, + .modalNext, + .modalControls .cursor { + opacity: 0; + } +} + /* context menu (ie for the generate button) */ #context-menu{ From bda86f0fd9653657c146f7c1128f92771d16ad4e Mon Sep 17 00:00:00 2001 From: Hina <102651522+HinaHyugaHime@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:39:14 -0600 Subject: [PATCH 097/311] Update webui.sh --- webui.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/webui.sh b/webui.sh index 3d0f87eed..046ecf9d0 100755 --- a/webui.sh +++ b/webui.sh @@ -131,7 +131,7 @@ case "$gpu_info" in if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] then # Navi users will still use torch 1.13 because 2.0 does not seem to work. - export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" else printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" exit 1 @@ -141,9 +141,8 @@ case "$gpu_info" in *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ - export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6" - # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) - # so switch to nightly rocm5.6 without explicit versions this time + export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7" + ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" From 89cfbc3bbe401fe1655afb07edbae34ec6af7aca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 13 Dec 2023 12:22:13 +0200 Subject: [PATCH 098/311] Allow pasting in WIDTHxHEIGHT strings into the width/height fields --- javascript/ui.js | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/javascript/ui.js b/javascript/ui.js index 410fc44e3..18c9f891a 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -215,9 +215,33 @@ function restoreProgressImg2img() { } +/** + * Configure the width and height elements on `tabname` to accept + * pasting of resolutions in the form of "width x height". + */ +function setupResolutionPasting(tabname) { + var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`); + var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`); + for (const el of [width, height]) { + el.addEventListener('paste', function(event) { + var pasteData = event.clipboardData.getData('text/plain'); + var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/); + if (parsed) { + width.value = parsed[1]; + height.value = parsed[2]; + updateInput(width); + updateInput(height); + event.preventDefault(); + } + }); + } +} + onUiLoaded(function() { showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); showRestoreProgressButton('img2img', localGet("img2img_task_id")); + setupResolutionPasting('txt2img'); + setupResolutionPasting('img2img'); }); From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: [PATCH 099/311] Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c378118..44465f7aa 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ class NetworkModuleOFT(network.NetworkModule): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: [PATCH 100/311] Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7aa..e3ae61a22 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: [PATCH 101/311] remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a22..ff4eb59b1 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: [PATCH 102/311] better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b1..fa647020f 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) From 3c0c27757944ae17a7fa4c2323ee9ae2d434dbce Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 14 Dec 2023 19:36:17 +0900 Subject: [PATCH 103/311] default False js_live_preview_in_modal_lightbox --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 41097d8e6..d2e86ff10 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -331,7 +331,7 @@ options_templates.update(options_section(('ui', "Live previews", "ui"), { "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), - "js_live_preview_in_modal_lightbox": OptionInfo(True, "Show Live preview in full page image viewer"), + "js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), { From 0c5427960b3a4ffe6d673c28e8e135b26f015717 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 15 Dec 2023 17:11:59 +0900 Subject: [PATCH 104/311] make modal toolbar and icon opacity adjustable --- modules/shared_gradio_themes.py | 4 ++++ modules/shared_options.py | 2 ++ style.css | 4 ++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py index 822db0a95..b6dc31450 100644 --- a/modules/shared_gradio_themes.py +++ b/modules/shared_gradio_themes.py @@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None): except Exception as e: errors.display(e, "changing gradio theme") shared.gradio_theme = gr.themes.Default(**default_theme_args) + + # append additional values gradio_theme + shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity + shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity diff --git a/modules/shared_options.py b/modules/shared_options.py index e5de0d018..86e7636cc 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -266,6 +266,8 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"), "js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"), "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"), + "sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), + "sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(), "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(), })) diff --git a/style.css b/style.css index ec449bdea..6d4c8a0d5 100644 --- a/style.css +++ b/style.css @@ -679,7 +679,7 @@ table.popup-table .link{ transition: 0.2s ease background-color; } .modalControls:hover { - background-color:rgba(0,0,0,0.9); + background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity)); } .modalClose { margin-left: auto; @@ -761,7 +761,7 @@ table.popup-table .link{ .modalPrev, .modalNext, .modalControls .cursor { - opacity: 0; + opacity: var(--sd-webui-modal-lightbox-icon-opacity); } } From 1242ba08e19f3d317bdc5924db2b73d0c9569a7f Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 16:57:17 +0800 Subject: [PATCH 105/311] add allow specify the task id and get the location of task in the queue of pending task --- modules/api/api.py | 20 ++++++++++++++++++-- modules/api/models.py | 2 ++ modules/processing.py | 2 ++ modules/progress.py | 21 +++++++++++++++++++-- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6edffe71..5d000ae8b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -33,7 +33,7 @@ from typing import Dict, List, Any import piexif import piexif.helper from contextlib import closing - +from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task def script_name_to_index(name, scripts): try: @@ -337,6 +337,10 @@ class Api: return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): + task_id = create_task_id("text2img") + if txt2imgreq.force_task_id != None: + task_id = txt2imgreq.force_task_id + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -363,6 +367,8 @@ class Api: send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: p.is_api = True @@ -372,12 +378,14 @@ class Api: try: shared.state.begin(job="scripts_txt2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -387,6 +395,10 @@ class Api: return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): + task_id = create_task_id("img2img") + if img2imgreq.force_task_id != None: + task_id = img2imgreq.force_task_id + init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -423,6 +435,8 @@ class Api: send_images = args.pop('send_images', True) args.pop('save_images', None) + add_task_to_queue(task_id) + with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] @@ -433,12 +447,14 @@ class Api: try: shared.state.begin(job="scripts_img2img") + start_task(task_id) if selectable_scripts is not None: p.script_args = script_args processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here else: p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) + finish_task(task_id) finally: shared.state.end() shared.total_tqdm.clear() @@ -514,7 +530,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) + return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task) def interrogateapi(self, interrogatereq: models.InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 6a574771c..7b7f17738 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -109,6 +109,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "send_images", "type": bool, "default": True}, {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, + {"key": "force_task_id", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/processing.py b/modules/processing.py index e124e7f0d..657cacfca 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1023,6 +1023,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_sampler_name: str = None hr_prompt: str = '' hr_negative_prompt: str = '' + force_task_id: str = None cached_hr_uc = [None, None] cached_hr_c = [None, None] @@ -1358,6 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None + force_task_id: string = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 69921de72..553866db1 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -8,10 +8,13 @@ from pydantic import BaseModel, Field from modules.shared import opts import modules.shared as shared - +from collections import OrderedDict +import string +import random +from typing import List current_task = None -pending_tasks = {} +pending_tasks = OrderedDict() finished_tasks = [] recorded_results = [] recorded_results_limit = 2 @@ -34,6 +37,11 @@ def finish_task(id_task): if len(finished_tasks) > 16: finished_tasks.pop(0) +def create_task_id(task_type): + N = 7 + res = ''.join(random.choices(string.ascii_uppercase + + string.digits, k=N)) + return f"task({task_type}-{res})" def record_results(id_task, res): recorded_results.append((id_task, res)) @@ -44,6 +52,9 @@ def record_results(id_task, res): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class PendingTasksResponse(BaseModel): + size: int = Field(title="Pending task size") + tasks: List[str] = Field(title="Pending task ids") class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") @@ -63,8 +74,14 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): + app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def get_pending_tasks(): + pending_tasks_ids = [x for x in pending_tasks] + pending_len = len(pending_tasks_ids) + return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) + def progressapi(req: ProgressRequest): active = req.id_task == current_task From d859de37d9ec10cb6c804226328a11c87c444852 Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:48:20 +0800 Subject: [PATCH 106/311] fix the problem of ruff of github --- modules/api/api.py | 4 ++-- modules/processing.py | 2 +- modules/progress.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5d000ae8b..1f4648067 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -338,7 +338,7 @@ class Api: def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = create_task_id("text2img") - if txt2imgreq.force_task_id != None: + if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id script_runner = scripts.scripts_txt2img @@ -396,7 +396,7 @@ class Api: def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): task_id = create_task_id("img2img") - if img2imgreq.force_task_id != None: + if img2imgreq.force_task_id is None: task_id = img2imgreq.force_task_id init_images = img2imgreq.init_images diff --git a/modules/processing.py b/modules/processing.py index 657cacfca..5added65e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1359,7 +1359,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): inpainting_mask_invert: int = 0 initial_noise_multiplier: float = None latent_mask: Image = None - force_task_id: string = None + force_task_id: str = None image_mask: Any = field(default=None, init=False) diff --git a/modules/progress.py b/modules/progress.py index 553866db1..6946fb1bc 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -78,7 +78,7 @@ def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) def get_pending_tasks(): - pending_tasks_ids = [x for x in pending_tasks] + pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids) From da45e73b4ffde2e2a85b64a3e3258a0625bd307e Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 17:57:58 +0800 Subject: [PATCH 107/311] fix the problem of ruff of github --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1f4648067..9fac7e604 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,7 @@ class Api: task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From 6d7e57ba6a4d686d515518b5f90e91b32fa01caf Mon Sep 17 00:00:00 2001 From: gayshub Date: Fri, 15 Dec 2023 18:03:14 +0800 Subject: [PATCH 108/311] fix the problem of ruff of github --- modules/api/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9fac7e604..8d8e70a46 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -340,7 +340,6 @@ class Api: task_id = create_task_id("text2img") if txt2imgreq.force_task_id is None: task_id = txt2imgreq.force_task_id - script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) From ea272152e0b50dbb2bd675ec020607f3d50c37d0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 15:08:08 +0800 Subject: [PATCH 109/311] Add FP8 settings into PNG info --- modules/generation_parameters_copypaste.py | 6 ++++++ modules/processing.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4efe53e0c..dbffe4949 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -314,6 +314,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "VAE Decoder" not in res: res["VAE Decoder"] = "Full" + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/processing.py b/modules/processing.py index bea01ec68..179f2c0fc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -688,6 +688,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, "Model": p.sd_model_name if opts.add_model_name_to_info else None, + "FP8 weight": opts.fp8_storage if devices.fp8 else None, + "Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None, "VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None, "VAE": p.sd_vae_name if opts.add_vae_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), From 7745db6fc02faf19117838c1e7bcc8a60b5f5e90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:15:08 +0300 Subject: [PATCH 110/311] torch 2.1.2 --- modules/errors.py | 4 ++-- modules/launch_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index c534a5d66..48aa13a17 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -107,8 +107,8 @@ def check_versions(): import torch import gradio - expected_torch_version = "2.1.0" - expected_xformers_version = "0.0.22.post7" + expected_torch_version = "2.1.2" + expected_xformers_version = "0.0.23.post1" expected_gradio_version = "3.41.2" if version.parse(torch.__version__) < version.parse(expected_torch_version): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 2c54e2a02..dabef0f53 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -315,7 +315,7 @@ def requirements_met(requirements_file): def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121") - torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}") + torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}") if args.use_ipex: if platform.system() == "Windows": # The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main @@ -338,7 +338,7 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.22.post7') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1') clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") From cd9ce2e31c4a264d7cde17c54d24f8ad94c9cf2c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 10:40:20 +0300 Subject: [PATCH 111/311] Use radio for FP8 mode selection --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index d470eb8ff..fa542ba8c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -206,7 +206,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), - "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) From 93eae69895c34361a71dbed17348bcfd132fbc6a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:00:42 +0300 Subject: [PATCH 112/311] move soft inpainting to a built-in extension --- .../soft-inpainting/scripts}/soft_inpainting.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {scripts => extensions-builtin/soft-inpainting/scripts}/soft_inpainting.py (100%) diff --git a/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py similarity index 100% rename from scripts/soft_inpainting.py rename to extensions-builtin/soft-inpainting/scripts/soft_inpainting.py From 86b3aa94e2d36a4f9d5ef1bb7c6ec995ff8eb517 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:04:59 +0300 Subject: [PATCH 113/311] rename pending tasks api endpoint to be more in line with others --- modules/progress.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/progress.py b/modules/progress.py index 6946fb1bc..85255e821 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -74,9 +74,10 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): - app.add_api_route("/internal/pendingTasks", get_pending_tasks, methods=["GET"]) + app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"]) return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + def get_pending_tasks(): pending_tasks_ids = list(pending_tasks) pending_len = len(pending_tasks_ids) From a97832033427096072d5ea914adac3662cda4fd1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 16 Dec 2023 19:39:43 +0800 Subject: [PATCH 114/311] Let fp8-related settings to invalidate cond_cache --- modules/processing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dd97b4eee..9351e3fb0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -431,6 +431,8 @@ class StableDiffusionProcessing: opts.sdxl_crop_top, self.width, self.height, + opts.fp8_storage, + opts.cache_fp16_weight, ) def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): From 98c5fa92015837706adfd9975d5f345ab74f1c99 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 16 Dec 2023 22:14:39 +0900 Subject: [PATCH 115/311] fix extras caption BLIP #14328 --- scripts/postprocessing_caption.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 243e3ad9c..9482a03ca 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -25,6 +25,6 @@ class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): captions.append(deepbooru.model.tag(pp.image)) if "BLIP" in option: - captions.append(shared.interrogator.generate_caption(pp.image)) + captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) pp.caption = ", ".join([x for x in captions if x]) From de03882d6ca56bc81058f5120f028678a6a54aaa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 08:55:35 +0300 Subject: [PATCH 116/311] make task ids for API work without force_task_id --- modules/api/api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9637cb81b..7154c9d5b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -336,9 +336,8 @@ class Api: return script_args def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): - task_id = create_task_id("text2img") - if txt2imgreq.force_task_id is None: - task_id = txt2imgreq.force_task_id + task_id = txt2imgreq.force_task_id or create_task_id("txt2img") + script_runner = scripts.scripts_txt2img if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -393,9 +392,7 @@ class Api: return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI): - task_id = create_task_id("img2img") - if img2imgreq.force_task_id is None: - task_id = img2imgreq.force_task_id + task_id = img2imgreq.force_task_id or create_task_id("img2img") init_images = img2imgreq.init_images if init_images is None: From 10945aa41a158ee03727c5ea77d4ffff6b5370f0 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:27:41 +0900 Subject: [PATCH 117/311] only rewrite ui-config when there is change and a typo --- modules/ui.py | 4 +++- modules/ui_loadsave.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index d80486dd4..f02b55116 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1086,6 +1086,7 @@ def create_ui(): ) loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) + ui_settings_from_file = loadsave.ui_settings.copy() settings = ui_settings.UiSettings() settings.create_ui(loadsave, dummy_component) @@ -1146,7 +1147,8 @@ def create_ui(): modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint']) - loadsave.dump_defaults() + if ui_settings_from_file != loadsave.ui_settings: + loadsave.dump_defaults() demo.ui_loadsave = loadsave return demo diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 7826786cc..693ff75c5 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -144,7 +144,7 @@ class UiLoadsave: json.dump(current_ui_settings, file, indent=4, ensure_ascii=False) def dump_defaults(self): - """saves default values to a file unless tjhe file is present and there was an error loading default values at start""" + """saves default values to a file unless the file is present and there was an error loading default values at start""" if self.error_loading and os.path.exists(self.filename): return From e4b4a9c4acf0ca375a8603f7f52fde8467b2d266 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 18:00:01 +0800 Subject: [PATCH 118/311] [IPEX] Slice SDPA into smaller chunks --- modules/xpu_specific.py | 66 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index d8da94a0e..0ebdd5964 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -27,6 +27,68 @@ def torch_xpu_gc(): has_xpu = check_for_xpu() + +# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627 +# Here we implement a slicing algorithm to split large batch size into smaller chunks, +# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. +# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, +# which is the best trade-off between VRAM usage and performance. +ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention +def torch_xpu_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs +): + # cast to same dtype first + key = key.to(query.dtype) + value = value.to(query.dtype) + + N = query.shape[:-2] # Batch size + L = query.size(-2) # Target sequence length + E = query.size(-1) # Embedding dimension of the query and key + S = key.size(-2) # Source sequence length + Ev = value.size(-1) # Embedding dimension of the value + + total_batch_size = torch.numel(torch.empty(N)) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + + if total_batch_size <= batch_size_limit: + return orig_sdp_attn_func( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + *args, **kwargs + ) + + query = torch.reshape(query, (-1, L, E)) + key = torch.reshape(key, (-1, S, E)) + value = torch.reshape(value, (-1, S, Ev)) + if attn_mask is not None: + attn_mask = attn_mask.view(-1, L, S) + chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit + outputs = [] + for i in range(chunk_count): + attn_mask_chunk = ( + None + if attn_mask is None + else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :] + ) + chunk_output = orig_sdp_attn_func( + query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :], + attn_mask_chunk, + dropout_p, + is_causal, + *args, **kwargs + ) + outputs.append(chunk_output) + result = torch.cat(outputs, dim=0) + return torch.reshape(result, (*N, L, Ev)) + + if has_xpu: # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device CondFunc('torch.Generator', @@ -55,5 +117,5 @@ if has_xpu: lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out), lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors)) CondFunc('torch.nn.functional.scaled_dot_product_attention', - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal), - lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype) + lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs), + lambda orig_func, query, *args, **kwargs: query.is_xpu) From f586f4973a0f715e30b42242bb0e6b3f88c37d90 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Mon, 18 Dec 2023 19:44:52 +0800 Subject: [PATCH 119/311] Fix device id --- modules/xpu_specific.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 0ebdd5964..f7687a66c 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -33,7 +33,7 @@ has_xpu = check_for_xpu() # so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT. # The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G, # which is the best trade-off between VRAM usage and performance. -ARC_SINGLE_ALLOCATION_LIMIT = min(torch.xpu.get_device_properties(shared.cmd_opts.device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) +ARC_SINGLE_ALLOCATION_LIMIT = {} orig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention def torch_xpu_scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs @@ -49,7 +49,10 @@ def torch_xpu_scaled_dot_product_attention( Ev = value.size(-1) # Embedding dimension of the value total_batch_size = torch.numel(torch.empty(N)) - batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT // (L * S * query.element_size())) + device_id = query.device.index + if device_id not in ARC_SINGLE_ALLOCATION_LIMIT: + ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024) + batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size())) if total_batch_size <= batch_size_limit: return orig_sdp_attn_func( From fe4d084390d35de49baf83c3319a72d71f540aee Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:50:00 +0500 Subject: [PATCH 120/311] Update webui.py Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 9ed20b306..3b587dc41 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 0d5941edbc4c602b760d4200bd76e044c65a0e40 Mon Sep 17 00:00:00 2001 From: Muhammad Rehan Aslam <19831661+ranareehanaslam@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:50:38 +0500 Subject: [PATCH 121/311] Update webui.py Co-authored-by: Aarni Koskela --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3b587dc41..3f5b67a30 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name = cmd_opts.server_name if cmd_opts.server_name else ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From 3e068de0dcec811d515402fc184f70709a785e4f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:48:49 +0900 Subject: [PATCH 122/311] reorder training preprocessing modules in extras tab using the order from before the rework 11d23e8ca55c097ecfa255a05b63f194e25f08be --- scripts/postprocessing_caption.py | 2 +- scripts/postprocessing_create_flipped_copies.py | 2 +- scripts/postprocessing_focal_crop.py | 2 +- scripts/processing_autosized_crop.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/postprocessing_caption.py b/scripts/postprocessing_caption.py index 9482a03ca..5592a8987 100644 --- a/scripts/postprocessing_caption.py +++ b/scripts/postprocessing_caption.py @@ -4,7 +4,7 @@ import gradio as gr class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): name = "Caption" - order = 4000 + order = 4040 def ui(self): with ui_components.InputAccordion(False, label="Caption") as enable: diff --git a/scripts/postprocessing_create_flipped_copies.py b/scripts/postprocessing_create_flipped_copies.py index 3425571dc..b673003b6 100644 --- a/scripts/postprocessing_create_flipped_copies.py +++ b/scripts/postprocessing_create_flipped_copies.py @@ -6,7 +6,7 @@ import gradio as gr class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): name = "Create flipped copies" - order = 4000 + order = 4030 def ui(self): with ui_components.InputAccordion(False, label="Create flipped copies") as enable: diff --git a/scripts/postprocessing_focal_crop.py b/scripts/postprocessing_focal_crop.py index d3baf2987..cff1dbc54 100644 --- a/scripts/postprocessing_focal_crop.py +++ b/scripts/postprocessing_focal_crop.py @@ -7,7 +7,7 @@ from modules.textual_inversion import autocrop class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto focal point crop" - order = 4000 + order = 4010 def ui(self): with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: diff --git a/scripts/processing_autosized_crop.py b/scripts/processing_autosized_crop.py index c09802264..7e6749898 100644 --- a/scripts/processing_autosized_crop.py +++ b/scripts/processing_autosized_crop.py @@ -28,7 +28,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing): name = "Auto-sized crop" - order = 4000 + order = 4020 def ui(self): with ui_components.InputAccordion(False, label="Auto-sized crop") as enable: From 9feb034e343d6d7ef63395821658fb3774b30a24 Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Thu, 21 Dec 2023 20:15:51 +0800 Subject: [PATCH 123/311] support for sdxl-inpaint model --- configs/sd_xl_inpaint.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/processing.py | 19 +++++++ modules/sd_models_config.py | 6 ++- modules/sd_models_xl.py | 5 ++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 configs/sd_xl_inpaint.yaml diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 000000000..3bad37218 --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f5..159548dba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -362,6 +376,11 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e2..b38137eb5 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 011233216..d8a9a73bc 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond) From de1809bd14450cfc41623b6021c7087fb385ab6f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:08:35 +0900 Subject: [PATCH 124/311] handle axis_type is None --- scripts/xyz_grid.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index b2250c04d..e5083874a 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -476,6 +476,8 @@ class Script(scripts.Script): fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown]) def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): + axis_type = axis_type or 0 # if axle type is None set to 0 + choices = self.current_axis_options[axis_type].choices has_choices = choices is not None @@ -526,6 +528,8 @@ class Script(scripts.Script): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 + if not no_fixed_seeds: modules.processing.fix_seed(p) From edfae95d90a49ea95394b772817a59dde4175222 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 23 Dec 2023 01:21:00 +0900 Subject: [PATCH 125/311] prevent crash due to Script __init__ exception --- modules/scripts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index b6fcf96e0..3a7669118 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -566,7 +566,12 @@ class ScriptRunner: auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() for script_data in auto_processing_scripts + scripts_data: - script = script_data.script_class() + try: + script = script_data.script_class() + except Exception: + errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True) + continue + script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img From 00d4a4d4ac75903d8224e9beb1136584dd66fcd8 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Tue, 26 Dec 2023 14:46:29 +0800 Subject: [PATCH 126/311] move thread-unsafe code to __init__ --- modules/api/api.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b..f0a68c672 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,6 +251,15 @@ class Api: self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] + script_runner = scripts.scripts_img2img + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs) @@ -339,11 +348,6 @@ class Api: task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -403,11 +407,6 @@ class Api: mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From bfe418a58d39c69ca2672e7d8a1fd7ad2b34869b Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Wed, 27 Dec 2023 10:20:56 +0800 Subject: [PATCH 127/311] add some codes for robust --- modules/processing.py | 24 +++++++++++++----------- modules/sd_models_xl.py | 5 +++-- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 159548dba..c05e608ab 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -108,17 +108,18 @@ def txt2img_image_conditioning(sd_model, x, width, height): else: sd = sd_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - # The "masked-image" in this case will just be all 0.5 since the entire image is masked. - image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 - image_conditioning = images_tensor_to_samples(image_conditioning, - approximation_indexes.get(opts.sd_vae_encode_method)) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. @@ -378,8 +379,9 @@ class StableDiffusionProcessing: sd = self.sampler.model_wrap.inner_model.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d8a9a73bc..162d0fee8 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -36,8 +36,9 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): sd = self.model.state_dict() diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) - if diffusion_model_input.shape[1] == 9: - x = torch.cat([x] + cond['c_concat'], dim=1) + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) return self.model(x, t, cond) From de04573438bc111f137359b8f4998780bf315275 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:22:51 +0900 Subject: [PATCH 128/311] create utility truncate_path utli.truncate_path(target_path, base_path) return the target_path relative to base_path if target_path is a sub path of base_path else return the absolute path --- modules/util.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/util.py b/modules/util.py index 60afc0670..4861bcb08 100644 --- a/modules/util.py +++ b/modules/util.py @@ -2,7 +2,7 @@ import os import re from modules import shared -from modules.paths_internal import script_path +from modules.paths_internal import script_path, cwd def natural_sort_key(s, regex=re.compile('([0-9]+)')): @@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs): return print(*args, **kwargs) + + +def truncate_path(target_path, base_path=cwd): + abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path) + try: + if os.path.commonpath([abs_target, abs_base]) == abs_base: + return os.path.relpath(abs_target, abs_base) + except ValueError: + pass + return abs_target From af2951ed53da6d357aea9232538f9ea7e1cdc648 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:33 +0900 Subject: [PATCH 129/311] base default image output on data_path Co-Authored-By: Alberto Cano <34340962+canoalberto@users.noreply.github.com> --- modules/paths_internal.py | 1 + modules/shared_options.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 89131a54f..b86ecd7f1 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models") extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") config_states_dir = os.path.join(script_path, "config_states") +default_output_dir = os.path.join(data_path, "output") roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') diff --git a/modules/shared_options.py b/modules/shared_options.py index fa542ba8c..752a4f125 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -1,7 +1,8 @@ +import os import gradio as gr -from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401 from modules.shared_cmd_options import cmd_opts from modules.options import options_section, OptionInfo, OptionHTML, categories @@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), { "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), - "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs), "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), - "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), - "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs), + "outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs), })) options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), { From 892e703b59b2f867d8a202a52fab1db89882ef86 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 28 Dec 2023 06:52:41 +0900 Subject: [PATCH 130/311] webpath use truncate_path --- modules/ui_gradio_extensions.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index 0d368f8b2..a86c368ef 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -1,17 +1,12 @@ import os import gradio as gr -from modules import localization, shared, scripts -from modules.paths import script_path, data_path, cwd +from modules import localization, shared, scripts, util +from modules.paths import script_path, data_path def webpath(fn): - if fn.startswith(cwd): - web_path = os.path.relpath(fn, cwd) - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' + return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}' def javascript_html(): From dc57ec0296e768ee91290e16ab262404837c566d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 29 Dec 2023 01:56:48 +0900 Subject: [PATCH 131/311] save info of init image --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb0..141f2f111 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1482,7 +1482,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # Save init image if opts.save_init_img: self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() - images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False) + images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) image = images.flatten(img, opts.img2img_background_color) From bb07cb6a0df60a96827125ffc09ea182a1ed272c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 17 Dec 2023 10:22:03 +0300 Subject: [PATCH 132/311] a --- modules/api/api.py | 27 +++++++++++++ modules/api/models.py | 2 + modules/generation_parameters_copypaste.py | 19 +++++++++ modules/processing.py | 2 +- modules/processing_scripts/refiner.py | 7 ++-- modules/processing_scripts/seed.py | 13 +++--- modules/ui.py | 46 +++++++++++----------- 7 files changed, 83 insertions(+), 33 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 7154c9d5b..b3d709404 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -335,6 +335,29 @@ class Api: script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args + def apply_infotext(self, request, tabname): + if not request.infotext: + return {} + + params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + + for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: + if not field.api: + continue + + value = field.function(params) if field.function else params.get(field.label) + target_type = request.__fields__[field.api].type_ + + if value is None: + continue + + if not isinstance(value, target_type): + value = target_type(value) + + setattr(request, field.api, value) + + return params + def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): task_id = txt2imgreq.force_task_id or create_task_id("txt2img") @@ -342,6 +365,9 @@ class Api: if not script_runner.scripts: script_runner.initialize_scripts(False) ui.create_ui() + + infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) @@ -358,6 +384,7 @@ class Api: args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) diff --git a/modules/api/models.py b/modules/api/models.py index 58083a34f..16edf11cf 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -108,6 +108,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() @@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( {"key": "save_images", "type": bool, "default": False}, {"key": "alwayson_scripts", "type": dict, "default": {}}, {"key": "force_task_id", "type": str, "default": None}, + {"key": "infotext", "type": str, "default": None}, ] ).generate_model() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index dbffe4949..4b4727c44 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -28,6 +28,19 @@ class ParamBinding: self.paste_field_names = paste_field_names or [] +class PasteField(tuple): + def __new__(cls, component, target, *, api=None): + return super().__new__(cls, (component, target)) + + def __init__(self, component, target, *, api=None): + super().__init__() + + self.api = api + self.component = component + self.label = target if isinstance(target, str) else None + self.function = target if callable(target) else None + + paste_fields: dict[str, dict] = {} registered_param_bindings: list[ParamBinding] = [] @@ -84,6 +97,12 @@ def image_from_url_text(filedata): def add_paste_fields(tabname, init_img, fields, override_settings_component=None): + + if fields: + for i in range(len(fields)): + if not isinstance(fields[i], PasteField): + fields[i] = PasteField(*fields[i]) + paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} # backwards compatibility for existing extensions diff --git a/modules/processing.py b/modules/processing.py index 9351e3fb0..ee2ccf460 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1135,7 +1135,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if self.hr_checkpoint_name: + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) if self.hr_checkpoint_info is None: diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 29ccb78f9..cefad32b7 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,6 +1,7 @@ import gradio as gr from modules import scripts, sd_models +from modules.generation_parameters_copypaste import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion @@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI): return None if info is None else info.title self.infotext_fields = [ - (enable_refiner, lambda d: 'Refiner' in d), - (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), - (refiner_switch_at, 'Refiner switch at'), + PasteField(enable_refiner, lambda d: 'Refiner' in d), + PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), + PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), ] return enable_refiner, refiner_checkpoint, refiner_switch_at diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index dc9c2da50..a3e16a12e 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,6 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors +from modules.generation_parameters_copypaste import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton @@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI): seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) self.infotext_fields = [ - (self.seed, "Seed"), - (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), + PasteField(self.seed, "Seed", api="seed"), + PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), + PasteField(subseed, "Variation seed", api="subseed"), + PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"), + PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"), + PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"), ] self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') diff --git a/modules/ui.py b/modules/ui.py index d80486dd4..9db2407ed 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -28,7 +28,7 @@ import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text +from modules.generation_parameters_copypaste import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component @@ -436,28 +436,28 @@ def create_ui(): ) txt2img_paste_fields = [ - (toprow.prompt, "Prompt"), - (toprow.negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_name, "Sampler"), - (cfg_scale, "CFG scale"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_name, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), - (hr_prompt, "Hires prompt"), - (hr_negative_prompt, "Hires negative prompt"), - (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + PasteField(toprow.prompt, "Prompt", api="prompt"), + PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"), + PasteField(steps, "Steps", api="steps"), + PasteField(sampler_name, "Sampler", api="sampler_name"), + PasteField(cfg_scale, "CFG scale", api="cfg_scale"), + PasteField(width, "Size-1", api="width"), + PasteField(height, "Size-2", api="height"), + PasteField(batch_size, "Batch size", api="batch_size"), + PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"), + PasteField(denoising_strength, "Denoising strength", api="denoising_strength"), + PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"), + PasteField(hr_scale, "Hires upscale", api="hr_scale"), + PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"), + PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"), + PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"), + PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"), + PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"), + PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"), + PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), + PasteField(hr_prompt, "Hires prompt", api="hr_prompt"), + PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"), + PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) From 59d060fd5ea93fcc3fdbfbd13b6e20fda06ecf94 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:11:03 +0900 Subject: [PATCH 133/311] More lora not found warning --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 985b2753b..72ebd6241 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c5..1518f7e5c 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) From ba92135a2ba9e210ce5370715e2defcb43df70d1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:11:09 +0300 Subject: [PATCH 134/311] add override_settings support for infotext API --- modules/api/api.py | 10 ++++ modules/generation_parameters_copypaste.py | 66 ++++++++++++++-------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index b3d709404..fb108486f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,7 @@ class Api: params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + handled_fields = {} for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: if not field.api: continue @@ -355,6 +356,15 @@ class Api: value = target_type(value) setattr(request, field.api, value) + handled_fields[field.label] = 1 + + if request.override_settings is None: + request.override_settings = {} + + overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) + for infotext_text, setting_name, value in overriden_settings: + if setting_name not in request.override_settings: + request.override_settings[setting_name] = value return params diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4b4727c44..86a36c327 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -390,6 +390,48 @@ def create_override_settings_dict(text_pairs): return res +def get_override_settings(params, *, skip_fields=None): + """Returns a list of settings overrides from the infotext parameters dictionary. + + This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns + a list of tuples containing the parameter name, setting name, and new value cast to correct type. + + It checks for conditions before adding an override: + - ignores settings that match the current value + - ignores parameter keys present in skip_fields argument. + + Example input: + {"Clip skip": "2"} + + Example output: + [("Clip skip", "CLIP_stop_at_last_layers", 2)] + """ + + res = [] + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in (skip_fields or {}): + continue + + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + res.append((param_name, setting_name, v)) + + return res + + def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: @@ -431,29 +473,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, already_handled_fields = {key: 1 for _, key in paste_fields} def paste_settings(params): - vals = {} + vals = get_override_settings(params, skip_fields=already_handled_fields) - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - if param_name in already_handled_fields: - continue - - v = params.get(param_name, None) - if v is None: - continue - - if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: - continue - - v = shared.opts.cast_value(setting_name, v) - current_value = getattr(shared.opts, setting_name, None) - - if v == current_value: - continue - - vals[param_name] = v - - vals_pairs = [f"{k}: {v}" for k, v in vals.items()] + vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) From 8b08b78c03f09898455d54cf099225ed5f8de1ee Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 12:27:23 +0300 Subject: [PATCH 135/311] make it so that if an option from infotext conflicts with an argument from API, the latter overrides the former --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index fb108486f..cabccb4c0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -339,6 +339,7 @@ class Api: if not request.infotext: return {} + set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) handled_fields = {} @@ -346,6 +347,9 @@ class Api: if not field.api: continue + if field.api in set_fields: + continue + value = field.function(params) if field.function else params.get(field.label) target_type = request.__fields__[field.api].type_ @@ -376,7 +380,7 @@ class Api: script_runner.initialize_scripts(False) ui.create_ui() - infotext_params = self.apply_infotext(txt2imgreq, "txt2img") + self.apply_infotext(txt2imgreq, "txt2img") if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) From 0aacd4c72b4008d7153e747301fe8c5ffca57f85 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:33:18 +0300 Subject: [PATCH 136/311] add support for alwayson scripts for infotext API --- modules/api/api.py | 61 +++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index cabccb4c0..946cfe4a9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -312,8 +312,13 @@ class Api: script_args[script.args_from:script.args_to] = ui_default_values return script_args - def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner): + def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None): script_args = default_script_args.copy() + + if input_script_args is not None: + for index, value in input_script_args.items(): + script_args[index] = value + # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run() if selectable_scripts: script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args @@ -335,41 +340,58 @@ class Api: script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] return script_args - def apply_infotext(self, request, tabname): + def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): if not request.infotext: return {} + possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) - handled_fields = {} - for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]: + def get_field_value(field, params): + value = field.function(params) if field.function else params.get(field.label) + if value is None: + return None + + if field.api in request.__fields__: + target_type = request.__fields__[field.api].type_ + else: + target_type = type(field.component.value) + + if target_type == type(None): + return None + + if not isinstance(value, target_type): + value = target_type(value) + + return value + + for field in possible_fields: if not field.api: continue if field.api in set_fields: continue - value = field.function(params) if field.function else params.get(field.label) - target_type = request.__fields__[field.api].type_ - - if value is None: - continue - - if not isinstance(value, target_type): - value = target_type(value) - - setattr(request, field.api, value) - handled_fields[field.label] = 1 + value = get_field_value(field, params) + if value is not None: + setattr(request, field.api, value) if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields) - for infotext_text, setting_name, value in overriden_settings: + overriden_settings = generation_parameters_copypaste.get_override_settings(params) + for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value + if script_runner is not None and mentioned_script_args is not None: + indexes = {v: i for i, v in enumerate(script_runner.inputs)} + script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) + + for field, index in script_fields: + mentioned_script_args[index] = get_field_value(field, params) + return params def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI): @@ -380,7 +402,8 @@ class Api: script_runner.initialize_scripts(False) ui.create_ui() - self.apply_infotext(txt2imgreq, "txt2img") + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_txt2img: self.default_script_arg_txt2img = self.init_default_script_args(script_runner) @@ -400,7 +423,7 @@ class Api: args.pop('alwayson_scripts', None) args.pop('infotext', None) - script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 11a435b4697c2d735a117f31944c4ebe59c2504c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:34:46 +0300 Subject: [PATCH 137/311] img2img support for infotext API --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 946cfe4a9..2c8dc2a0f 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -470,6 +470,10 @@ class Api: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() + + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) @@ -489,7 +493,7 @@ class Api: args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) - script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner) + script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) send_images = args.pop('send_images', True) args.pop('save_images', None) From 8f1826375943718463cec3af97a37886249bdb44 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 13:48:25 +0300 Subject: [PATCH 138/311] fix bad values read from infotext for API, add comment --- modules/api/api.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2c8dc2a0f..2918f7857 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -341,6 +341,13 @@ class Api: return script_args def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None): + """Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext. + + If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored. + + Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext. + """ + if not request.infotext: return {} @@ -361,7 +368,10 @@ class Api: if target_type == type(None): return None - if not isinstance(value, target_type): + if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value + value = value.get('value') + + if value is not None and not isinstance(value, target_type): value = target_type(value) return value @@ -390,7 +400,12 @@ class Api: script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes) for field, index in script_fields: - mentioned_script_args[index] = get_field_value(field, params) + value = get_field_value(field, params) + + if value is None: + continue + + mentioned_script_args[index] = value return params From f0e2e8b930115012976f7c5bae00e243a7ebbf79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 15:12:48 +0300 Subject: [PATCH 139/311] update #14354 --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 3f5b67a30..2c417168a 100644 --- a/webui.py +++ b/webui.py @@ -39,7 +39,7 @@ def api_only(): print(f"Startup time: {startup_timer.summary()}.") api.launch( - server_name=cmd_opts.server_name or ("0.0.0.0" if cmd_opts.listen else "127.0.0.1"), + server_name=initialize_util.gradio_server_name(), port=cmd_opts.port if cmd_opts.port else 7861, root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) From c069c2c5628728c9506dd034ef98e6335fd5bb34 Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 21:32:22 +0800 Subject: [PATCH 140/311] add locks to ensure init args are thread-safe --- modules/api/api.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index f0a68c672..45c5c5077 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,14 +251,10 @@ class Api: self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - script_runner = scripts.scripts_img2img - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) + self.txt2img_script_arg_init_lock = Lock() + self.img2img_script_arg_init_lock = Lock() + + def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -348,6 +344,12 @@ class Api: task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img + with self.txt2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -407,6 +409,12 @@ class Api: mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img + with self.img2img_script_arg_init_lock: + if not script_runner.scripts: + script_runner.initialize_scripts(True) + ui.create_ui() + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 31992eff9b9714c158b12cec16dfe66c76270dfa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 16:51:02 +0300 Subject: [PATCH 141/311] make it possible again to extract styles that have whitespace at the end. --- modules/styles.py | 47 +++++++++++++++++------------------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/modules/styles.py b/modules/styles.py index 81d9800d1..026c43001 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -30,38 +30,29 @@ def apply_styles_to_prompt(prompt, styles): return prompt -def unwrap_style_text_from_prompt(style_text, prompt): - """ - Checks the prompt to see if the style text is wrapped around it. If so, - returns True plus the prompt text without the style text. Otherwise, returns - False with the original prompt. +def extract_style_text_from_prompt(style_text, prompt): + """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt. - Note that the "cleaned" version of the style text is only used for matching - purposes here. It isn't returned; the original style text is not modified. + extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg") + extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg") """ - stripped_prompt = prompt - stripped_style_text = style_text + + stripped_prompt = prompt.strip() + stripped_style_text = style_text.strip() + if "{prompt}" in stripped_style_text: - # Work out whether the prompt is wrapped in the style text. If so, we - # return True and the "inner" prompt text that isn't part of the style. - try: - left, right = stripped_style_text.split("{prompt}", 2) - except ValueError as e: - # If the style text has multple "{prompt}"s, we can't split it into - # two parts. This is an error, but we can't do anything about it. - print(f"Unable to compare style text to prompt:\n{style_text}") - print(f"Error: {e}") - return False, prompt + left, right = stripped_style_text.split("{prompt}", 2) if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): - prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] + prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] return True, prompt else: - # Work out whether the given prompt ends with the style text. If so, we - # return True and the prompt text up to where the style text starts. if stripped_prompt.endswith(stripped_style_text): - prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] - if prompt.endswith(", "): + prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] + + if prompt.endswith(', '): prompt = prompt[:-2] + return True, prompt return False, prompt @@ -76,15 +67,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt - match_positive, extracted_positive = unwrap_style_text_from_prompt( - style.prompt, prompt - ) + match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) if not match_positive: return False, prompt, negative_prompt - match_negative, extracted_negative = unwrap_style_text_from_prompt( - style.negative_prompt, negative_prompt - ) + match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) if not match_negative: return False, prompt, negative_prompt From 7aa27b000a3087dcb5cc7254600064bf70cacd3e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:44:15 +0200 Subject: [PATCH 142/311] Add types to split_grid --- modules/images.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index 16f9ae7cc..d30e8865d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -64,9 +64,8 @@ def image_grid(imgs, batch_size=1, rows=None): Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) -def split_grid(image, tile_w=512, tile_h=512, overlap=64): - w = image.width - h = image.height +def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: + w, h = image.size non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap From 12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:01:45 +0200 Subject: [PATCH 143/311] Add tile_count property to Grid --- modules/images.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index d30e8865d..87a7bf221 100644 --- a/modules/images.py +++ b/modules/images.py @@ -61,7 +61,13 @@ def image_grid(imgs, batch_size=1, rows=None): return grid -Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) +class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])): + @property + def tile_count(self) -> int: + """ + The total number of tiles in the grid. + """ + return sum(len(row[2]) for row in self.tiles) def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: [PATCH 144/311] Refactor esrgan_upscale to more generic upscale_with_model --- modules/esrgan_model.py | 47 +++++----------------------- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 39 deletions(-) create mode 100644 modules/upscaler_utils.py diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d2..c0d22a992 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys -import numpy as np import torch -from PIL import Image import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model def mod2normal(state_dict): @@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler): return model -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') - - def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 000000000..8bdda51c4 --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: [PATCH 145/311] Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- .../ScuNET/scripts/scunet_model.py | 13 +- .../ScuNET/scunet_model_arch.py | 268 ----- .../SwinIR/scripts/swinir_model.py | 126 +- .../SwinIR/swinir_model_arch.py | 867 -------------- .../SwinIR/swinir_model_arch_v2.py | 1017 ----------------- modules/codeformer/codeformer_arch.py | 276 ----- modules/codeformer/vqgan_arch.py | 435 ------- modules/codeformer_model.py | 223 ++-- modules/esrgan_model.py | 153 +-- modules/esrgan_model_arch.py | 465 -------- modules/gfpgan_model.py | 13 +- modules/launch_utils.py | 7 - modules/modelloader.py | 16 + modules/paths.py | 1 - modules/realesrgan_model.py | 153 +-- modules/sysinfo.py | 2 - modules/upscaler.py | 3 + requirements.txt | 3 +- requirements_versions.txt | 4 +- 19 files changed, 277 insertions(+), 3768 deletions(-) delete mode 100644 extensions-builtin/ScuNET/scunet_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch_v2.py delete mode 100644 modules/codeformer/codeformer_arch.py delete mode 100644 modules/codeformer/vqgan_arch.py delete mode 100644 modules/esrgan_model_arch.py diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64b..18cf8e1a0 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device) def on_ui_settings(): diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a88062..000000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a8..85c18b9e9 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,5 +1,5 @@ +import logging import sys -import platform import numpy as np import torch @@ -8,13 +8,11 @@ from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: current_config = (model_file, opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + device = self._get_device() + + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + device=device, + ) devices.torch_gc() return img @@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) + model = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + dtype=devices.dtype, + ) + if getattr(opts, 'SWIN_torch_compile', False): + try: + model = torch.compile(model) + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) return model + def _get_device(self): + return devices.get_device_for('swinir') + def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, + img, + model, + *, + tile: int, + tile_overlap: int, + window_size=8, + scale=4, + device, ): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + img = img.unsqueeze(0).to(device, dtype=devices.dtype) with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) + output = inference( + img, + model, + tile=tile, + tile_overlap=tile_overlap, + window_size=window_size, + scale=scale, + device=device, + ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: @@ -142,7 +128,16 @@ def upscale( return Image.fromarray(output, "RGB") -def inference(img, model, tile, tile_overlap, window_size, scale): +def inference( + img, + model, + *, + tile: int, + tile_overlap: int, + window_size: int, + scale: int, + device, +): # test the image tile by tile b, c, h, w = img.size() tile = min(tile, h, w) @@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: @@ -185,8 +180,7 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b932747..000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca2..000000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=(0, 0)): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py deleted file mode 100644 index 12db68142..000000000 --- a/modules/codeformer/codeformer_arch.py +++ /dev/null @@ -1,276 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -import math -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from typing import Optional - -from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils.registry import ARCH_REGISTRY - -def calc_mean_std(feat, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - - -def adaptive_instance_normalization(content_feat, style_feat): - """Adaptive instance normalization. - - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, x, mask=None): - if mask is None: - mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") - - -class TransformerSALayer(nn.Module): - def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): - super().__init__() - self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) - # Implementation of Feedforward model - MLP - self.linear1 = nn.Linear(embed_dim, dim_mlp) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_mlp, embed_dim) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward(self, tgt, - tgt_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - - # self attention - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - - # ffn - tgt2 = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout2(tgt2) - return tgt - -class Fuse_sft_block(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.encode_enc = ResBlock(2*in_ch, out_ch) - - self.scale = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - self.shift = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - def forward(self, enc_feat, dec_feat, w=1): - enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) - scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out - - -@ARCH_REGISTRY.register() -class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, - codebook_size=1024, latent_size=256, - connect_list=('32', '64', '128', '256'), - fix_modules=('quantize', 'generator')): - super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False - - self.connect_list = connect_list - self.n_layers = n_layers - self.dim_embd = dim_embd - self.dim_mlp = dim_embd*2 - - self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) - self.feat_emb = nn.Linear(256, self.dim_embd) - - # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) - for _ in range(self.n_layers)]) - - # logits_predict head - self.idx_pred_layer = nn.Sequential( - nn.LayerNorm(dim_embd), - nn.Linear(dim_embd, codebook_size, bias=False)) - - self.channels = { - '16': 512, - '32': 256, - '64': 256, - '128': 128, - '256': 128, - '512': 64, - } - - # after second residual block for > 16, before attn layer for ==16 - self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} - # after first residual block for > 16, before attn layer for ==16 - self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} - - # fuse_convs_dict - self.fuse_convs_dict = nn.ModuleDict() - for f_size in self.connect_list: - in_ch = self.channels[f_size] - self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): - # ################### Encoder ##################### - enc_feat_dict = {} - out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] - for i, block in enumerate(self.encoder.blocks): - x = block(x) - if i in out_list: - enc_feat_dict[str(x.shape[-1])] = x.clone() - - lq_feat = x - # ################# Transformer ################### - # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) - pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) - # BCHW -> BC(HW) -> (HW)BC - feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) - query_emb = feat_emb - # Transformer encoder - for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) - - # output logits - logits = self.idx_pred_layer(query_emb) # (hw)bn - logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n - - if code_only: # for training stage II - # logits doesn't need softmax before cross_entropy loss - return logits, lq_feat - - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ - soft_one_hot = F.softmax(logits, dim=2) - _, top_idx = torch.topk(soft_one_hot, 1, dim=2) - quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() - - if detach_16: - quant_feat = quant_feat.detach() # for training stage III - if adain: - quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) - - # ################## Generator #################### - x = quant_feat - fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] - - for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block - f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) - out = x - # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py deleted file mode 100644 index 09ee6660d..000000000 --- a/modules/codeformer/vqgan_arch.py +++ /dev/null @@ -1,435 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -''' -VQGAN code, adapted from the original created by the Unleashing Transformers authors: -https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from basicsr.utils import get_root_logger -from basicsr.utils.registry import ARCH_REGISTRY - -def normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -@torch.jit.script -def swish(x): - return x*torch.sigmoid(x) - - -# Define VQVAE classes -class VectorQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, beta): - super(VectorQuantizer, self).__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) - self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.emb_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) - - mean_distance = torch.mean(d) - # find closest encodings - # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) - # [0-1], higher score, higher confidence - min_encoding_scores = torch.exp(-min_encoding_scores/10) - - min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, { - "perplexity": perplexity, - "min_encodings": min_encodings, - "min_encoding_indices": min_encoding_indices, - "min_encoding_scores": min_encoding_scores, - "mean_distance": mean_distance - } - - def get_codebook_feat(self, indices, shape): - # input indices: batch*token_num -> (batch*token_num)*1 - # shape: batch, height, width, channel - indices = indices.view(-1,1) - min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) - min_encodings.scatter_(1, indices, 1) - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: # reshape back to match original input shape - z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): - super().__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits - self.embed = nn.Embedding(codebook_size, emb_dim) - - def forward(self, z): - hard = self.straight_through if self.training else True - - logits = self.proj(z) - - soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) - - z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() - min_encoding_indices = soft_one_hot.argmax(dim=1) - - return z_q, diff, { - "min_encoding_indices": min_encoding_indices - } - - -class Downsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - - return x - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels=None): - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = normalize(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x_in): - x = x_in - x = self.norm1(x) - x = swish(x) - x = self.conv1(x) - x = self.norm2(x) - x = swish(x) - x = self.conv2(x) - if self.in_channels != self.out_channels: - x_in = self.conv_out(x_in) - - return x + x_in - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c)**(-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Encoder(nn.Module): - def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): - super().__init__() - self.nf = nf - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - curr_res = self.resolution - in_ch_mult = (1,)+tuple(ch_mult) - - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - - # residual and downsampling blocks, with attention on smaller res (16x16) - for i in range(self.num_resolutions): - block_in_ch = nf * in_ch_mult[i] - block_out_ch = nf * ch_mult[i] - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - if curr_res in attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != self.num_resolutions - 1: - blocks.append(Downsample(block_in_ch)) - curr_res = curr_res // 2 - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - # normalise and convert to latent size - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) - self.blocks = nn.ModuleList(blocks) - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -class Generator(nn.Module): - def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): - super().__init__() - self.nf = nf - self.ch_mult = ch_mult - self.num_resolutions = len(self.ch_mult) - self.num_res_blocks = res_blocks - self.resolution = img_size - self.attn_resolutions = attn_resolutions - self.in_channels = emb_dim - self.out_channels = 3 - block_in_ch = self.nf * self.ch_mult[-1] - curr_res = self.resolution // 2 ** (self.num_resolutions-1) - - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - for i in reversed(range(self.num_resolutions)): - block_out_ch = self.nf * self.ch_mult[i] - - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - - if curr_res in self.attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != 0: - blocks.append(Upsample(block_in_ch)) - curr_res = curr_res * 2 - - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) - - self.blocks = nn.ModuleList(blocks) - - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -@ARCH_REGISTRY.register() -class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, - beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): - super().__init__() - logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks - self.codebook_size = codebook_size - self.embed_dim = emb_dim - self.ch_mult = ch_mult - self.resolution = img_size - self.attn_resolutions = attn_resolutions or [16] - self.quantizer_type = quantizer - self.encoder = Encoder( - self.in_channels, - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - if self.quantizer_type == "nearest": - self.beta = beta #0.25 - self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) - elif self.quantizer_type == "gumbel": - self.gumbel_num_hiddens = emb_dim - self.straight_through = gumbel_straight_through - self.kl_weight = gumbel_kl_weight - self.quantize = GumbelQuantizer( - self.codebook_size, - self.embed_dim, - self.gumbel_num_hiddens, - self.straight_through, - self.kl_weight - ) - self.generator = Generator( - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_ema' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) - logger.info(f'vqgan is loaded from: {model_path} [params_ema]') - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - logger.info(f'vqgan is loaded from: {model_path} [params]') - else: - raise ValueError('Wrong params!') - - - def forward(self, x): - x = self.encoder(x) - quant, codebook_loss, quant_stats = self.quantize(x) - x = self.generator(quant) - return x, codebook_loss, quant_stats - - - -# patch based discriminator -@ARCH_REGISTRY.register() -class VQGANDiscriminator(nn.Module): - def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): - super().__init__() - - layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] - ndf_mult = 1 - ndf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n, 8) - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n_layers, 8) - - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - layers += [ - nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map - self.main = nn.Sequential(*layers) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_d' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - else: - raise ValueError('Wrong params!') - - def forward(self, x): - return self.main(x) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index da42b5e99..517eadfd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,9 +8,6 @@ import modules.shared from modules import shared, devices, modelloader, errors from modules.paths import models_path -# codeformer people made a choice to include modified basicsr library to their project which makes -# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. -# I am making a choice to include some files from codeformer to work around this issue. model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' @@ -18,115 +15,127 @@ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codef codeformer = None +class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): + def name(self): + return "CodeFormer" + + def __init__(self, dirname): + self.net = None + self.face_helper = None + self.cmd_dir = dirname + + def create_models(self): + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + + if self.net is not None and self.face_helper is not None: + self.net.to(devices.device_codeformer) + return self.net, self.face_helper + model_paths = modelloader.load_models( + model_path, + model_url, + self.cmd_dir, + download_name='codeformer-v0.1.0.pth', + ext_filter=['.pth'], + ) + + if len(model_paths) != 0: + ckpt_path = model_paths[0] + else: + print("Unable to load codeformer model.") + return None, None + net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) + + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=devices.device_codeformer, + ) + + self.net = net + self.face_helper = face_helper + + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def restore(self, np_image, w=None): + from torchvision.transforms.functional import normalize + from basicsr.utils import img2tensor, tensor2img + np_image = np_image[:, :, ::-1] + + original_resolution = np_image.shape[0:2] + + self.create_models() + if self.net is None or self.face_helper is None: + return np_image + + self.send_model_to(devices.device_codeformer) + + self.face_helper.clean_all() + self.face_helper.read_image(np_image) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() + + for cropped_face in self.face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) + if isinstance(res, tuple): + output = res[0] + else: + output = res + if not isinstance(res, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(res)}") + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + devices.torch_gc() + except Exception: + errors.report('Failed inference for CodeFormer', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) + + self.face_helper.get_inverse_affine(None) + + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] + + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize( + restored_img, + (0, 0), + fx=original_resolution[1]/restored_img.shape[1], + fy=original_resolution[0]/restored_img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + + self.face_helper.clean_all() + + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + return restored_img + + def setup_model(dirname): os.makedirs(model_path, exist_ok=True) - - path = modules.paths.paths.get("CodeFormer", None) - if path is None: - return - try: - from torchvision.transforms.functional import normalize - from modules.codeformer.codeformer_arch import CodeFormer - from basicsr.utils import img2tensor, tensor2img - from facelib.utils.face_restoration_helper import FaceRestoreHelper - from facelib.detection.retinaface import retinaface - - net_class = CodeFormer - - class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): - def name(self): - return "CodeFormer" - - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() - - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) - - self.net = net - self.face_helper = face_helper - - return net, face_helper - - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) - - def restore(self, np_image, w=None): - np_image = np_image[:, :, ::-1] - - original_resolution = np_image.shape[0:2] - - self.create_models() - if self.net is None or self.face_helper is None: - return np_image - - self.send_model_to(devices.device_codeformer) - - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() - - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - - try: - with torch.no_grad(): - output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - self.face_helper.get_inverse_affine(None) - - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) - except Exception: errors.report("Error setting up CodeFormer", exc_info=True) - - # sys.path = stored_sys_path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c0d22a992..a7c7c9e30 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,122 +1,9 @@ -import sys - -import torch - -import modules.esrgan_model_arch as arch -from modules import modelloader, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale - - class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + ) def esrgan_upscale(model, img): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py deleted file mode 100644 index 2b9888baf..000000000 --- a/modules/esrgan_model_arch.py +++ /dev/null @@ -1,465 +0,0 @@ -# this file is adapted from https://github.com/victorca25/iNNfer - -from collections import OrderedDict -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -#################### -# RRDBNet Generator -#################### - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, - act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', - finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.resrgan_scale = 0 - if in_nc % 16 == 0: - self.resrgan_scale = 1 - elif in_nc != 4 and in_nc % 4 == 0: - self.resrgan_scale = 2 - - fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] - LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) - - if upsample_mode == 'upconv': - upsample_block = upconv_block - elif upsample_mode == 'pixelshuffle': - upsample_block = pixelshuffle_block - else: - raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] - HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) - HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - - outact = act(finalact) if finalact else None - - self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), - *upsampler, HR_conv0, HR_conv1, outact) - - def forward(self, x, outm=None): - if self.resrgan_scale == 1: - feat = pixel_unshuffle(x, scale=4) - elif self.resrgan_scale == 2: - feat = pixel_unshuffle(x, scale=2) - else: - feat = x - - return self.model(feat) - - -class RRDB(nn.Module): - """ - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - """ - - def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() - # This is for backwards compatibility with existing models - if nr == 3: - self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - else: - RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] - self.RDBs = nn.Sequential(*RDB_list) - - def forward(self, x): - if hasattr(self, 'RDB1'): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - else: - out = self.RDBs(x) - return out * 0.2 + x - - -class ResidualDenseBlock_5C(nn.Module): - """ - Residual Dense Block - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - Modified options that can be used: - - "Partial Convolution based Padding" arXiv:1811.11718 - - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - {Rakotonirina} and A. {Rasoanaivo} - """ - - def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() - - self.noise = GaussianNoise() if gaussian_noise else None - self.conv1x1 = conv1x1(nf, gc) if plus else None - - self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - if self.conv1x1: - x2 = x2 + self.conv1x1(x) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - if self.conv1x1: - x4 = x4 + x2 - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - if self.noise: - return self.noise(x5.mul(0.2) + x) - else: - return x5 * 0.2 + x - - -#################### -# ESRGANplus -#################### - -class GaussianNoise(nn.Module): - def __init__(self, sigma=0.1, is_relative_detach=False): - super().__init__() - self.sigma = sigma - self.is_relative_detach = is_relative_detach - self.noise = torch.tensor(0, dtype=torch.float) - - def forward(self, x): - if self.training and self.sigma != 0: - self.noise = self.noise.to(x.device) - scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x - sampled_noise = self.noise.repeat(*x.size()).normal_() * scale - x = x + sampled_noise - return x - -def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -#################### -# SRVGGNetCompact -#################### - -class SRVGGNetCompact(nn.Module): - """A compact VGG-style network structure for super-resolution. - This class is copied from https://github.com/xinntao/Real-ESRGAN - """ - - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() - self.num_in_ch = num_in_ch - self.num_out_ch = num_out_ch - self.num_feat = num_feat - self.num_conv = num_conv - self.upscale = upscale - self.act_type = act_type - - self.body = nn.ModuleList() - # the first conv - self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) - # the first activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the body structure - for _ in range(num_conv): - self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) - # activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the last conv - self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) - # upsample - self.upsampler = nn.PixelShuffle(upscale) - - def forward(self, x): - out = x - for i in range(0, len(self.body)): - out = self.body[i](out) - - out = self.upsampler(out) - # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') - out += base - return out - - -#################### -# Upsampler -#################### - -class Upsample(nn.Module): - r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. - The input data is assumed to be of the form - `minibatch x channels x [optional depth] x [optional height] x width`. - """ - - def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() - if isinstance(scale_factor, tuple): - self.scale_factor = tuple(float(factor) for factor in scale_factor) - else: - self.scale_factor = float(scale_factor) if scale_factor else None - self.mode = mode - self.size = size - self.align_corners = align_corners - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - - def extra_repr(self): - if self.scale_factor is not None: - info = f'scale_factor={self.scale_factor}' - else: - info = f'size={self.size}' - info += f', mode={self.mode}' - return info - - -def pixel_unshuffle(x, scale): - """ Pixel unshuffle. - Args: - x (Tensor): Input feature with shape (b, c, hh, hw). - scale (int): Downsample ratio. - Returns: - Tensor: the pixel unshuffled feature. - """ - b, c, hh, hw = x.size() - out_channel = c * (scale**2) - assert hh % scale == 0 and hw % scale == 0 - h = hh // scale - w = hw // scale - x_view = x.view(b, c, h, scale, w, scale) - return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): - """ - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - """ - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): - """ Upconv layer """ - upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor - upsample = Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) - return sequential(upsample, conv) - - - - - - - - -#################### -# Basic blocks -#################### - - -def make_layer(basic_block, num_basic_block, **kwarg): - """Make layers by stacking the same blocks. - Args: - basic_block (nn.module): nn.module class for basic block. (block) - num_basic_block (int): number of blocks. (n_layers) - Returns: - nn.Sequential: Stacked blocks in nn.Sequential. - """ - layers = [] - for _ in range(num_basic_block): - layers.append(basic_block(**kwarg)) - return nn.Sequential(*layers) - - -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): - """ activation helper """ - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type in ('leakyrelu', 'lrelu'): - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - elif act_type == 'tanh': # [-1, 1] range output - layer = nn.Tanh() - elif act_type == 'sigmoid': # [0, 1] range output - layer = nn.Sigmoid() - else: - raise NotImplementedError(f'activation layer [{act_type}] is not found') - return layer - - -class Identity(nn.Module): - def __init__(self, *kwargs): - super(Identity, self).__init__() - - def forward(self, x, *kwargs): - return x - - -def norm(norm_type, nc): - """ Return a normalization layer """ - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'none': - def norm_layer(x): return Identity() - else: - raise NotImplementedError(f'normalization layer [{norm_type}] is not found') - return layer - - -def pad(pad_type, padding): - """ padding layer helper """ - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - elif pad_type == 'zero': - layer = nn.ZeroPad2d(padding) - else: - raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ShortcutBlock(nn.Module): - """ Elementwise sum the output of a submodule to its input """ - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = x + self.sub(x) - return output - - def __repr__(self): - return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') - - -def sequential(*args): - """ Flatten Sequential. It unwraps nn.Sequential. """ - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', - spectral_norm=False): - """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - if convtype=='PartialConv2D': - from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer - c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='DeformConv2D': - from torchvision.ops import DeformConv2d # not tested - c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='Conv3D': - c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - else: - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - - if spectral_norm: - c = nn.utils.spectral_norm(c) - - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 01d668ecd..6b6f17c43 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,8 +1,5 @@ import os -import facexlib -import gfpgan - import modules.face_restoration from modules import paths, shared, devices, modelloader, errors @@ -41,6 +38,8 @@ def gfpgann(): print("Unable to load gfpgan model!") return None + import facexlib.detection.retinaface + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan model_file_path = model_file @@ -81,8 +80,10 @@ gfpgan_constructor = None def setup_model(dirname): try: os.makedirs(model_path, exist_ok=True) - from gfpgan import GFPGANer - from facexlib import detection, parsing # noqa: F401 + import gfpgan + import facexlib.detection + import facexlib.parsing + global user_path global have_gfpgan global gfpgan_constructor @@ -111,7 +112,7 @@ def setup_model(dirname): facexlib.parsing.load_file_from_url = facex_load_file_from_url2 user_path = dirname have_gfpgan = True - gfpgan_constructor = GFPGANer + gfpgan_constructor = gfpgan.GFPGANer class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index dabef0f53..c2cbd8ce7 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -345,13 +345,11 @@ def prepare_environment(): stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: @@ -408,15 +406,10 @@ def prepare_environment(): git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) startup_timer.record("clone repositores") - if not is_installed("lpips"): - run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - startup_timer.record("install CodeFormer requirements") - if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb793..30116932a 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil import importlib @@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +181,15 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model(path, *, device, half: bool = False, dtype=None): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model diff --git a/modules/paths.py b/modules/paths.py index 187b94961..030646519 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,7 +38,6 @@ mute_sdxl_imports() path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 02841c302..332d8f4b1 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,6 @@ import os -import numpy as np -from PIL import Image -from realesrgan import RealESRGANer - +from modules.upscaler_utils import upscale_with_model from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader, errors @@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler): self.name = "RealESRGAN" self.user_path = path super().__init__() - try: - from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 - from realesrgan import RealESRGANer # noqa: F401 - from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401 - self.enable = True - self.scalers = [] - scalers = self.load_models(path) + self.enable = True + self.scalers = [] + scalers = get_realesrgan_models(self) - local_model_paths = self.find_models(ext_filter=[".pth"]) - for scaler in scalers: - if scaler.local_data_path.startswith("http"): - filename = modelloader.friendly_name(scaler.local_data_path) - local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] - if local_model_candidates: - scaler.local_data_path = local_model_candidates[0] + local_model_paths = self.find_models(ext_filter=[".pth"]) + for scaler in scalers: + if scaler.local_data_path.startswith("http"): + filename = modelloader.friendly_name(scaler.local_data_path) + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] - if scaler.name in opts.realesrgan_enabled_models: - self.scalers.append(scaler) - - except Exception: - errors.report("Error importing Real-ESRGAN", exc_info=True) - self.enable = False - self.scalers = [] + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) def do_upscale(self, img, path): if not self.enable: @@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - upsampler = RealESRGANer( - scale=info.scale, - model_path=info.local_data_path, - model=info.model(), - half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, + mod = modelloader.load_spandrel_model( + info.local_data_path, device=self.device, + half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + ) + return upscale_with_model( + mod, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + # TODO: `outscale`? ) - - upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] - - image = Image.fromarray(upsampled) - return image def load_model(self, path): for scaler in self.scalers: @@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler): return scaler raise ValueError(f"Unable to find model info: {path}") - def load_models(self, _): - return get_realesrgan_models(self) - -def get_realesrgan_models(scaler): - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - models = [ - UpscalerData( - name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN General WDN 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN AnimeVideo", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN 4x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 4x+ Anime6B", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 2x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - scale=2, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), - ] - return models - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) +def get_realesrgan_models(scaler: UpscalerRealESRGAN): + return [ + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, + ), + ] diff --git a/modules/sysinfo.py b/modules/sysinfo.py index b669edd0c..5abf616b7 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -26,11 +26,9 @@ environment_whitelist = { "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", - "CODEFORMER_REPO", "BLIP_REPO", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", - "CODEFORMER_COMMIT_HASH", "BLIP_COMMIT_HASH", "COMMANDLINE_ARGS", "IGNORE_CMD_ARGS_ERRORS", diff --git a/modules/upscaler.py b/modules/upscaler.py index b256e085b..3aee69db8 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -98,6 +98,9 @@ class UpscalerData: self.scale = scale self.model = model + def __repr__(self): + return f"" + class UpscalerNone(Upscaler): name = "None" diff --git a/requirements.txt b/requirements.txt index 80b438455..36f5674ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ basicsr blendmodes clean-fid einops +facexlib fastapi>=0.90.1 gfpgan gradio==3.41.2 @@ -20,13 +21,11 @@ open-clip-torch piexif psutil pytorch_lightning -realesrgan requests resize-right safetensors scikit-image>=0.19 -timm tomesd torch torchdiffeq diff --git a/requirements_versions.txt b/requirements_versions.txt index cb7403a9d..042fa708c 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,6 +5,7 @@ basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 +facexlib==0.3.0 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.41.2 @@ -19,11 +20,10 @@ open-clip-torch==2.20.0 piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 -realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 scikit-image==0.21.0 -timm==0.9.2 +spandrel==0.1.6 tomesd==0.1.3 torch torchdiffeq==0.2.3 From b621a63cf68c788487684250856707cb352b82d0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 23:01:02 +0200 Subject: [PATCH 146/311] Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN --- .github/workflows/run_tests.yaml | 8 ++ .gitignore | 1 + modules/codeformer_model.py | 158 ++++++++--------------------- modules/face_restoration_utils.py | 163 ++++++++++++++++++++++++++++++ modules/gfpgan_model.py | 154 +++++++++------------------- requirements.txt | 1 - requirements_versions.txt | 1 - test/conftest.py | 15 ++- test/test_face_restorers.py | 29 ++++++ test/test_files/two-faces.jpg | Bin 0 -> 14768 bytes test/test_outputs/.gitkeep | 0 11 files changed, 302 insertions(+), 228 deletions(-) create mode 100644 modules/face_restoration_utils.py create mode 100644 test/test_face_restorers.py create mode 100644 test/test_files/two-faces.jpg create mode 100644 test/test_outputs/.gitkeep diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 3dafaf8dc..cd5c3f868 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,6 +20,12 @@ jobs: cache-dependency-path: | **/requirements*txt launch.py + - name: Cache models + id: cache-models + uses: actions/cache@v3 + with: + path: models + key: "2023-12-30" - name: Install test dependencies run: pip install wait-for-it -r requirements-test.txt env: @@ -33,6 +39,8 @@ jobs: TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu WEBUI_LAUNCH_LIVE_OUTPUT: "1" PYTHONUNBUFFERED: "1" + - name: Print installed packages + run: pip freeze - name: Start test server run: > python -m coverage run diff --git a/.gitignore b/.gitignore index 09734267f..6790e9ee7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ notification.mp3 /node_modules /package-lock.json /.coverage* +/test/test_outputs diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 517eadfd8..ceda4bab9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,140 +1,62 @@ -import os +from __future__ import annotations + +import logging -import cv2 import torch -import modules.face_restoration -import modules.shared -from modules import shared, devices, modelloader, errors -from modules.paths import models_path +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) + +logger = logging.getLogger(__name__) -model_dir = "Codeformer" -model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' +model_download_name = 'codeformer-v0.1.0.pth' -codeformer = None +# used by e.g. postprocessing_codeformer.py +codeformer: face_restoration.FaceRestoration | None = None -class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): +class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): def name(self): return "CodeFormer" - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - from facexlib.detection import retinaface - from facexlib.utils.face_restoration_helper import FaceRestoreHelper - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models( - model_path, - model_url, - self.cmd_dir, - download_name='codeformer-v0.1.0.pth', + def load_net(self) -> torch.Module: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, ext_filter=['.pth'], - ) + ): + return modelloader.load_spandrel_model( + model_path, + device=devices.device_codeformer, + ).model + raise ValueError("No codeformer model found") - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) + def get_device(self): + return devices.device_codeformer - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer + def restore(self, np_image, w: float | None = None): + if w is None: + w = getattr(shared.opts, "code_former_weight", 0.5) - face_helper = FaceRestoreHelper( - upscale_factor=1, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - use_parse=True, - device=devices.device_codeformer, - ) + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, w=w, adain=True)[0] - self.net = net - self.face_helper = face_helper - - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) - - def restore(self, np_image, w=None): - from torchvision.transforms.functional import normalize - from basicsr.utils import img2tensor, tensor2img - np_image = np_image[:, :, ::-1] - - original_resolution = np_image.shape[0:2] - - self.create_models() - if self.net is None or self.face_helper is None: - return np_image - - self.send_model_to(devices.device_codeformer) - - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() - - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - - try: - with torch.no_grad(): - res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) - if isinstance(res, tuple): - output = res[0] - else: - output = res - if not isinstance(res, torch.Tensor): - raise TypeError(f"Expected torch.Tensor, got {type(res)}") - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - self.face_helper.get_inverse_affine(None) - - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize( - restored_img, - (0, 0), - fx=original_resolution[1]/restored_img.shape[1], - fy=original_resolution[0]/restored_img.shape[0], - interpolation=cv2.INTER_LINEAR, - ) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img + return self.restore_with_helper(np_image, restore_face) -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) +def setup_model(dirname: str) -> None: + global codeformer try: - global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) except Exception: diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py new file mode 100644 index 000000000..c65c85ef8 --- /dev/null +++ b/modules/face_restoration_utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import logging +import os +from functools import cached_property +from typing import TYPE_CHECKING, Callable + +import cv2 +import numpy as np +import torch + +from modules import devices, errors, face_restoration, shared + +if TYPE_CHECKING: + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +logger = logging.getLogger(__name__) + + +def create_face_helper(device) -> FaceRestoreHelper: + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + if hasattr(retinaface, 'device'): + retinaface.device = device + return FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=device, + ) + + +def restore_with_face_helper( + np_image: np.ndarray, + face_helper: FaceRestoreHelper, + restore_face: Callable[[np.ndarray], np.ndarray], +) -> np.ndarray: + """ + Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. + + `restore_face` should take a cropped face image and return a restored face image. + """ + from basicsr.utils import img2tensor, tensor2img + from torchvision.transforms.functional import normalize + np_image = np_image[:, :, ::-1] + original_resolution = np_image.shape[0:2] + + try: + logger.debug("Detecting faces...") + face_helper.clean_all() + face_helper.read_image(np_image) + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) + for cropped_face in face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + + try: + with torch.no_grad(): + restored_face = tensor2img( + restore_face(cropped_face_t), + rgb2bgr=True, + min_max=(-1, 1), + ) + devices.torch_gc() + except Exception: + errors.report('Failed face-restoration inference', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + logger.debug("Merging restored faces into image") + face_helper.get_inverse_affine(None) + img = face_helper.paste_faces_to_input_image() + img = img[:, :, ::-1] + if original_resolution != img.shape[0:2]: + img = cv2.resize( + img, + (0, 0), + fx=original_resolution[1] / img.shape[1], + fy=original_resolution[0] / img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) + logger.debug("Face restoration complete") + finally: + face_helper.clean_all() + return img + + +class CommonFaceRestoration(face_restoration.FaceRestoration): + net: torch.Module | None + model_url: str + model_download_name: str + + def __init__(self, model_path: str): + super().__init__() + self.net = None + self.model_path = model_path + os.makedirs(model_path, exist_ok=True) + + @cached_property + def face_helper(self) -> FaceRestoreHelper: + return create_face_helper(self.get_device()) + + def send_model_to(self, device): + if self.net: + logger.debug("Sending %s to %s", self.net, device) + self.net.to(device) + if self.face_helper: + logger.debug("Sending face helper to %s", device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def get_device(self): + raise NotImplementedError("get_device must be implemented by subclasses") + + def load_net(self) -> torch.Module: + raise NotImplementedError("load_net must be implemented by subclasses") + + def restore_with_helper( + self, + np_image: np.ndarray, + restore_face: Callable[[np.ndarray], np.ndarray], + ) -> np.ndarray: + try: + if self.net is None: + self.net = self.load_net() + except Exception: + logger.warning("Unable to load face-restoration model", exc_info=True) + return np_image + + try: + self.send_model_to(self.get_device()) + return restore_with_face_helper(np_image, self.face_helper, restore_face) + finally: + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) + + +def patch_facexlib(dirname: str) -> None: + import facexlib.detection + import facexlib.parsing + + det_facex_load_file_from_url = facexlib.detection.load_file_from_url + par_facex_load_file_from_url = facexlib.parsing.load_file_from_url + + def update_kwargs(kwargs): + return dict(kwargs, save_dir=dirname, model_dir=None) + + def facex_load_file_from_url(**kwargs): + return det_facex_load_file_from_url(**update_kwargs(kwargs)) + + def facex_load_file_from_url2(**kwargs): + return par_facex_load_file_from_url(**update_kwargs(kwargs)) + + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6b6f17c43..a356b56fe 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,126 +1,68 @@ +from __future__ import annotations + +import logging import os -import modules.face_restoration -from modules import paths, shared, devices, modelloader, errors +from modules import ( + devices, + errors, + face_restoration, + face_restoration_utils, + modelloader, + shared, +) -model_dir = "GFPGAN" -user_path = None -model_path = os.path.join(paths.models_path, model_dir) -model_file_path = None +logger = logging.getLogger(__name__) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -have_gfpgan = False -loaded_gfpgan_model = None +model_download_name = "GFPGANv1.4.pth" +gfpgan_face_restorer: face_restoration.FaceRestoration | None = None -def gfpgann(): - global loaded_gfpgan_model - global model_path - global model_file_path - if loaded_gfpgan_model is not None: - loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) - return loaded_gfpgan_model +class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): + def name(self): + return "GFPGAN" - if gfpgan_constructor is None: - return None + def get_device(self): + return devices.device_gfpgan - models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth']) + def load_net(self) -> None: + for model_path in modelloader.load_models( + model_path=self.model_path, + model_url=model_url, + command_path=self.model_path, + download_name=model_download_name, + ext_filter=['.pth'], + ): + if 'GFPGAN' in os.path.basename(model_path): + net = modelloader.load_spandrel_model( + model_path, + device=self.get_device(), + ).model + net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return net + raise ValueError("No GFPGAN model found") - if len(models) == 1 and models[0].startswith("http"): - model_file = models[0] - elif len(models) != 0: - gfp_models = [] - for item in models: - if 'GFPGAN' in os.path.basename(item): - gfp_models.append(item) - latest_file = max(gfp_models, key=os.path.getctime) - model_file = latest_file - else: - print("Unable to load gfpgan model!") - return None + def restore(self, np_image): + def restore_face(cropped_face_t): + assert self.net is not None + return self.net(cropped_face_t, return_rgb=False)[0] - import facexlib.detection.retinaface - - if hasattr(facexlib.detection.retinaface, 'device'): - facexlib.detection.retinaface.device = devices.device_gfpgan - model_file_path = model_file - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) - loaded_gfpgan_model = model - - return model - - -def send_model_to(model, device): - model.gfpgan.to(device) - model.face_helper.face_det.to(device) - model.face_helper.face_parse.to(device) + return self.restore_with_helper(np_image, restore_face) def gfpgan_fix_faces(np_image): - model = gfpgann() - if model is None: - return np_image - - send_model_to(model, devices.device_gfpgan) - - np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) - np_image = gfpgan_output_bgr[:, :, ::-1] - - model.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - send_model_to(model, devices.cpu) - + if gfpgan_face_restorer: + return gfpgan_face_restorer.restore(np_image) + logger.warning("GFPGAN face restorer not set up") return np_image -gfpgan_constructor = None +def setup_model(dirname: str) -> None: + global gfpgan_face_restorer - -def setup_model(dirname): try: - os.makedirs(model_path, exist_ok=True) - import gfpgan - import facexlib.detection - import facexlib.parsing - - global user_path - global have_gfpgan - global gfpgan_constructor - global model_file_path - - facexlib_path = model_path - - if dirname is not None: - facexlib_path = dirname - - load_file_from_url_orig = gfpgan.utils.load_file_from_url - facex_load_file_from_url_orig = facexlib.detection.load_file_from_url - facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url - - def my_load_file_from_url(**kwargs): - return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path)) - - def facex_load_file_from_url(**kwargs): - return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - def facex_load_file_from_url2(**kwargs): - return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None)) - - gfpgan.utils.load_file_from_url = my_load_file_from_url - facexlib.detection.load_file_from_url = facex_load_file_from_url - facexlib.parsing.load_file_from_url = facex_load_file_from_url2 - user_path = dirname - have_gfpgan = True - gfpgan_constructor = gfpgan.GFPGANer - - class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): - def name(self): - return "GFPGAN" - - def restore(self, np_image): - return gfpgan_fix_faces(np_image) - - shared.face_restorers.append(FaceRestorerGFPGAN()) + face_restoration_utils.patch_facexlib(dirname) + gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) + shared.face_restorers.append(gfpgan_face_restorer) except Exception: errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/requirements.txt b/requirements.txt index 36f5674ad..b1329c9e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,6 @@ clean-fid einops facexlib fastapi>=0.90.1 -gfpgan gradio==3.41.2 inflection jsonmerge diff --git a/requirements_versions.txt b/requirements_versions.txt index 042fa708c..edbb6db9e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -7,7 +7,6 @@ clean-fid==0.1.35 einops==0.4.1 facexlib==0.3.0 fastapi==0.94.0 -gfpgan==1.3.8 gradio==3.41.2 httpcore==0.15 inflection==0.5.1 diff --git a/test/conftest.py b/test/conftest.py index 31a5d9eaf..e4fc56785 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,10 +1,16 @@ +import base64 import os import pytest -import base64 - test_files_path = os.path.dirname(__file__) + "/test_files" +test_outputs_path = os.path.dirname(__file__) + "/test_outputs" + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") def file_to_base64(filename): @@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str: @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) + + +@pytest.fixture(scope="session") +def initialize() -> None: + import webui # noqa: F401 diff --git a/test/test_face_restorers.py b/test/test_face_restorers.py new file mode 100644 index 000000000..7760d51bf --- /dev/null +++ b/test/test_face_restorers.py @@ -0,0 +1,29 @@ +import os +from test.conftest import test_files_path, test_outputs_path + +import numpy as np +import pytest +from PIL import Image + + +@pytest.mark.usefixtures("initialize") +@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) +def test_face_restorers(restorer_name): + from modules import shared + + if restorer_name == "gfpgan": + from modules import gfpgan_model + gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) + restorer = gfpgan_model.gfpgan_fix_faces + elif restorer_name == "codeformer": + from modules import codeformer_model + codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) + restorer = codeformer_model.codeformer.restore + else: + raise NotImplementedError("...") + img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) + np_img = np.array(img, dtype=np.uint8) + fixed_image = restorer(np_img) + assert fixed_image.shape == np_img.shape + assert not np.allclose(fixed_image, np_img) # should have visibly changed + Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) diff --git a/test/test_files/two-faces.jpg b/test/test_files/two-faces.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c9d1b01032a7298d76608c8b65cbb243463491c5 GIT binary patch literal 14768 zcmbVz^LHju)9n-6wr$(Cp4hfEu`#i2+t$QRCUz#~#7-u-dEfhe_Yb(${adZmeb(t! zyQ=oC{#yIm1t7~x%18o0KmY)c?+fs?2?zy%{yzf)0|f&E2m3Z~Nbvs{Bn$*3#PQS_fQ*KMf`*NSO+-Ws^8eOpBU(=V^Q~-H4M7#g@dO9zX4+!d2gNfw(CEqw0*5$;BaF${~6H~tc~o# z_=CaCf^y3T4B`tXfI!ujEhIeVg!1U-Bn4;cI5D`Ju6(vFJnmH}txN}~biq~yee(OMc zX_Jzrvr-8Bw`n3>f?;Cgig5eMy#^l4FnC!wf9R7c|hQBhx+$hVSM_`$oH z3~XPl^$2<}H+=zm8571hkCPGKCRb?{E$vDwNlQL-5iK&L{w#Z#3EA3TYvG*EMkU1A zk*f@MRjxZwZPcwl6=$wCx~_7voQc{@VAPv_1Mcv z`?M)V>uVmKNSu~-A1(SP)dkwC{A$JfZ7G=XUeT`6z|i2-qhh%}ISi2tHw>$lOqOhj z80$M#{939h#mRz#1mq9@yzVS2@>OrDJNzd{hNN*{6dH@~z%; z?_|t|JKV$;ai&gSEjH!3z?qngyNPx;f$?n*?QUUSDi&9I^k8C~cmHpV1`+%KDCB>J6^5GDC*R!s<)wn-EW>O=K2v@>)eR0r(CaP{?EQG%`~9QN0k3O&OD5n zv+O>Q0h3CtJ%_WCHSASGj^tjKYdaX8(@a6GE0!9gP@_-W9V_|D<>N*&PpqLjpp zWEAv+GrSV**8AFNbkaO0^FH6{68!3zx~{P%{TILa$ym*;B3zP?WGxg*{gtk*mtmmp-mix7hY4GYsC|j-3b9cuH>`3G?5;DxwaK1qAg2p zk8a@rb3dz0fy*sq(x?ljVJ>OT7eL@mCbhr9u`$wC$Yw~s@IsxUMdfN?u3VG<;aQ0G zS2l*kKwNjpZ9X#|st`#V)cI3d!V_$b$vs{X@I;`tK18k2>ylT;IBD23zamiA{*(U5l0p3NCg<+fZlkL+ghzHTA$I+6~3etKo*LhX2k28WLQ zPtwLRF`0J6?w5a?0(Q{OvN5?CDGR|&h@|VxmCFk)3;m-}F(#HV&AG?4gO2=E5sA#P zsPc4U82o9>P`Cyi_TgI&;PKHnhNmA-R2(9qQGfONaKDvbHV2~}ug|=&wH2$hD6`$3 zFQjzIxq8<{JXwmkNrf_#XA)JDz(fy`1Uk8YAN^9#E={NSy#8-l0)*qj0G3+aRf->6 zUUc^S7eL07CZ64MZ5^&YOG+f#6v90PSA8S5cSi=rHAlIo&JJp>ckZ?a!IIIyw&Y&U zLBVz-FS=6h@?RJ!Mw=(~>DE(cpsS3kXfXbb*HgO39vvZ7tZe+PJ3m9OU$TMPZ<1fj z&yMmWXa2r;988u&QX#9T2Ketp?@aH8FUrNST=ne6e2IsuuqJ9p(OT|XE9>a0jP|4uJ+r?97atpquh)gM_oaT9Ft zcPG=KTYk`BE0!xc{HJSk;2PBppVp+CcJix>YQ9Aud*-p-2$iGhs_dG)_E{WB@e9!Y z7wJS(YA?JLK}u@)9^W2A+_j+M!E0gdqlt2=b5gtToMp+i%i5VX9c03-hns&fj%Md>yZE&1nx4l*d{?ORjPx|pXltABl>?3+$hB!l$(G2hl<(` zZBlKse0pW#g-`bg#(pz->grd0))GLrOn5g$2(e&2c`-Pk#puFQ_go5^_ncMp5sC#J zs}F`RL(Wwkb^WP4X+cHMfepct2*70F~zf38+^r1FWXYI;duNz-}dbH%dF=X$Yx zqDsf1A-~)%c1lp2nkAP7@sqt-FkolJt^5-Isyao9V7HEMgR%HU+0POILSrL2zqFW} zaLdH0)c*2m8>)nNRD;Qw%_{Ma3)Rc%5|~zXTF-@C^L>`$&ZC1X{dzL0-^sRx#CdXU z>vmXjS}m#vY!s?sG(W#g)xHLn4H*Ujm^$a?&{UZSU8wi+b)O3vxFt=le2mPpf{BmR zdO=OMY9FP5&0YAu^m4TvL%+F1=os$#@E3rob#7&}YGw+(qMKqrqrj8Q1;*q}eZADq@js(w)d+ zhu#TcSAp=J-BFGdMA5)ow1tP>=X4223u+f#KHGsB6?0L_xO>^T-IH`O4%uc46ZZLQ z(MToV8@NdJ{F??#C$ossHx!B>13gZ~4SU_u~HBsK$?Iuz56+i*5&dP`br zS;yQ2>gs%Wcy+?^WPb!l)QI-uyUKpXwN&{^95wbFUFFEhGL_K`sQi3(A(HUJ>V@fHUg)r%wQAIEp%wy~kU;62 z)nY-IzuO7{p-yLmY$lre!Yq~%D{~j+VDG7LT*An-mUNd{-q2TU&=A{}&D^`lh%puI zS2QoHTyxNx^g@4h%JPVgi&|UB`_)8yC?RYUxEoY2^UUHF=rD#r1S9bkaEL&jx6WI} zXL799Gy`{*Ik}~LDAOI9vL8OM&a>Uv;W*}I;W)MNmvPElyO=gwaP_&!)I@aQ3hdoj zw>e@-;Br-HC$C^^PDgLt~at{&lZA4JAmCqC4++}HIh z+o9z!!p}t*hJXptjHd2w)dH!-^9Fl7n;2?FC+vqp6!fZQ2`J||^ma8ELE^nwc4&V- zCDVx{egZk)5d>x1Vl0`n!A-5L=iaKfcGNww=nEin*4bm#z-0X{Fh?eu@GBto&F5_} z0u@r*oAwL_xI0ecoU%5VJd2WwY&qXj4b@XGpH=erdmKN=bgQ7DF1dRKHajdF5C2w4 z<}q-nSsDo%6at4311$UvkHE|O_s0zWa!FhRNS{wJDCHl?^@l3v=wXvAnv$T8Gr%&p zL10U#5kkm~9F1MdY=2{1y7DG%2ADMPTkru#2WURYedE+b0>U7%sBz~DRsKgd z{2{?kDRVJ^m7M~nGC~#xB$pDgo0vdsY#PHJZlx2yq#`APvUIpM#EZj)ANoTotLSPP z6doT9C$SBe{3zWQKzoI4i)AIYTwh9E0h~TI5CDp~=+5?DWqQ1>w$#w1mG$=-_{1Gf zI13@?zRxD1E~jFXeq)>YZV@LFao_$4RLo!d%VTBB?iV{m@v?aeZ02YZpK6G19jZ_# ztn!EzQRB{qt?V07M+^hD7pDqoKqx{b&^iXv1_VjmKoIm)VPl^W8Lb5G`^#_9Oo;SFq9`L@%b%&+Rww%V@4?GZRS@cRevcwZ+FiD zuI?Gtj)yOmk|_wE4ejUrehRqw%Q22&a_NAwFNoL;wLb*NJBZpBAPpT;@R9=@SwTT% z_r>nqEA8Dij4WfOFDXq&x-fux_xU&Et9aw(`9%AH38<{MkVNTjc3 z%iKj7d&eaq5iEML1RO)QI=@t>B>OD~T*u{V8qS6K+`5!m_7UH2CmPOs*|^>g`BikY zf;cK6hrzD2nWxb@cIlF7$Q_4m?gZt;*8DM`X4(Nn5r!IK| z=f*>nX^Kr1pRFA~zKYuk!IrlIJz$nIe=3N+03rg8EaF`cM8FOng9py5qbG@buGw2p6ZJbxK56z!CUs z=<0Z8ZGbdp@paj=FlGU{yjmk#wqZv&q3&&KXCjDCx0mp=&=9@Z^-e8N2@mNwVW!siGT zDqC@l15j%eV3YOf+3c=QkUg5eK9yr&V!rJ#bJ6fZ8(?yCDIuZF`&;`6s;{alDUXHD zlwcDxkP#2IwrYe0w}y+j@;lm?pSG|N{LkB(2UEBoO~rqNVfccvxy^0|kZ(-02A2^1 z5gqdzg!T^%=A$A-kt%XW3cuyixm&%#pYt)4;7~%OwYo4{{>{JgNUh4t%Pq4nGwygy z{`{z7{Nk&~f=JfSL~Ap5AZwf$VSpe8mi>HMrRAa|B8AQ~oOk@(@v1E*cZuq>F()~N zoR3^r`SXoabWKYk+j#52fNg;4$A1Ro?PFjR$aaRGZ?y1?IAD00myXv&mH!G=qW4j`9Oxo8#i`%L z@D<&&`Q!w^76Q>=;cI!KitXVZE@+hA*bL63hntH5&Yz`$d80D+umrSGWUxI z8{(}iTqP|#j$#AfQq@(^8mSd^R+GtMd_dis>Iy2&F-ccBm0cp-Xf0->@2*`bqfv@) z`~)dN)dgYGG$e2mvya@D6^YNzonbSna&LGw1*I$~NB1mm9tPQ?pt<0+ zNw8E}XQL$#$OUk|3D($@J~U}@4qR`wce7*LMy;gO1BE;-iO}j-!53g2xT@CHmS*wu z;K;)vq-%t^9vS$1-;UK+HgkxbGCxDlpTOtAfZCm3sPdggD!e%)ahW@1NaO>>!M`%2 zIs5GyGas|QTRM(qpe%NJZ9TVFokD=lnr7FOthK2@3!H{*k{w``Z@6cZ*iDic{|}E5 z&%|i^dRnSL5(#eM_?~B0XdASfE>`&>WAXf_eaf4}ZWB zSC(PDaDyJdPZlN57XUx!>hl~rJ3SE0Ux%I!vp5O2>c{q^eGbTzvT^HNYo@*eo!)6A zQI_p34xNk3sFz!lUEFTR`OX9grdYvMgBZ=z7FFKiKLsYT*1ttVhp`EB2iV;J%<**s zpIq32vXm8B@}#nz-l_nUOR%LXd|oV+f01>hj@tNDEZ!5>I`j<)_a^8WEw>!1j&8}AJ%-q=E){WY1Do8U1;QdaGz0$zhdjW*y$TkTaUg3`&tyH@>xl~y| z49<`8ogW_004KKdt`|Aqx6^F!c~v>03g_}GM8kt}+FZjXvtVb*{e-$oYD4~Pe8(Z+tp;yde1 zp4y?KCMh|M)cz(qiR5xLtt8Ia!QD*8Ckq#F1RP{roIHMmP~)p*k5d^T z)9+66_E^_mp_yyr!K?m0I+V!8mBkT$Mod=5z{*SAjsswbe1+)trb;tV7=xamj#O2@uSwjm*JV%tjM<66>fsR%TMOG7`G!mC zxp{Qaoaqc;S|qM;ng6T$qjv{2%7t`ZNjv1eH-5>n{L4S;x#>}_oO7&RoYrZgz@t+# zpZq+Bw=7BQ+6XE@{D(4M1*79y*#vrGK3!Pf*#qn3sGCpr)4UAI^2F%zxy=~z=8+-? zo8r^m2wdcAw{Vk{H^e`c+@aK*i3Q%Pt=#l3pU9<}E7$OAVx$RBbN3de#lt1J583_c zf%6y@mO5Mn>CskFEa#gk9(Iy^lyCdb$e$Uyf-o(7BKe`L8eWJVhn*w5R8YJ*!x?Vc`K7cRUS zm+0EH_ace>k6p|Q@=vcnCwDUSC?kH2&%2!w>hV_N;RR1H<(`OPyJ0{IH>( zU_XCDysr=$HI3^?mgOCsD%(r2X)h)Dm*PcZ3*+hBuD!mGXKx>Xt!{&dz8Qq-?9nVIj6ylw@`i z*$@XjkWFAvW8 zv^*9f{-w3TZ%K(gn_IJ-M>>C-tNbJk`P}|5As;zU>Xq&9oy~KKDvI6jt|@J$vUpUa zLTJG*q}+(fbI$!F75+`Gp6E}}$r(&lgyc8$t7;%GeRut`7Jb9;KcKo#nm3eJ?vuFr zYmp1lkwyx;n~~M`r@SVMi~eRwDUiQe(f_FK|CQJOqppJhP)JxrRE)ow(ZWXLft&fg z{}qjdko75Ir;y=i6=sr|d#vs8ESIl%mC1RfK*DCnwL?A%GBi4jaQV&AfSN3+`VyrZ zg+0=eR<8~!Aqib^K}H=Yc+sjcPA!dM%!wFc~YSadN!a z?MLQuvPF}Nb;thu$jBH_>OK_Wzb=TQ{YN}lyCD%p(hwqCs;%{9JoN|I&uIW(fFNk; z&??v%RV!+f$V!7W^L@;@5k}Hsm-XNjJM7FA*~*`~6>=O-OPm|#^HympQ9q5-J4Gz$ z8AKR$Q>w058{}QnWpo;qD@Brs$qzKKyJiN)!)kXAbgWs{R8SHQNEbS9F>hQS`hNm# z2*xxclM8-|!Cx9z^n*TNp;*TRg$x(iBcl%47~z-A4_{i)g;SIxLn)dDnmZ|UrmyYE zLbj~{q;Ihy^@dfC0!uNA8D(V0M^VRK47Rp7HONwXVA|RmG!ueqDHc@PbgxQ@bNjo5 zsX7k_2q3RWgPELuZiN(=r5*U04iJz{Fi44Uol9a*sUE{)atuS$6hIs!N!zas+)-(g zbrTDRqk*UElA975vw?<7qN~bcb;wWl%BnV5vY4ss`J%}KMzAbd@*88fH=?qE8VtRo z5X@3|TqT7f1+<8Nr0ZKtup+GiLIXjUMCtX z$U>YAScf#70Ke4|QW<83Uqrk)IhYuXL$*4eH~-#ACOU8pqZaB&Ni?H~B(ZWr$5jnB zA#1BUz{LJ0b;iZEhKc8lm8uG4HGqhStf3Fhb?R0th<=KgCdhb!5nkLiM1nL8yUl@W z{TvnQP-qyBHpQ*DUC7(!?@E6;CwYTR@3g!}#KsKNDpvjibdxi!8|7VTw8vek!n6Gl zL=guVN%_+E_3vSEGDlPxQ8A7XBFPNX4Vm!7#|=wXd0>+lb3}Gl_H z5L1k$Bhx7&vAk|Bs7Kb7+8xo@cQjV2fK{lq2*DO{G*^FkSvVNDmt*}xTU%8a!b;P0 zsB0ex)EPY8=aHY(VP1wfyO2SoI5Ph|xXKt3dYZvi!lT*8EP4T=u0I^!R)b2w zYRlot*w#sT7xd8W&t#Wj)h4T#&zRGn)Hv6D>kc{4m<3-ZGA(kL`D)!skE)F{+>!^O z^nMNt?Eyqw;+KNlTk_L-)MOOqyS=!D2cjpf%kQnF+@Hf60uJm&@3^LCCUIZ^6)=dZ zB*%7A*naCI_Yq`NJI*t&29KxQ6pL>h)>BZj%F17%DuCEyeDj^miYTde-$fN9f^|$GDYntaZ@Ua+vPJIhFRtNREeZkY z$ZvRWN=hR#emy3!3b_?yb$<}@U-HIgsu^VlF48f)6-4`NB05R?gEkE3`ZTob{w700 z{-T|rov7=xAIP?{SbV|#N*z?EdSb)i0O2E;cAtVEtO(@gEu&K>&i-MLrkYqM3A6}~ zB>H_5Q*DR60fAFi%=V&P`*=u}WS|}4v|^saDD=JDj)3 zsB}2=0Tbge%FOqm>?>RlkNzY(BCPXH1TP-P4|ip-ze2SGL0Y<{`_1E-%2QN>pCQV_`!eVh53<9%kyJ);4C^z? zE)by9wlyM~8tbsWmf#;yLwWr#f+m z%Z8`&PjF>`iTr?EJGg+)q&dJG9)>Q0q}DgsZ`h;Io(HVI*;8~s$YHr8jieU&>sCas z)#PVO?rJa2Qgi&m-j_|+P?oD$1^KaLG6`TsnvX0+C7RCt9n4FM53@N=`NuloeJvNG z>t@DBOC+4lXp0ig6s%O>K8%jAP^Zul3wqHE%xqdsEoB-t(OREE2+9)Z-^S_}o}=kw zy9dE=y{Am2FSiXbAe1kSh|AnkzHM{yTX3?~GF*v!*g zy=6q?QmE845+xO4CWK$kP9ite%LHCj8Pe40Px6H`Q`GYw)nLL1IwV1(e2rKh2% z0bPn-ooV%Nhc&N{*m@t0YuKlP>IHu5U(UG`6b4tC^ zj*0-A%V;Qx80nkg_(Gjd7c=Ur4;XC}&a_^kqKm?NJXv+Aa-O9qUeVrW8?@%uMsH)x zYq~|jl;)s)uXHIQCn2XXijom+7HAe=coQE2Kf|f^^7|K_gH+;zL7}^8y+_M({LC}FInu9)Z_;a)k-?MPV%Jv-Cn!y?*!k%d%a*Z@Hfl)U z7GRM_+OWHVeRJU)#%)CA6&;X?PKc^rsu>SJ;YusniL|R1FXRa+2%f&{i-Wc@JoN}$ zTC38@){YHkh9s<_Kr(_hsV8&l5irsTVA%&P1U#gB;)r@X4T*PHNnSihhMONo%6e4C%yBoI3lQwvFBmYuv@~fJx)8Yx1aKnNIjl;=a zqL#BuxbT6aHPXyoJWwQ!Xi#TbegtU|SIJlDTamywlFfk}rD+J(Hs-=aFp@7vuVD=_ zWnZYQ856b@M7p3n;qfOk8ZfYUQjlW-IOKfG8w7r6=VIL`Y(aFPuNpH4)7&2G6#a zxQw6z=+iJhQe}L)B}Ido9x)&Q1FVu(5U&q$iP4ohhVkOGX4b~vaWFZ`tJY!BO@WH~ zCR5zXZ2fAmJOEFv_?$>m6)MIE+R7CM@BOvRgR%{&OxUqSYrLtabAsefMw3QRfK8j|s+0eWg%;-q)z|jn`RCj5Usd#Gd?=@lCgpymXMiDq?T& z^9We0P5v(a=zt!inyXEdv1@UF{Q@YDtcM?=I?FgA&OtT9Lp5OrBqNt5je)WRXX%?p zKGVg$^o}{D9tIc#=$edbryl{u9j_&$G=n1ld?J6t>VuR)G{_A`2PAtz<+0|B-1lvx6orHoD5qNWJO&#U0aqI~r;s%nHm}g`u?a2zQ8=Bh zW>GXC3{{KsDYtK^Zi!b-tF9Cu6|Wdsb`9-g)F2uvcNiTw1C8b5?2S!W=dI+A3&HhH zhs{i8NtYF`ZU$;Ff8Rf5nzey!_fIFKMmkaV(BFtyxF$fc+}iyu8sxOlXdN|mYDk|M z5L|h`!sY~MI)&n=Jpx2`YN)S2a~tJWuFhyF=H=}au=}W?oq`7k`xok=*yx-*qhT^D zJ}WSbG?Pt+`xaDSG?OAcN-|d30?IJT*VRoa53I~$*O-{;Jz273WkGdEz{mLfPd=49 zq*$TuQFI$D-lK6Cfj7=*L$EF7T|p}NnF6d$UZIZgSExmQO$aO&ZkcwAH?7|>aEFlr z$AUnxbL4eR$;l2xnOX>=mj%Z~HZo%obwr74ClY;?nIlHL2#|H%(u6re?7*j)+3~Zi zEHgQNVaQ4aFF(JEuq+NwE!)>%X5HDqVVy~kBhbB7$<;$p=ow#A6H$B-7H)IMqo9;g zx~Nye{22!kLula;jIIG>OXz}lr!7KEIF3lS=I@S;4dr_r8VnTTd$tAQKT|E=9UCYZ z02~FC1On~5ZzDB!hAbpw6%B$yCwEE2R849e_;3FP3IZYo$iz!?;Tn+wV-~cr(wFHk zSNwQe{)N}mMxx&t2?^l%EbMe26vpUztp4dx?qJ&Jn71?%e_Rs6khe5XtyeKnVP-}g z6cL&9d5cXmK0{^%@Q$*M9!5tPkzy?BimxioV3+;76xEa=eFsSZA2fN@~74Uy&{S14FfisiGbDbfiWQICe$gqp1P-)arz9~%HgIIzt znWBM=9okf!^!B(H63k(6*%Wv{<*VXr}Fe;)Wx|y>!-#7$1$h5!%$^e)~-Wj#J>uC5)zAu zRH_7*vP>-p8WW_XZd{fWFm;A*`F?^ha#9XuKGj>98`dhC>#}6xAW>^y^&CfMW~jT^ z8EJ7^rJ$@Z^tZ#9=G+Mm>ab*Z_7_3A@X2OAGGuK29+gOv$qBvzfpLl!cUW5Z}kz@JNFViIVnqxl8lqXZxM>xp?qj^e?3#nYxRwI4cn zl({*;VaV>)x0m3L{w=X2;0&Wob%JvaYeG%wZ4nxTL37Ru^HkrG>q)d&Ln@cB!SAOA z>xIwYlmd<+h`9$L;BRR4_occk;@fPd*FUI0aPGWqUfPRE6<-V1or*5)RM6$VQaFr^3bE3CGAT70rWE@BBSS4;g5*ipAvUMNYjIQ`KOY)#6Mkvfiaee+ysDG|pWuNWgLMvknHuOF(x`jrIHN zu~izt8qfjmB;mjqWS^*geCka!eySU`j;oencrZFNO;6dAbl6=@q%5obPOWrcT!%`I zVw! zn+@SOR_W&(n?IlX3leE;9UsdP;RFN45Aln>9xHS&Mz%7P!msVz%3~D^E_6S?0EjQ; zNSb+r61{&#I0Hsv}uSxXIbts;uiZ4i4Zc(m< z>QXK=N;bXu63saWc);QFCF-X;1>H?}^V1&YI+z!g&ic|)_q}ft8dnC|5R8LhHiy$6 z0!=BHmD;_IgRmQldqNa90n3QxGHP$d>mBe0e4PD+2IEHJ3?pJ!#xJ`$zp4v6Wj{`@yd}bo-njyD(j*& zBVBc_60V*OuKo+LC0xOpsG8601S3Y$>aB$Qi8RMJj-RNHuFIzyI9Jl9N*5c-IJ8#x zm|MYeum3lc^!Kpceu6pNVLY#AB`?>daKZx0l3V>Z(zE{;>Hjms2nzRqe~I~Dq-TK; zK_gXhhD2vIMh+q)XA@O*`6lZI<|+OQ_20waLI4era9R`oTIs!TItrNwjynr;eHifMZP&ZU{(%*Z%LoiQ{Mp1 zbZJpr28Fn?lpLNQRt!#lO((i6(?!35tsjF*ECtjk1MPGOgCBKss>qFYFz7DkR{t01 zmPrH^JCY5x`>g>EzpDl&3@Y5uQQcIv%|A&zRO7KYjq4I}c4w2fGwyKY?x>^e z0liwA~t@uA2bM6Iv{2}kJ)d$BV(E#u1;1b0yq+ zuYT!KW+c9w?jFxUi+iM{*qw?Q56SUAyA&%QGbADK4Thauk+8uT^QP^uVfoM=>n3{D zxM8RUKlNtGUdwV5l!cGAi+-C1G1nocw_KmFdUV# z3@YUSyQrVV;K1hn3jm?cLp&#zvtKB-U$(v@l&U>|%Zmg1646L(e=2w!J$)S(gZ-Hc zZU(%wG6XLoIe>;-f2xjL&O|#tYTvCU21mINssE`wWNlH)*lt0k!F~vGrjN>f^(G60 zGdVn66S?8KxZ4 zB%|;0nz@o{ZX$#H$Hc(RXpL&=;e1+j*%-POPtqx-@u}!jX{vIM&G2gz_9`$!VmBo= z6H;hIzt6(kl3Yb4<~YyJKKFyB+=UqrY|dUJn+M)iAD&`NDQd5ujjk!}sj_wO{8Wht z+y#M3r6r5GU7v$c-eHUw?OeiLu^yw~M{Oc_d8kO3th%LT;sm|#aNgmHHa2S#j-m+0 z0`FNy2^wlUvE4!V*b4{4$l<|LNyXI|2SO*j{w9OIjes0_ROg)MSg8WH*QN@|1Qk5= z$)8Aa8_?#q{nmW_Gef`AR$n@9PEY1MRlQGnS}Bd{zpL#(9~xIA2b;9c$9e^PP4aA88&T>*|=^ z^tYZ=JJJVT^8BSQTs5{W{z%TVW~{wk{aR_(AF(vD-NRfJxFG*fsg8H&{sOqut3o>U zR#!>w>h>i4I>SFxQ7$`s!i|-Lj~nSwwhr8UTUBhoWg-(5d_>(7ZAS-O_E#PmIkk{v zd4eNEkY7OfYE*dl40zHe2Ky<}Tt1?*TuX-8+3r>LFjk@_2N9cnzk)@&*n)c;hkRA@ zoQTf@wDBfoFW^!aCBpZ9tf7S0D6zXyop8+CYK$;jQ$gtifo?J&n8t%=;;EFr&WV$K zJaf6+JEf5qr2he9=>*G9_DZFTy@ci;<6u+_kyz4BaLlJ Date: Wed, 27 Dec 2023 10:55:01 +0200 Subject: [PATCH 147/311] Add experimental HAT model --- modules/hat_model.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 modules/hat_model.py diff --git a/modules/hat_model.py b/modules/hat_model.py new file mode 100644 index 000000000..553e19411 --- /dev/null +++ b/modules/hat_model.py @@ -0,0 +1,42 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + + +class UpscalerHAT(Upscaler): + def __init__(self, dirname): + self.name = "HAT" + self.scalers = [] + self.user_path = dirname + super().__init__() + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scale = 4 # TODO: scale might not be 4, but we can't know without loading the model + scaler_data = UpscalerData(name, file, upscaler=self, scale=scale) + self.scalers.append(scaler_data) + + def do_upscale(self, img, selected_model): + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr) + return img + model.to(devices.device_esrgan) # TODO: should probably be device_hat + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile + tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap + ) + + def load_model(self, path: str): + if not os.path.isfile(path): + raise FileNotFoundError(f"Model file {path} not found") + return modelloader.load_spandrel_model( + path, + device=devices.device_esrgan, # TODO: should probably be device_hat + ) From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: [PATCH 148/311] Verify architecture for loaded Spandrel models --- extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 1 + modules/codeformer_model.py | 1 + modules/esrgan_model.py | 1 + modules/gfpgan_model.py | 1 + modules/hat_model.py | 1 + modules/modelloader.py | 13 ++++++++++++- modules/realesrgan_model.py | 7 ++++--- 8 files changed, 22 insertions(+), 5 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a0..5f3dd08b3 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 85c18b9e9..aae159af5 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler): filename, device=self._get_device(), dtype=devices.dtype, + expected_architecture="SwinIR", ) if getattr(opts, 'SWIN_torch_compile', False): try: diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ceda4bab9..44b84618e 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): return modelloader.load_spandrel_model( model_path, device=devices.device_codeformer, + expected_architecture='CodeFormer', ).model raise ValueError("No codeformer model found") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e30..70041ab02 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler): return modelloader.load_spandrel_model( filename, device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', ) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a356b56fe..48f8ad5e2 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): net = modelloader.load_spandrel_model( model_path, device=self.get_device(), + expected_architecture='GFPGAN', ).model net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 return net diff --git a/modules/hat_model.py b/modules/hat_model.py index 553e19411..7f2abb416 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler): return modelloader.load_spandrel_model( path, device=devices.device_esrgan, # TODO: should probably be device_hat + expected_architecture='HAT', ) diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932a..f4182559e 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 332d8f4b1..2a2be5ad7 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,9 @@ import os -from modules.upscaler_utils import upscale_with_model -from modules.upscaler import Upscaler, UpscalerData -from modules.shared import cmd_opts, opts from modules import modelloader, errors +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model class UpscalerRealESRGAN(Upscaler): @@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="RealESRGAN", ) return upscale_with_model( mod, From 05230c02606080527b65ace9eacb6fb835239877 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 18:02:51 +0300 Subject: [PATCH 149/311] fix img2img api that i broke when implementing infotext support --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/api/api.py b/modules/api/api.py index 2918f7857..2e18c6b91 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -507,6 +507,7 @@ class Api: args.pop('script_name', None) args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them args.pop('alwayson_scripts', None) + args.pop('infotext', None) script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args) From f476649c02cf3547d891fa08c50a92f92c4d73bd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:41:19 +0200 Subject: [PATCH 150/311] Correct arg type for restore_face --- modules/face_restoration_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index c65c85ef8..85cb30570 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -36,7 +36,7 @@ def create_face_helper(device) -> FaceRestoreHelper: def restore_with_face_helper( np_image: np.ndarray, face_helper: FaceRestoreHelper, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: """ Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. @@ -126,7 +126,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration): def restore_with_helper( self, np_image: np.ndarray, - restore_face: Callable[[np.ndarray], np.ndarray], + restore_face: Callable[[torch.Tensor], torch.Tensor], ) -> np.ndarray: try: if self.net is None: From 91560e98c47f8271d444556ef4ae6505dece9aba Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sat, 30 Dec 2023 23:42:10 +0800 Subject: [PATCH 151/311] fix format issue --- modules/api/api.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2f718ec26..d202cb8d2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -418,7 +418,6 @@ class Api: task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(False) @@ -489,14 +488,13 @@ class Api: mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: if not script_runner.scripts: script_runner.initialize_scripts(True) ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) if not self.default_script_arg_img2img: self.default_script_arg_img2img = self.init_default_script_args(script_runner) From c9174253fb603e6b2552e4c2721fd767b6ede87d Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 17:45:26 +0200 Subject: [PATCH 152/311] Drop dependency on basicsr --- modules/face_restoration_utils.py | 35 +++++++++++++++++++++++-------- requirements.txt | 1 - requirements_versions.txt | 1 - 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index 85cb30570..1cbac2364 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -17,6 +17,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: + """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" + assert img.shape[2] == 3, "image must be RGB" + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return torch.from_numpy(img.transpose(2, 0, 1)).float() + + +def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: + """ + Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. + """ + tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + assert tensor.dim() == 3, "tensor must be RGB" + img_np = tensor.numpy().transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image, no RGB/BGR required + return np.squeeze(img_np, axis=2) + return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) + + def create_face_helper(device) -> FaceRestoreHelper: from facexlib.detection import retinaface from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -43,7 +65,6 @@ def restore_with_face_helper( `restore_face` should take a cropped face image and return a restored face image. """ - from basicsr.utils import img2tensor, tensor2img from torchvision.transforms.functional import normalize np_image = np_image[:, :, ::-1] original_resolution = np_image.shape[0:2] @@ -56,23 +77,19 @@ def restore_with_face_helper( face_helper.align_warp_face() logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) for cropped_face in face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) try: with torch.no_grad(): - restored_face = tensor2img( - restore_face(cropped_face_t), - rgb2bgr=True, - min_max=(-1, 1), - ) + cropped_face_t = restore_face(cropped_face_t) devices.torch_gc() except Exception: errors.report('Failed face-restoration inference', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') + restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) + restored_face = (restored_face * 255.0).astype('uint8') face_helper.add_restored_face(restored_face) logger.debug("Merging restored faces into image") diff --git a/requirements.txt b/requirements.txt index b1329c9e3..731a1be7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ GitPython Pillow accelerate -basicsr blendmodes clean-fid einops diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e..1e0ccafa7 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,7 +1,6 @@ GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 -basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 From b58ed1b2432c3c7643b39e53f7bb567ea8655aae Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 18:02:01 +0200 Subject: [PATCH 153/311] Bump numpy to 1.26.2 This avoids it being downgraded during `launch.py` --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e..7ec7abe2e 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -13,7 +13,7 @@ inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 lark==1.1.2 -numpy==1.23.5 +numpy==1.26.2 omegaconf==2.2.3 open-clip-torch==2.20.0 piexif==1.1.3 From f651405427dfc6d4ef96ecba7f9c2ceb580263fd Mon Sep 17 00:00:00 2001 From: lanyeeee <1210347077@qq.com> Date: Sun, 31 Dec 2023 01:09:13 +0800 Subject: [PATCH 154/311] remove locks, move init code to __init__ --- modules/api/api.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index d202cb8d2..fc3921c23 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -251,8 +251,21 @@ class Api: self.default_script_arg_txt2img = [] self.default_script_arg_img2img = [] - self.txt2img_script_arg_init_lock = Lock() - self.img2img_script_arg_init_lock = Lock() + txt2img_script_runner = scripts.scripts_txt2img + img2img_script_runner = scripts.scripts_img2img + + if not txt2img_script_runner.scripts or not img2img_script_runner.scripts: + ui.create_ui() + + if not txt2img_script_runner.scripts: + txt2img_script_runner.initialize_scripts(False) + if not self.default_script_arg_txt2img: + self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner) + + if not img2img_script_runner.scripts: + img2img_script_runner.initialize_scripts(True) + if not self.default_script_arg_img2img: + self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner) @@ -418,16 +431,10 @@ class Api: task_id = txt2imgreq.force_task_id or create_task_id("txt2img") script_runner = scripts.scripts_txt2img - with self.txt2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(False) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_txt2img: - self.default_script_arg_txt2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner) populate = txt2imgreq.copy(update={ # Override __init__ params @@ -488,16 +495,10 @@ class Api: mask = decode_base64_to_image(mask) script_runner = scripts.scripts_img2img - with self.img2img_script_arg_init_lock: - if not script_runner.scripts: - script_runner.initialize_scripts(True) - ui.create_ui() - infotext_script_args = {} - self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) + infotext_script_args = {} + self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args) - if not self.default_script_arg_img2img: - self.default_script_arg_img2img = self.init_default_script_args(script_runner) selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner) populate = img2imgreq.copy(update={ # Override __init__ params From 1465dab71564bb30091479ceabae6c69e3426bc6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:05 +0200 Subject: [PATCH 155/311] Make Tensorboard a late import (it was implicitly installed by basicsr) --- modules/textual_inversion/textual_inversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 04dda585c..c6bcab153 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,6 @@ import safetensors.torch import numpy as np from PIL import Image, PngImagePlugin -from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset @@ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) def tensorboard_setup(log_directory): + from torch.utils.tensorboard import SummaryWriter os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) return SummaryWriter( log_dir=os.path.join(log_directory, "tensorboard"), @@ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed + tensorboard_writer = None if shared.opts.training_enable_tensorboard: - tensorboard_writer = tensorboard_setup(log_directory) + try: + tensorboard_writer = tensorboard_setup(log_directory) + except ImportError: + errors.report("Error initializing tensorboard", exc_info=True) pin_memory = shared.opts.pin_memory @@ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" - if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + if tensorboard_writer and shared.opts.training_tensorboard_save_images: tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: From 48a2a1a437a48cc232725cc813242f98483b7697 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 19:44:38 +0200 Subject: [PATCH 156/311] Don't wait for 10 minutes for test server to come up --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index cd5c3f868..f42e4758e 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -57,7 +57,7 @@ jobs: 2>&1 | tee output.txt & - name: Run tests run: | - wait-for-it --service 127.0.0.1:7860 -t 600 + wait-for-it --service 127.0.0.1:7860 -t 20 python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test - name: Kill test server if: always() From 5fbb13e0da8eb2e26bd2c45ec8ffbb2de669ef47 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 20:46:44 +0200 Subject: [PATCH 157/311] Remove `cleanup_models` code --- modules/initialize.py | 3 --- modules/modelloader.py | 50 ------------------------------------------ 2 files changed, 53 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index ac95fc6f0..4a3cd98cf 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -54,9 +54,6 @@ def initialize(): initialize_util.configure_sigint_handler() initialize_util.configure_opts_onchange() - from modules import modelloader - modelloader.cleanup_models() - from modules import sd_models sd_models.setup_model() startup_timer.record("setup SD model") diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e..5f7aec3e4 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging import os -import shutil import importlib from urllib.parse import urlparse @@ -10,7 +9,6 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone -from modules.paths import script_path, models_path logger = logging.getLogger(__name__) @@ -96,54 +94,6 @@ def friendly_name(file: str): return model_name -def cleanup_models(): - # This code could probably be more efficient if we used a tuple list or something to store the src/destinations - # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler - # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") - move_files(src_path, dest_path, ".ckpt") - move_files(src_path, dest_path, ".safetensors") - src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path, ".pth") - src_path = os.path.join(root_path, "gfpgan") - dest_path = os.path.join(models_path, "GFPGAN") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") - move_files(src_path, dest_path) - - -def move_files(src_path: str, dest_path: str, ext_filter: str = None): - try: - os.makedirs(dest_path, exist_ok=True) - if os.path.exists(src_path): - for file in os.listdir(src_path): - fullpath = os.path.join(src_path, file) - if os.path.isfile(fullpath): - if ext_filter is not None: - if ext_filter not in file: - continue - print(f"Moving {file} from {src_path} to {dest_path}.") - try: - shutil.move(fullpath, dest_path) - except Exception: - pass - if len(os.listdir(src_path)) == 0: - print(f"Removing empty folder: {src_path}") - shutil.rmtree(src_path, True) - except Exception: - pass - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ From af050dcaa75ef40b6b1c3da3361f32fe52786aeb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:05:59 +0200 Subject: [PATCH 158/311] Soften Spandrel model-architecture check to just a warning --- modules/modelloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559e..6b7d697f1 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -196,7 +196,9 @@ def load_spandrel_model( import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) if expected_architecture and model.architecture != expected_architecture: - raise TypeError(f"Model {path} is not a {expected_architecture} model") + logger.warning( + f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + ) if half: model = model.model.half() if dtype: From 393a5b82ba6df06d85f2bf7bbbe0456d3d06115f Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:12:32 +0200 Subject: [PATCH 159/311] Correct RealESRGAN expected architecture type to ESRGAN --- modules/realesrgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2a2be5ad7..65f2e8806 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -40,7 +40,7 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), - expected_architecture="RealESRGAN", + expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( mod, From 8100e901ab0c5b04d289eebb722c8a653b8beef1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 22:41:53 +0300 Subject: [PATCH 160/311] fix error with RealESRGAN model failing to upscale fp32 image --- modules/upscaler_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8bdda51c4..39f78a0b7 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + + model_weight = next(iter(model.parameters())) + img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + with torch.no_grad(): output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255. * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) From bc5ae74c7d8949bab37e260b16e76889b9968099 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 21:52:27 +0100 Subject: [PATCH 161/311] Added negative prompts to extra networks lora --- .../Lora/ui_edit_user_metadata.py | 14 +++++++-- .../Lora/ui_extra_networks_lora.py | 9 ++++++ javascript/extraNetworks.js | 31 +++++++++++++------ modules/ui_extra_networks.py | 5 ++- 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c70119090..f7859b21f 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,14 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +129,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +166,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +203,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -211,7 +218,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b..09ce2a057 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -45,6 +45,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + preferred_negative_weight = item["user_metadata"].get("negative weight") + item["negative_prompt"] = quote_js("") + if negative_prompt: + neg_prompt = negative_prompt + if (preferred_negative_weight > 0): + neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb74..2bb9795dc 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -185,8 +185,10 @@ onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - var m = text.match(re_extranet); +var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/; +var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g; +function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { + var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; if (m) { @@ -194,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { - m = found.match(re_extranet); + newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) { + m = found.match(isNeg ? re_extranet_neg : re_extranet); if (m[1] == partToSearch) { replaced = true; foundAtPosition = pos; @@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { }); if (foundAtPosition >= 0) { - if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); } if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { @@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return false; } -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); +function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { - textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { + textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } - updateInput(textarea); + updateInput(textArea); +} + +function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { + if (textToAddNegative.length > 0) { + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + } else { + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); + updatePromptArea(textToAdd, textarea) + } } function saveCardPreview(event, tabname, filename) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba33..b8c022413 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -223,7 +223,10 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + if "negative_prompt" in item: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + else: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' From a2f23f9d22dde87bf2529dcb2854a6a5d3d44278 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:16:51 +0100 Subject: [PATCH 162/311] Code Style fixes --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- javascript/extraNetworks.js | 6 +++--- modules/upscaler_utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 09ce2a057..9a6624e3f 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -52,8 +52,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): neg_prompt = negative_prompt if (preferred_negative_weight > 0): neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) - + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2bb9795dc..f1ad19a66 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -243,11 +243,11 @@ function updatePromptArea(text, textArea, isNeg) { function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { if (textToAddNegative.length > 0) { - updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) - updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")); + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true); } else { var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); - updatePromptArea(textToAdd, textarea) + updatePromptArea(textToAdd, textarea); } } diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7..1d610dbff 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 3be90740316f8fbb950b31d440458a5e8ed4beb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 00:43:41 +0300 Subject: [PATCH 163/311] fix for the previous fix. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b7..dde5d7ad4 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): From c0ca6348e8489651df861a101142805c213c66a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:04:47 +0200 Subject: [PATCH 164/311] load_spandrel_model: always return a model descriptor --- modules/modelloader.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682c..8bcee08c1 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -142,17 +145,17 @@ def load_spandrel_model( half: bool = False, dtype: str | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: [PATCH 165/311] Be more clear about Spandrel model nomenclature --- extensions-builtin/SwinIR/scripts/swinir_model.py | 6 +++--- modules/gfpgan_model.py | 10 ++++++---- modules/modelloader.py | 2 +- modules/realesrgan_model.py | 4 ++-- modules/upscaler_utils.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index aae159af5..95c7ec648 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -71,7 +71,7 @@ class UpscalerSwinIR(Upscaler): else: filename = path - model = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), dtype=devices.dtype, @@ -79,10 +79,10 @@ class UpscalerSwinIR(Upscaler): ) if getattr(opts, 'SWIN_torch_compile', False): try: - model = torch.compile(model) + model_descriptor.model.compile() except Exception: logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) - return model + return model_descriptor def _get_device(self): return devices.get_device_for('swinir') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e2..445b04092 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging import os +import torch + from modules import ( devices, errors, @@ -25,7 +27,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): def get_device(self): return devices.device_gfpgan - def load_net(self) -> None: + def load_net(self) -> torch.Module: for model_path in modelloader.load_models( model_path=self.model_path, model_url=model_url, @@ -34,13 +36,13 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - net = modelloader.load_spandrel_model( + model = modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return net + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c1..a71941375 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e8806..4d35b695c 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - mod = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( - mod, + model_descriptor, img, tile_size=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad4..174c9bc37 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: [PATCH 166/311] Deduplicate tiled inference code from SwinIR/ScuNET --- .../ScuNET/scripts/scunet_model.py | 55 +++----------- .../SwinIR/scripts/swinir_model.py | 57 ++------------- modules/upscaler_utils.py | 72 ++++++++++++++++++- 3 files changed, 87 insertions(+), 97 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b3..f799cb76d 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import sys import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 95c7ec648..8a555c794 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -4,11 +4,11 @@ import sys import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -110,14 +110,14 @@ def upscale( w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference( + output = tiled_upscale_2( img, model, - tile=tile, + tile_size=tile, tile_overlap=tile_overlap, - window_size=window_size, scale=scale, device=device, + desc="SwinIR tiles", ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() @@ -129,53 +129,6 @@ def upscale( return Image.fromarray(output, "RGB") -def inference( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size: int, - scale: int, - device, -): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc37..8e4138543 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: [PATCH 167/311] Add utility to inspect a model's parameters (to get dtype/device) --- modules/devices.py | 3 ++- modules/interrogate.py | 3 ++- modules/sd_models_xl.py | 3 ++- modules/torch_utils.py | 17 +++++++++++++++++ modules/upscaler_utils.py | 5 +++-- modules/xlmr.py | 5 ++++- modules/xlmr_m18.py | 5 ++++- test/test_torch_utils.py | 19 +++++++++++++++++++ 8 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 modules/torch_utils.py create mode 100644 test/test_torch_utils.py diff --git a/modules/devices.py b/modules/devices.py index c956207f3..bd6bd579b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,6 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared +from modules.torch_utils import get_param if sys.platform == "darwin": from modules import mac_specific @@ -131,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype + org_dtype = get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 3045560d0..5be5a10f3 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -11,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.torch_utils import get_param blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -131,7 +132,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = next(self.clip_model.parameters()).dtype + self.dtype = get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1de31b0da..c3602a7e1 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,6 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser +from modules.torch_utils import get_param def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = next(model.model.diffusion_model.parameters()).dtype + dtype = get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/torch_utils.py b/modules/torch_utils.py new file mode 100644 index 000000000..e5b52393e --- /dev/null +++ b/modules/torch_utils.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import torch.nn + + +def get_param(model) -> torch.nn.Parameter: + """ + Find the first parameter in a model or module. + """ + if hasattr(model, "model") and hasattr(model.model, "parameters"): + # Unpeel a model descriptor to get at the actual Torch module. + model = model.model + + for param in model.parameters(): + return param + + raise ValueError(f"No parameters found in model {model!r}") diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e4138543..c60e3beb8 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ import tqdm from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) diff --git a/modules/xlmr.py b/modules/xlmr.py index a407a3cad..6e000a56e 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index a727e8655..e3e819610 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional +from modules.torch_utils import get_param + + class BertSeriesConfig(BertConfig): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): @@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = next(self.parameters()).device + device = get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py new file mode 100644 index 000000000..f1aec8329 --- /dev/null +++ b/test/test_torch_utils.py @@ -0,0 +1,19 @@ +import types + +import pytest +import torch + +from modules.torch_utils import get_param + + +@pytest.mark.parametrize("wrapped", [True, False]) +def test_get_param(wrapped): + mod = torch.nn.Linear(1, 1) + cpu = torch.device("cpu") + mod.to(dtype=torch.float16, device=cpu) + if wrapped: + # more or less how spandrel wraps a thing + mod = types.SimpleNamespace(model=mod) + p = get_param(mod) + assert p.dtype == torch.float16 + assert p.device == cpu From d4945f4422e5a0bf31a6dbe4c1aeedd78c09eacb Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:22:30 +0100 Subject: [PATCH 168/311] Removed weight slider for negative prompts --- extensions-builtin/Lora/ui_edit_user_metadata.py | 7 +------ extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index f7859b21f..3160aecfa 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,14 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["negative text"] = negative_text - user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -130,7 +129,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), user_metadata.get('negative text', ''), - float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -167,7 +165,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -204,7 +201,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -219,7 +215,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, self.edit_notes, ] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9a6624e3f..e714fac46 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -46,13 +46,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): item["prompt"] += " + " + quote_js(" " + activation_text) negative_prompt = item["user_metadata"].get("negative text") - preferred_negative_weight = item["user_metadata"].get("negative weight") item["negative_prompt"] = quote_js("") if negative_prompt: - neg_prompt = negative_prompt - if (preferred_negative_weight > 0): - neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: From b6f74e936e4de3b8d190bffaf3bed67d6d4bd211 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:36:36 +0100 Subject: [PATCH 169/311] Revert change from linting for unrelated file --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 1d610dbff..39f78a0b7 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images +from modules import devices, images logger = logging.getLogger(__name__) From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: [PATCH 170/311] change import statements for #14478 --- modules/devices.py | 4 ++-- modules/interrogate.py | 5 ++--- modules/sd_models_xl.py | 4 ++-- modules/upscaler_utils.py | 5 ++--- modules/xlmr.py | 4 ++-- modules/xlmr_m18.py | 5 ++--- test/test_torch_utils.py | 4 ++-- 7 files changed, 14 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index bd6bd579b..ff279ac50 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -4,7 +4,7 @@ from functools import lru_cache import torch from modules import errors, shared -from modules.torch_utils import get_param +from modules import torch_utils if sys.platform == "darwin": from modules import mac_specific @@ -132,7 +132,7 @@ patch_module_list = [ def manual_cast_forward(self, *args, **kwargs): - org_dtype = get_param(self).dtype + org_dtype = torch_utils.get_param(self).dtype self.to(dtype) args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} diff --git a/modules/interrogate.py b/modules/interrogate.py index 5be5a10f3..35a627ca1 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,8 +10,7 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors -from modules.torch_utils import get_param +from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -132,7 +131,7 @@ class InterrogateModels: self.clip_model = self.clip_model.to(devices.device_interrogate) - self.dtype = get_param(self.clip_model).dtype + self.dtype = torch_utils.get_param(self.clip_model).dtype def send_clip_to_ram(self): if not shared.opts.interrogate_keep_models_in_memory: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index c3602a7e1..0de17af3d 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -6,7 +6,7 @@ import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer from modules import devices, shared, prompt_parser -from modules.torch_utils import get_param +from modules import torch_utils def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): @@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt def extend_sdxl(model): """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" - dtype = get_param(model.model.diffusion_model).dtype + dtype = torch_utils.get_param(model.model.diffusion_model).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' model.cond_stage_key = 'txt' diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb8..f5cb92d51 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56e..319771b7b 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index e3e819610..f60555049 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -4,8 +4,7 @@ import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional - -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py index f1aec8329..23ccb93a4 100644 --- a/test/test_torch_utils.py +++ b/test/test_torch_utils.py @@ -3,7 +3,7 @@ import types import pytest import torch -from modules.torch_utils import get_param +from modules import torch_utils @pytest.mark.parametrize("wrapped", [True, False]) @@ -14,6 +14,6 @@ def test_get_param(wrapped): if wrapped: # more or less how spandrel wraps a thing mod = types.SimpleNamespace(model=mod) - p = get_param(mod) + p = torch_utils.get_param(mod) assert p.dtype == torch.float16 assert p.device == cpu From 00901bfbe0095303554f4440b4c12fac262e2e89 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 1 Jan 2024 15:47:57 +0900 Subject: [PATCH 171/311] handle selectable script_index is None --- modules/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/scripts.py b/modules/scripts.py index 3a7669118..017aed5a0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -696,6 +696,8 @@ class ScriptRunner: self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): + if script_index is None: + script_index = 0 selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] @@ -739,7 +741,7 @@ class ScriptRunner: def run(self, p, *args): script_index = args[0] - if script_index == 0: + if script_index == 0 or script_index is None: return None script = self.selectable_scripts[script_index-1] From 5692bf1517c3409ad46262c56e65f256389825b1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 11:11:14 +0300 Subject: [PATCH 172/311] add missing field for DDIM sampler that was breaking img2img --- modules/sd_samplers_timesteps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index b17a8f93c..f8afa8bd7 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_default = 0.0 self.model_wrap_cfg = CFGDenoiserTimesteps(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) From 003b91f08361c99ecdd97257624d81a2046d3823 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:45:01 +0300 Subject: [PATCH 173/311] rename generation_parameters_copypaste module to infotext --- modules/{generation_parameters_copypaste.py => infotext.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename modules/{generation_parameters_copypaste.py => infotext.py} (100%) diff --git a/modules/generation_parameters_copypaste.py b/modules/infotext.py similarity index 100% rename from modules/generation_parameters_copypaste.py rename to modules/infotext.py From c5496c76461c90bd186ae8804aa65a33cd136d48 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:52:37 +0300 Subject: [PATCH 174/311] infotext.py: add support for old modules.generation_parameters_copypaste name --- modules/infotext.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/infotext.py b/modules/infotext.py index 86a36c327..bcbeb0fdc 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -4,12 +4,15 @@ import io import json import os import re +import sys import gradio as gr from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image +sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name + re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_imagesize = re.compile(r"^(\d+)x(\d+)$") From d859cec696a953dbfd6f69f7735e68661748d579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:53:12 +0300 Subject: [PATCH 175/311] infotext.py: rename usages in the codebase --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 14 files changed, 25 insertions(+), 25 deletions(-) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index ac2c3de46..8aa901fd4 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste +from modules import scripts, shared, ui_components, ui_settings, infotext from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 843c59b04..0e2807de2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ class Api: if not request.infotext: return {} - possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] + possible_fields = infotext.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + params = infotext.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ class Api: if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params) + overriden_settings = infotext.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ class Api: if geninfo is None: geninfo = "" - params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + params = infotext.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index c583290a0..75b3d346f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr import gradio as gr from modules import images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules.infotext import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 0c59fad48..f776f7b69 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext from modules.shared import opts @@ -86,7 +86,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index 7789f9a42..b30df60db 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -733,7 +733,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index cefad32b7..e9941413f 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index a3e16a12e..602932785 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index 991971ad0..e13924720 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import generation_parameters_copypaste, shared + from modules import infotext, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in generation_parameters_copypaste.paste_fields.values(): + for tab_data in infotext.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index e4e18ceb6..3a481915f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ from contextlib import closing import modules.scripts from modules import processing -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.infotext import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 9db2407ed..6451e14c1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript from modules.shared import opts, cmd_opts -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text, PasteField +from modules.infotext import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index 032ec4af7..fd32676f9 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import gradio as gr import subprocess as sp from modules import call_queue, shared -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b8c022413..790af1356 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import json import html from fastapi.exceptions import HTTPException -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 36a807fcd..87aeb6f33 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import os.path import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks +from modules import infotext, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ class UserMetadataEditor: index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = generation_parameters_copypaste.image_from_url_text(img_info) + image = infotext.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 13d888e48..b74a15323 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste def create_ui(): From d613cd17c72c753bd1e314dff74dc22d9a949374 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:38:29 +0300 Subject: [PATCH 176/311] add automatic backwards version compatibility --- modules/infotext.py | 4 +++- modules/infotext_versions.py | 35 +++++++++++++++++++++++++++++++++++ modules/shared_options.py | 1 + 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 modules/infotext_versions.py diff --git a/modules/infotext.py b/modules/infotext.py index bcbeb0fdc..7f30446b9 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import sys import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name @@ -342,6 +342,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": res["Cache FP16 weight for LoRA"] = False + infotext_versions.backcompat(res) + skip = set(shared.opts.infotext_skip_pasting) res = {k: v for k, v in res.items() if k not in skip} diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py new file mode 100644 index 000000000..01e885a26 --- /dev/null +++ b/modules/infotext_versions.py @@ -0,0 +1,35 @@ +from modules import shared +from packaging import version +import re + + +v160 = version.parse("1.6.0") + + +def parse_version(text): + if text is None: + return None + + m = re.match(r'([^-]+-[^-]+)-.*', text) + if m: + text = m.group(1) + + try: + return version.parse(text) + except Exception as e: + return None + + +def backcompat(d): + """Checks infotext Version field, and enables backwards compatibility options according to it.""" + + if not shared.opts.auto_backcompat: + return + + ver = parse_version(d.get("Version")) + if ver is None: + return + + if ver < v160: + d["Old prompt editing timelines"] = True + diff --git a/modules/shared_options.py b/modules/shared_options.py index 752a4f125..281591da8 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -212,6 +212,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), { + "auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), From 45b7bba3d06f2d4bc2fffc210cbfcb357b86add6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 14:51:56 +0300 Subject: [PATCH 177/311] add automatic version support for zero terminal SNR noise schedule option from #14145 --- modules/infotext_versions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 01e885a26..9a204d842 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -4,6 +4,7 @@ import re v160 = version.parse("1.6.0") +v170_tsnr = version.parse("v1.7.0-225") def parse_version(text): @@ -33,3 +34,6 @@ def backcompat(d): if ver < v160: d["Old prompt editing timelines"] = True + if ver < v170_tsnr: + d["Downcast alphas_cumprod"] = True + From d8126be578c7d4579c0f2ee4adbe35500bc71ce6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:00:39 +0300 Subject: [PATCH 178/311] linter --- modules/infotext.py | 2 +- modules/infotext_versions.py | 2 +- modules/postprocessing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/infotext.py b/modules/infotext.py index 7f30446b9..26e9b9493 100644 --- a/modules/infotext.py +++ b/modules/infotext.py @@ -8,7 +8,7 @@ import sys import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, errors +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions from PIL import Image sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name diff --git a/modules/infotext_versions.py b/modules/infotext_versions.py index 9a204d842..a5afeebf1 100644 --- a/modules/infotext_versions.py +++ b/modules/infotext_versions.py @@ -17,7 +17,7 @@ def parse_version(text): try: return version.parse(text) - except Exception as e: + except Exception: return None diff --git a/modules/postprocessing.py b/modules/postprocessing.py index f776f7b69..facea899f 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common from modules.shared import opts From 0743ee9b3eda8dd4ceea625d710031577201f4ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:50:47 +0300 Subject: [PATCH 179/311] re-layout checkboxes for XYZ grid a bit --- scripts/xyz_grid.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 34267c2c3..2d5509947 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -438,17 +438,16 @@ class Script(scripts.Script): with gr.Column(): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Row(): + vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.") + vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.") + vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.") with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Column(): - vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) - vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) - vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Column(): - csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -531,7 +530,7 @@ class Script(scripts.Script): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode): x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 if not no_fixed_seeds: From ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:28:58 +0300 Subject: [PATCH 180/311] option to convert VAE to bfloat16 (implementation of #9295) --- modules/processing.py | 23 ++++++++++++++++++----- modules/shared_options.py | 1 + 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 846e4796a..f06568821 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -628,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): sample = decode_first_stage(model, batch[i:i + 1])[0] if check_for_nans: + try: devices.test_for_nans(sample, "vae") except devices.NansException as e: - if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + if shared.opts.auto_vae_precision_bfloat16: + autofix_dtype = torch.bfloat16 + autofix_dtype_text = "bfloat16" + autofix_dtype_setting = "Automatically convert VAE to bfloat16" + autofix_dtype_comment = "" + elif shared.opts.auto_vae_precision: + autofix_dtype = torch.float32 + autofix_dtype_text = "32-bit float" + autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" + autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." + else: + raise e + + if devices.dtype_vae == autofix_dtype: raise e errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" - "Web UI will now convert VAE into 32-bit float and retry.\n" - "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" - "To always start with 32-bit VAE, use --no-half-vae commandline flag." + f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" + f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" ) - devices.dtype_vae = torch.float32 + devices.dtype_vae = autofix_dtype model.first_stage_model.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae) diff --git a/modules/shared_options.py b/modules/shared_options.py index ce06f022e..e813546f6 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -177,6 +177,7 @@ For img2img, VAE is used to process user's input image before the sampling, and "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), From 0aa7c53c0b9469849377aff83f43c9f75c19b3fa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:50:59 +0300 Subject: [PATCH 181/311] fix borked merge, rename fields to better match what they do, change setting default to true for #13653 --- modules/call_queue.py | 2 +- modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 2 +- modules/shared_state.py | 12 ++++++------ modules/ui_toprow.py | 8 +++++++- scripts/loopback.py | 4 ++-- scripts/xyz_grid.py | 2 +- 8 files changed, 20 insertions(+), 14 deletions(-) diff --git a/modules/call_queue.py b/modules/call_queue.py index 01c6d17f6..bcd7c5462 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,7 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): shared.state.skipped = False shared.state.interrupted = False - shared.state.interrupted_next = False + shared.state.stopping_generation = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 829faa818..e7e8e2510 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break try: diff --git a/modules/processing.py b/modules/processing.py index 00de2ed2e..f55b85ed3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 7852e0ea3..7581e276e 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -120,7 +120,6 @@ options_templates.update(options_section(('system', "System", "system"), { "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), - "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API", "system"), { @@ -286,6 +285,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives", "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), + "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"), })) options_templates.update(options_section(('ui', "User interface", "ui"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index 532fdcd8d..33996691c 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,7 +12,7 @@ log = logging.getLogger(__name__) class State: skipped = False interrupted = False - interrupted_next = False + stopping_generation = False job = "" job_no = 0 job_count = 0 @@ -80,9 +80,9 @@ class State: self.interrupted = True log.info("Received interrupt request") - def interrupt_next(self): - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") + def stop_generating(self): + self.stopping_generation = True + log.info("Received stop generating request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -96,7 +96,7 @@ class State: obj = { "skipped": self.skipped, "interrupted": self.interrupted, - "interrupted_next": self.interrupted_next, + "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -120,7 +120,7 @@ class State: self.id_live_preview = 0 self.skipped = False self.interrupted = False - self.interrupted_next = False + self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 9caf8faa2..1abc9117b 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -106,8 +106,14 @@ class Toprow: outputs=[], ) + def interrupt_function(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.stop_generating() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_function, inputs=[], outputs=[], ) diff --git a/scripts/loopback.py b/scripts/loopback.py index ad921269a..800ee882a 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ class Script(scripts.Script): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if initial_seed is None: @@ -122,7 +122,7 @@ class Script(scripts.Script): p.inpainting_fill = original_inpainting_fill - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if len(history) > 1: diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 2deff365f..2f385ebf2 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -696,7 +696,7 @@ class Script(scripts.Script): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted or state.interrupted_next: + if shared.state.interrupted or state.stopping_generation: return Processed(p, [], p.seed, "") pc = copy(p) From 1ffdedc11d49862cf0d030fb0bcc25eb0449939b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:03:08 +0300 Subject: [PATCH 182/311] restore lines lost from #13789 merge --- modules/cmd_args.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 3775e6702..e58059a1f 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -93,7 +93,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) -parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") @@ -115,8 +115,9 @@ parser.add_argument("--skip-version-check", action='store_true', help="Do not ch parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') -parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') +parser.add_argument('--add-stop-route', action='store_true', help='does not do anything') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) -parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) +parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) From 5d7d1823afab0a051a3fbbdb3213bae8051350b7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:25:30 +0300 Subject: [PATCH 183/311] rename infotext.py again, this time to infotext_utils.py; I didn't realize infotext would be used for variable names in multiple places, which makes it awkward to import the module; also fix the bug I caused by this rename that breaks tests --- .../scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/{infotext.py => infotext_utils.py} | 0 modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 15 files changed, 25 insertions(+), 25 deletions(-) rename modules/{infotext.py => infotext_utils.py} (100%) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 8aa901fd4..4c10d9c7d 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, infotext +from modules import scripts, shared, ui_components, ui_settings, infotext_utils from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 0e2807de2..9d1292e95 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ class Api: if not request.infotext: return {} - possible_fields = infotext.paste_fields[tabname]["fields"] + possible_fields = infotext_utils.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = infotext.parse_generation_parameters(request.infotext) + params = infotext_utils.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ class Api: if request.override_settings is None: request.override_settings = {} - overriden_settings = infotext.get_override_settings(params) + overriden_settings = infotext_utils.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ class Api: if geninfo is None: geninfo = "" - params = infotext.parse_generation_parameters(geninfo) + params = infotext_utils.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index e7e8e2510..04de8e62c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr import gradio as gr from modules import images as imgutil -from modules.infotext import create_override_settings_dict, parse_generation_parameters +from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/infotext.py b/modules/infotext_utils.py similarity index 100% rename from modules/infotext.py rename to modules/infotext_utils.py diff --git a/modules/postprocessing.py b/modules/postprocessing.py index facea899f..7850328f6 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils from modules.shared import opts @@ -86,7 +86,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index f55b85ed3..213a2879c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -746,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index e9941413f..ba33d8a4b 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 602932785..2d3cbb97f 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index e13924720..13fb2814f 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import infotext, shared + from modules import infotext_utils, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in infotext.paste_fields.values(): + for tab_data in infotext_utils.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index 3a481915f..49660e891 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ from contextlib import closing import modules.scripts from modules import processing -from modules.infotext import create_override_settings_dict +from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 378529c79..52b15646a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript from modules.shared import opts, cmd_opts -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.infotext import image_from_url_text, PasteField +from modules.infotext_utils import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index fd32676f9..f48ad4260 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import gradio as gr import subprocess as sp from modules import call_queue, shared -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 790af1356..beea13160 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import json import html from fastapi.exceptions import HTTPException -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 87aeb6f33..989a649b7 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import os.path import gradio as gr -from modules import infotext, images, sysinfo, errors, ui_extra_networks +from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ class UserMetadataEditor: index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = infotext.image_from_url_text(img_info) + image = infotext_utils.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b74a15323..1edb68c5c 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste def create_ui(): From 501993ebf210bf3b55173ec1910f0c84c7e75424 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 19:31:06 +0300 Subject: [PATCH 184/311] added a button to run hires fix on selected image in the gallery --- javascript/ui.js | 8 +++ modules/processing.py | 46 ++++++++++++--- modules/txt2img.py | 19 +++++- modules/ui.py | 108 +++++++++++++++++++---------------- modules/ui_common.py | 57 +++++++++++------- modules/ui_postprocessing.py | 8 +-- 6 files changed, 160 insertions(+), 86 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 18c9f891a..78d1caee5 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -150,6 +150,14 @@ function submit() { return res; } +function submit_txt2img_upscale() { + res = submit.apply(null, arguments); + + res[2] = selected_gallery_index(); + + return res; +} + function submit_img2img() { showSubmitButtons('img2img', false); diff --git a/modules/processing.py b/modules/processing.py index 213a2879c..045c7d795 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -179,6 +179,7 @@ class StableDiffusionProcessing: token_merging_ratio = 0 token_merging_ratio_hr = 0 disable_extra_networks: bool = False + firstpass_image: Image = None scripts_value: scripts.ScriptRunner = field(default=None, init=False) script_args_value: list = field(default=None, init=False) @@ -1238,18 +1239,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) - del x + if self.firstpass_image is not None and self.enable_hr: + # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix - if not self.enable_hr: - return samples - devices.torch_gc() + if self.latent_scale_mode is None: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + image = np.moveaxis(image, 2, 0) + + samples = None + decoded_samples = torch.asarray(np.expand_dims(image, 0)) + + else: + image = np.array(self.firstpass_image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + image = torch.from_numpy(np.expand_dims(image, axis=0)) + image = image.to(shared.device, dtype=devices.dtype_vae) + + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + decoded_samples = None + devices.torch_gc() - if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: - decoded_samples = None + # here we generate an image normally + + x = self.rng.next() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x + + if not self.enable_hr: + return samples + + devices.torch_gc() + + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) diff --git a/modules/txt2img.py b/modules/txt2img.py index 49660e891..4a6fe72a6 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import processing +from modules import processing, infotext_utils from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared @@ -9,9 +9,23 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): + assert len(gallery) > 0, 'No image to upscale' + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + image = infotext_utils.image_from_url_text(image_info) + + return txt2img(id_task, request, *args, firstpass_image=image) + + +def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): override_settings = create_override_settings_dict(override_settings_texts) + if firstpass_image is not None: + enable_hr = True + batch_size = 1 + n_iter = 1 + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, + firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 52b15646a..3d548430d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -375,50 +375,60 @@ def create_ui(): show_progress=False, ) - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow) + + txt2img_inputs = [ + dummy_component, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, + steps, + sampler_name, + batch_count, + batch_size, + cfg_scale, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + hr_checkpoint_name, + hr_sampler_name, + hr_prompt, + hr_negative_prompt, + override_settings, + ] + custom_inputs + + txt2img_outputs = [ + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, + ] txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", - inputs=[ - dummy_component, - toprow.prompt, - toprow.negative_prompt, - toprow.ui_styles.dropdown, - steps, - sampler_name, - batch_count, - batch_size, - cfg_scale, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - hr_checkpoint_name, - hr_sampler_name, - hr_prompt, - hr_negative_prompt, - override_settings, - - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], + inputs=txt2img_inputs, + outputs=txt2img_outputs, show_progress=False, ) toprow.prompt.submit(**txt2img_args) toprow.submit.click(**txt2img_args) + output_panel.button_upscale.click( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), + _js="submit_txt2img_upscale", + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + outputs=txt2img_outputs, + show_progress=False, + ) + res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) toprow.restore_progress_button.click( @@ -426,10 +436,10 @@ def create_ui(): _js="restoreProgressTxt2img", inputs=[dummy_component], outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -479,7 +489,7 @@ def create_ui(): toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery) extra_tabs.__exit__() @@ -710,7 +720,7 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha], ) - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) + output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow) img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -755,10 +765,10 @@ def create_ui(): img2img_batch_png_info_dir, ] + custom_inputs, outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -796,10 +806,10 @@ def create_ui(): _js="restoreProgressImg2img", inputs=[dummy_component], outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_info, + output_panel.html_log, ], show_progress=False, ) @@ -839,7 +849,7 @@ def create_ui(): )) extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery) extra_tabs.__exit__() diff --git a/modules/ui_common.py b/modules/ui_common.py index f48ad4260..ff84197c1 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import dataclasses import json import html import os @@ -104,7 +105,17 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") +@dataclasses.dataclass +class OutputPanel: + gallery = None + infotext = None + html_info = None + html_log = None + button_upscale = None + + def create_output_panel(tabname, outdir, toprow=None): + res = OutputPanel() def open_folder(f): if not os.path.exists(f): @@ -136,9 +147,8 @@ Requested path was: {f} with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) + res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None) - generation_info = None with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.") @@ -152,6 +162,9 @@ Requested path was: {f} 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") } + if tabname == 'txt2img': + res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.") + open_folder_button.click( fn=lambda: open_folder(shared.opts.outdir_samples or outdir), inputs=[], @@ -162,17 +175,17 @@ Requested path was: {f} download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], + inputs=[res.infotext, res.html_info, res.html_info], + outputs=[res.html_info, res.html_info], show_progress=False, ) @@ -180,14 +193,14 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ], show_progress=False, ) @@ -196,21 +209,21 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - generation_info, - result_gallery, - html_info, - html_info, + res.infotext, + res.gallery, + res.html_info, + res.html_info, ], outputs=[ download_files, - html_log, + res.html_log, ] ) else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") - html_log = gr.HTML(elem_id=f'html_log_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] if tabname == "txt2img": @@ -220,11 +233,11 @@ Requested path was: {f} for paste_tabname, paste_button in buttons.items(): parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery, + paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery, paste_field_names=paste_field_names )) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return res def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 1edb68c5c..8f09e6588 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -28,7 +28,7 @@ def create_ui(): toprow.create_inline_toprow_image() submit = toprow.submit - result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) @@ -48,9 +48,9 @@ def create_ui(): *script_inputs ], outputs=[ - result_images, - html_info_x, - html_log, + output_panel.gallery, + output_panel.infotext, + output_panel.html_log, ], show_progress=False, ) From c32c51a0fc49ca4fbfe02bfbae45ec49e7bf9876 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 1 Jan 2024 19:20:54 +0200 Subject: [PATCH 185/311] Fix lint issue from 501993eb --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index 78d1caee5..3430b3fef 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -151,7 +151,7 @@ function submit() { } function submit_txt2img_upscale() { - res = submit.apply(null, arguments); + var res = submit(...arguments); res[2] = selected_gallery_index(); From c2ea571005dab29b285e31a0ad4a97258360bf2d Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:57:41 -0600 Subject: [PATCH 186/311] Add inpaint options to paste fields --- modules/ui.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 3d548430d..4cdb0e9c0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -840,6 +840,10 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (inpainting_mask_invert, 'Mask mode'), + (inpainting_fill, 'Masked content'), + (inpaint_full_res, 'Inpaint area'), + (inpaint_full_res_padding, 'Only masked padding, pixels'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From a5b6a5a3adcc845237d872750ded34240cc6a810 Mon Sep 17 00:00:00 2001 From: Jibaku789 <151478027+Jibaku789@users.noreply.github.com> Date: Mon, 1 Jan 2024 14:58:55 -0600 Subject: [PATCH 187/311] Add inpaint options to img2img.py --- modules/img2img.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modules/img2img.py b/modules/img2img.py index 04de8e62c..9e09c0a00 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -225,6 +225,18 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if mask: p.extra_generation_params["Mask blur"] = mask_blur + if inpainting_mask_invert is not None: + p.extra_generation_params["Mask mode"] = inpainting_mask_invert + + if inpainting_fill is not None: + p.extra_generation_params["Masked content"] = inpainting_fill + + if inpaint_full_res is not None: + p.extra_generation_params["Inpaint area"] = inpaint_full_res + + if inpaint_full_res_padding is not None: + p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding + with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" From 1341b2208185cd89b0019bda2df63b406ec0cb5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 06:47:26 +0300 Subject: [PATCH 188/311] add an option to hide upscaling progressbar --- modules/shared_options.py | 1 + modules/upscaler_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index cca3f7be3..63488f4e7 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -115,6 +115,7 @@ options_templates.update(options_section(('system', "System", "system"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index f5cb92d51..9379f512b 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -47,7 +47,7 @@ def upscale_with_model( grid = images.split_grid(img, tile_size, tile_size, tile_overlap) newtiles = [] - with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p: for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: @@ -103,7 +103,7 @@ def tiled_upscale_2( ).type_as(img) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) - with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: for h_idx in h_idx_list: if shared.state.interrupted or shared.state.skipped: break From 80873b1538e6ca0c7ebe558f8ce4213b06fd8307 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:05:05 +0300 Subject: [PATCH 189/311] fix #14497 --- modules/img2img.py | 15 --------------- modules/infotext_utils.py | 12 ++++++++++++ modules/processing.py | 13 +++++++++++++ 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 9e09c0a00..f81405df5 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -222,21 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if shared.opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) - if mask: - p.extra_generation_params["Mask blur"] = mask_blur - - if inpainting_mask_invert is not None: - p.extra_generation_params["Mask mode"] = inpainting_mask_invert - - if inpainting_fill is not None: - p.extra_generation_params["Masked content"] = inpainting_fill - - if inpaint_full_res is not None: - p.extra_generation_params["Inpaint area"] = inpaint_full_res - - if inpaint_full_res_padding is not None: - p.extra_generation_params["Only masked padding, pixels"] = inpaint_full_res_padding - with closing(p): if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 26e9b9493..e582ee479 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -312,6 +312,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires negative prompt" not in res: res["Hires negative prompt"] = "" + if "Mask mode" not in res: + res["Mask mode"] = "Inpaint masked" + + if "Masked content" not in res: + res["Masked content"] = 'original' + + if "Inpaint area" not in res: + res["Inpaint area"] = "Whole picture" + + if "Masked area padding" not in res: + res["Masked area padding"] = 32 + restore_old_hires_fix_params(res) # Missing RNG means the default was set, which is GPU RNG diff --git a/modules/processing.py b/modules/processing.py index 045c7d795..84e7b1b45 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1530,6 +1530,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) + self.extra_generation_params["Mask mode"] = "Inpaint not masked" if self.mask_blur_x > 0: np_mask = np.array(image_mask) @@ -1543,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) image_mask = Image.fromarray(np_mask) + if self.mask_blur_x > 0 or self.mask_blur_y > 0: + self.extra_generation_params["Mask blur"] = self.mask_blur + if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') @@ -1553,6 +1557,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask = mask.crop(crop_region) image_mask = images.resize_image(2, mask, self.width, self.height) self.paste_to = (x1, y1, x2-x1, y2-y1) + + self.extra_generation_params["Inpaint area"] = "Only masked" + self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) np_mask = np.array(image_mask) @@ -1594,6 +1601,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_fill != 1: image = masking.fill(image, latent_mask) + if self.inpainting_fill == 0: + self.extra_generation_params["Masked content"] = 'fill' + if add_color_corrections: self.color_corrections.append(setup_color_correction(image)) @@ -1643,8 +1653,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.extra_generation_params["Masked content"] = 'latent noise' + elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + self.extra_generation_params["Masked content"] = 'latent nothing' self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round) From 980970d39091e572500434c69660bc6eed22498d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 07:08:32 +0300 Subject: [PATCH 190/311] final touches --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 4cdb0e9c0..7116d71c5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -843,7 +843,7 @@ def create_ui(): (inpainting_mask_invert, 'Mask mode'), (inpainting_fill, 'Masked content'), (inpaint_full_res, 'Inpaint area'), - (inpaint_full_res_padding, 'Only masked padding, pixels'), + (inpaint_full_res_padding, 'Masked area padding'), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: [PATCH 191/311] Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- .../ScuNET/scripts/scunet_model.py | 48 +++------- .../SwinIR/scripts/swinir_model.py | 62 ++----------- modules/upscaler_utils.py | 87 +++++++++++++++---- 3 files changed, 86 insertions(+), 111 deletions(-) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76d..fe5e5a192 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c794..bc427feac 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler): self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512b..e4c63f097 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ from modules import images, shared, torch_utils logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) From 2cacbc124c49f45da5b66b79d9b0a3ab943472eb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 19:52:32 +0200 Subject: [PATCH 192/311] load_spandrel_model: make `half` `prefer_half` As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not. --- modules/modelloader.py | 20 ++++++++++++++------ modules/realesrgan_model.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index a71941375..e100bb246 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,23 +139,31 @@ def load_upscalers(): def load_spandrel_model( - path: str, + path: str | os.PathLike, *, device: str | torch.device | None, - half: bool = False, + prefer_half: bool = False, dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel - model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) - if half: - model_descriptor.model.half() + half = False + if prefer_half: + if model_descriptor.supports_half: + model_descriptor.model.half() + half = True + else: + logger.info("Model %s does not support half precision, ignoring --half", path) if dtype: model_descriptor.model.to(dtype=dtype) model_descriptor.model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + logger.debug( + "Loaded %s from %s (device=%s, half=%s, dtype=%s)", + model_descriptor, path, device, half, dtype, + ) return model_descriptor diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 4d35b695c..ff9d8ac0d 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler): model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, - half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( From 62bd7624d2adb123e84d4625804e9ad94d1db018 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 10:40:26 +0200 Subject: [PATCH 193/311] Remove licenses for code that's no longer copy-pasted; adjust README --- README.md | 11 +- html/licenses.html | 310 +-------------------------------------------- 2 files changed, 7 insertions(+), 314 deletions(-) diff --git a/README.md b/README.md index 9f9f33b12..72908fa77 100644 --- a/README.md +++ b/README.md @@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git -- GFPGAN - https://github.com/TencentARC/GFPGAN.git -- CodeFormer - https://github.com/sczhou/CodeFormer -- ESRGAN - https://github.com/xinntao/ESRGAN -- SwinIR - https://github.com/JingyunLiang/SwinIR -- Swin2SR - https://github.com/mv-lab/swin2sr +- Spandrel - https://github.com/chaiNNer-org/spandrel implementing + - GFPGAN - https://github.com/TencentARC/GFPGAN.git + - CodeFormer - https://github.com/sczhou/CodeFormer + - ESRGAN - https://github.com/xinntao/ESRGAN + - SwinIR - https://github.com/JingyunLiang/SwinIR + - Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - MiDaS - https://github.com/isl-org/MiDaS - Ideas for optimizations - https://github.com/basujindal/stable-diffusion diff --git a/html/licenses.html b/html/licenses.html index ef6f2c0a4..9f5d1e9dc 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -4,107 +4,6 @@ #licenses pre { margin: 1em 0 2em 0;} -

CodeFormer

-Parts of CodeFormer code had to be copied to be compatible with GFPGAN. -
-S-Lab License 1.0
-
-Copyright 2022 S-Lab
-
-Redistribution and use for non-commercial purpose in source and
-binary forms, with or without modification, are permitted provided
-that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright
-   notice, this list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright
-   notice, this list of conditions and the following disclaimer in
-   the documentation and/or other materials provided with the
-   distribution.
-
-3. Neither the name of the copyright holder nor the names of its
-   contributors may be used to endorse or promote products derived
-   from this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-In the event that redistribution and/or use for commercial purpose in
-source or binary forms, with or without modification is required,
-please contact the contributor(s) of the work.
-
- - -

ESRGAN

-Code for architecture and reading models copied. -
-MIT License
-
-Copyright (c) 2021 victorca25
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-
- -

Real-ESRGAN

-Some code is copied to support ESRGAN models. -
-BSD 3-Clause License
-
-Copyright (c) 2021, Xintao Wang
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
-   list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
-   this list of conditions and the following disclaimer in the documentation
-   and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
-   contributors may be used to endorse or promote products derived from
-   this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-

InvokeAI

Some code for compatibility with OSX is taken from lstein's repository.
@@ -183,213 +82,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 SOFTWARE.
 
-

SwinIR

-Code added by contributors, most likely copied from this repository. - -
-                                 Apache License
-                           Version 2.0, January 2004
-                        http://www.apache.org/licenses/
-
-   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-   1. Definitions.
-
-      "License" shall mean the terms and conditions for use, reproduction,
-      and distribution as defined by Sections 1 through 9 of this document.
-
-      "Licensor" shall mean the copyright owner or entity authorized by
-      the copyright owner that is granting the License.
-
-      "Legal Entity" shall mean the union of the acting entity and all
-      other entities that control, are controlled by, or are under common
-      control with that entity. For the purposes of this definition,
-      "control" means (i) the power, direct or indirect, to cause the
-      direction or management of such entity, whether by contract or
-      otherwise, or (ii) ownership of fifty percent (50%) or more of the
-      outstanding shares, or (iii) beneficial ownership of such entity.
-
-      "You" (or "Your") shall mean an individual or Legal Entity
-      exercising permissions granted by this License.
-
-      "Source" form shall mean the preferred form for making modifications,
-      including but not limited to software source code, documentation
-      source, and configuration files.
-
-      "Object" form shall mean any form resulting from mechanical
-      transformation or translation of a Source form, including but
-      not limited to compiled object code, generated documentation,
-      and conversions to other media types.
-
-      "Work" shall mean the work of authorship, whether in Source or
-      Object form, made available under the License, as indicated by a
-      copyright notice that is included in or attached to the work
-      (an example is provided in the Appendix below).
-
-      "Derivative Works" shall mean any work, whether in Source or Object
-      form, that is based on (or derived from) the Work and for which the
-      editorial revisions, annotations, elaborations, or other modifications
-      represent, as a whole, an original work of authorship. For the purposes
-      of this License, Derivative Works shall not include works that remain
-      separable from, or merely link (or bind by name) to the interfaces of,
-      the Work and Derivative Works thereof.
-
-      "Contribution" shall mean any work of authorship, including
-      the original version of the Work and any modifications or additions
-      to that Work or Derivative Works thereof, that is intentionally
-      submitted to Licensor for inclusion in the Work by the copyright owner
-      or by an individual or Legal Entity authorized to submit on behalf of
-      the copyright owner. For the purposes of this definition, "submitted"
-      means any form of electronic, verbal, or written communication sent
-      to the Licensor or its representatives, including but not limited to
-      communication on electronic mailing lists, source code control systems,
-      and issue tracking systems that are managed by, or on behalf of, the
-      Licensor for the purpose of discussing and improving the Work, but
-      excluding communication that is conspicuously marked or otherwise
-      designated in writing by the copyright owner as "Not a Contribution."
-
-      "Contributor" shall mean Licensor and any individual or Legal Entity
-      on behalf of whom a Contribution has been received by Licensor and
-      subsequently incorporated within the Work.
-
-   2. Grant of Copyright License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      copyright license to reproduce, prepare Derivative Works of,
-      publicly display, publicly perform, sublicense, and distribute the
-      Work and such Derivative Works in Source or Object form.
-
-   3. Grant of Patent License. Subject to the terms and conditions of
-      this License, each Contributor hereby grants to You a perpetual,
-      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
-      (except as stated in this section) patent license to make, have made,
-      use, offer to sell, sell, import, and otherwise transfer the Work,
-      where such license applies only to those patent claims licensable
-      by such Contributor that are necessarily infringed by their
-      Contribution(s) alone or by combination of their Contribution(s)
-      with the Work to which such Contribution(s) was submitted. If You
-      institute patent litigation against any entity (including a
-      cross-claim or counterclaim in a lawsuit) alleging that the Work
-      or a Contribution incorporated within the Work constitutes direct
-      or contributory patent infringement, then any patent licenses
-      granted to You under this License for that Work shall terminate
-      as of the date such litigation is filed.
-
-   4. Redistribution. You may reproduce and distribute copies of the
-      Work or Derivative Works thereof in any medium, with or without
-      modifications, and in Source or Object form, provided that You
-      meet the following conditions:
-
-      (a) You must give any other recipients of the Work or
-          Derivative Works a copy of this License; and
-
-      (b) You must cause any modified files to carry prominent notices
-          stating that You changed the files; and
-
-      (c) You must retain, in the Source form of any Derivative Works
-          that You distribute, all copyright, patent, trademark, and
-          attribution notices from the Source form of the Work,
-          excluding those notices that do not pertain to any part of
-          the Derivative Works; and
-
-      (d) If the Work includes a "NOTICE" text file as part of its
-          distribution, then any Derivative Works that You distribute must
-          include a readable copy of the attribution notices contained
-          within such NOTICE file, excluding those notices that do not
-          pertain to any part of the Derivative Works, in at least one
-          of the following places: within a NOTICE text file distributed
-          as part of the Derivative Works; within the Source form or
-          documentation, if provided along with the Derivative Works; or,
-          within a display generated by the Derivative Works, if and
-          wherever such third-party notices normally appear. The contents
-          of the NOTICE file are for informational purposes only and
-          do not modify the License. You may add Your own attribution
-          notices within Derivative Works that You distribute, alongside
-          or as an addendum to the NOTICE text from the Work, provided
-          that such additional attribution notices cannot be construed
-          as modifying the License.
-
-      You may add Your own copyright statement to Your modifications and
-      may provide additional or different license terms and conditions
-      for use, reproduction, or distribution of Your modifications, or
-      for any such Derivative Works as a whole, provided Your use,
-      reproduction, and distribution of the Work otherwise complies with
-      the conditions stated in this License.
-
-   5. Submission of Contributions. Unless You explicitly state otherwise,
-      any Contribution intentionally submitted for inclusion in the Work
-      by You to the Licensor shall be under the terms and conditions of
-      this License, without any additional terms or conditions.
-      Notwithstanding the above, nothing herein shall supersede or modify
-      the terms of any separate license agreement you may have executed
-      with Licensor regarding such Contributions.
-
-   6. Trademarks. This License does not grant permission to use the trade
-      names, trademarks, service marks, or product names of the Licensor,
-      except as required for reasonable and customary use in describing the
-      origin of the Work and reproducing the content of the NOTICE file.
-
-   7. Disclaimer of Warranty. Unless required by applicable law or
-      agreed to in writing, Licensor provides the Work (and each
-      Contributor provides its Contributions) on an "AS IS" BASIS,
-      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
-      implied, including, without limitation, any warranties or conditions
-      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
-      PARTICULAR PURPOSE. You are solely responsible for determining the
-      appropriateness of using or redistributing the Work and assume any
-      risks associated with Your exercise of permissions under this License.
-
-   8. Limitation of Liability. In no event and under no legal theory,
-      whether in tort (including negligence), contract, or otherwise,
-      unless required by applicable law (such as deliberate and grossly
-      negligent acts) or agreed to in writing, shall any Contributor be
-      liable to You for damages, including any direct, indirect, special,
-      incidental, or consequential damages of any character arising as a
-      result of this License or out of the use or inability to use the
-      Work (including but not limited to damages for loss of goodwill,
-      work stoppage, computer failure or malfunction, or any and all
-      other commercial damages or losses), even if such Contributor
-      has been advised of the possibility of such damages.
-
-   9. Accepting Warranty or Additional Liability. While redistributing
-      the Work or Derivative Works thereof, You may choose to offer,
-      and charge a fee for, acceptance of support, warranty, indemnity,
-      or other liability obligations and/or rights consistent with this
-      License. However, in accepting such obligations, You may act only
-      on Your own behalf and on Your sole responsibility, not on behalf
-      of any other Contributor, and only if You agree to indemnify,
-      defend, and hold each Contributor harmless for any liability
-      incurred by, or claims asserted against, such Contributor by reason
-      of your accepting any such warranty or additional liability.
-
-   END OF TERMS AND CONDITIONS
-
-   APPENDIX: How to apply the Apache License to your work.
-
-      To apply the Apache License to your work, attach the following
-      boilerplate notice, with the fields enclosed by brackets "[]"
-      replaced with your own identifying information. (Don't include
-      the brackets!)  The text should be enclosed in the appropriate
-      comment syntax for the file format. We also recommend that a
-      file or class name and description of purpose be included on the
-      same "printed page" as the copyright notice for easier
-      identification within third-party archives.
-
-   Copyright [2021] [SwinIR Authors]
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-       http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License.
-
-

Memory Efficient Attention

The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.
@@ -687,4 +379,4 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 SOFTWARE.
-
\ No newline at end of file + From 7ad6899bf987a8ee615efbcfc99562457f89cd8b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 17:14:05 +0200 Subject: [PATCH 194/311] torch_bgr_to_pil_image: round, don't truncate This matches what `realesrgan` does. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f097..4f1417cf0 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -30,7 +30,7 @@ def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale - arr = arr.astype(np.uint8) + arr = arr.round().astype(np.uint8) arr = arr[:, :, ::-1] # flip BGR to RGB return Image.fromarray(arr, "RGB") From fccd0b00c2ca17360b7b956cd2e9bd1fb42c017d Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 3 Jan 2024 18:55:43 +0900 Subject: [PATCH 195/311] reduce unnecessary re-indexing extra networks dir --- modules/ui_extra_networks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index beea13160..e1c679eca 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -417,21 +417,21 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def create_html(): + ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] + def pages_html(): if not ui.pages_contents: - return refresh() - + create_html() return ui.pages_contents def refresh(): for pg in ui.stored_extra_pages: pg.refresh() - - ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] - + create_html() return ui.pages_contents - interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) + interface.load(fn=pages_html, inputs=[], outputs=ui.pages) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui From bfc48fbc244130770991fab284f6fedcef2054e7 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 03:46:05 +0900 Subject: [PATCH 196/311] paste infotext cast int as float --- modules/infotext_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479..a21329e62 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -477,6 +477,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, if valtype == bool and v == "False": val = False + elif valtype == int: + val = float(v) else: val = valtype(v) From dfdc51246c678b585e1bdfdb7d2f202b0ca0e362 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:13 +0200 Subject: [PATCH 197/311] SwinIR: use prefer_half --- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index bc427feac..6a8e21b02 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,6 +1,7 @@ import logging import sys +import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils @@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler): model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), - dtype=devices.dtype, + prefer_half=(devices.dtype == torch.float16), expected_architecture="SwinIR", ) if getattr(shared.opts, 'SWIN_torch_compile', False): From 3d31d5c27beb433fa37b30f135ec06a278a87630 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:49 +0200 Subject: [PATCH 198/311] SwinIR: pass model.scale --- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 6a8e21b02..16bf9b792 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -51,7 +51,7 @@ class UpscalerSwinIR(Upscaler): model, tile_size=shared.opts.SWIN_tile, tile_overlap=shared.opts.SWIN_tile_overlap, - scale=4, # TODO: This was hard-coded before too... + scale=model.scale, desc="SwinIR", ) devices.torch_gc() From 62470ee23443cb2ad3943a152ccae26a689c86e1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:39:12 +0200 Subject: [PATCH 199/311] upscale_2: cast image to model's dtype --- modules/upscaler_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f097..5db748776 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) From 50158a1fc9b4dd47a7bef70d34fbb0b30d5e8b47 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 4 Jan 2024 06:21:53 +0900 Subject: [PATCH 200/311] handle config.json failed to load --- modules/launch_utils.py | 4 +++- modules/options.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce7..e2ad412a6 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -249,7 +249,9 @@ def list_extensions(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) except Exception: - errors.report("Could not load settings", exc_info=True) + errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(settings_file, os.path.join(script_path, "tmp", "config.json")) + settings = {} disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/options.py b/modules/options.py index 09ff9403d..503b40e98 100644 --- a/modules/options.py +++ b/modules/options.py @@ -1,3 +1,4 @@ +import os import json import sys from dataclasses import dataclass @@ -6,6 +7,7 @@ import gradio as gr from modules import errors from modules.shared_cmd_options import cmd_opts +from modules.paths_internal import script_path class OptionInfo: @@ -193,9 +195,13 @@ class Options: return type_x == type_y def load(self, filename): - with open(filename, "r", encoding="utf8") as file: - self.data = json.load(file) - + try: + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + except Exception: + errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True) + os.replace(filename, os.path.join(script_path, "tmp", "config.json")) + self.data = {} # 1.6.0 VAE defaults if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') From d9034b48a526f0a0c3e8f0dbf7c171bf4f0597fd Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 00:16:58 +0200 Subject: [PATCH 201/311] Avoid unnecessary `isfile`/`exists` calls --- modules/cache.py | 17 ++++++++--------- modules/extensions.py | 11 ++++++----- modules/extra_networks.py | 7 ++++--- modules/infotext_utils.py | 4 +++- modules/launch_utils.py | 7 ++++--- modules/postprocessing.py | 7 ++++--- modules/shared_init.py | 4 +++- modules/ui_gradio_extensions.py | 8 +++----- modules/ui_loadsave.py | 5 +++-- modules/util.py | 6 +++--- 10 files changed, 41 insertions(+), 35 deletions(-) diff --git a/modules/cache.py b/modules/cache.py index 2d37e7b99..a9822a0eb 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -62,16 +62,15 @@ def cache(subsection): if cache_data is None: with cache_lock: if cache_data is None: - if not os.path.isfile(cache_filename): + try: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + except FileNotFoundError: + cache_data = {} + except Exception: + os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) + print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') cache_data = {} - else: - try: - with open(cache_filename, "r", encoding="utf8") as file: - cache_data = json.load(file) - except Exception: - os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json")) - print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache') - cache_data = {} s = cache_data.get(subsection, {}) cache_data[subsection] = s diff --git a/modules/extensions.py b/modules/extensions.py index 1899cd529..99e7ee60f 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -32,11 +32,12 @@ class ExtensionMetadata: self.config = configparser.ConfigParser() filepath = os.path.join(path, self.filename) - if os.path.isfile(filepath): - try: - self.config.read(filepath) - except Exception: - errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) + # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is), + # so no need to check whether the file exists beforehand. + try: + self.config.read(filepath) + except Exception: + errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True) self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name) self.canonical_name = canonical_name.lower().strip() diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b95336778..cd030fa31 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -215,9 +215,10 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except FileNotFoundError: + pass except Exception as e: errors.display(e, f"reading extra network user metadata from {metadata_filename}") diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index e582ee479..6978a0bf0 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -453,9 +453,11 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: filename = os.path.join(data_path, "params.txt") - if os.path.exists(filename): + try: with open(filename, "r", encoding="utf8") as file: prompt = file.read() + except OSError: + pass params = parse_generation_parameters(prompt) script_callbacks.infotext_pasted_callback(prompt, params) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2cbd8ce7..febd8c24a 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -245,9 +245,10 @@ def list_extensions(settings_file): settings = {} try: - if os.path.isfile(settings_file): - with open(settings_file, "r", encoding="utf8") as file: - settings = json.load(file) + with open(settings_file, "r", encoding="utf8") as file: + settings = json.load(file) + except FileNotFoundError: + pass except Exception: errors.report("Could not load settings", exc_info=True) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7850328f6..7449b0dc5 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -97,11 +97,12 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, if pp.caption: caption_filename = os.path.splitext(fullfn)[0] + ".txt" - if os.path.isfile(caption_filename): + existing_caption = "" + try: with open(caption_filename, encoding="utf8") as file: existing_caption = file.read().strip() - else: - existing_caption = "" + except FileNotFoundError: + pass action = shared.opts.postprocessing_existing_caption_action if action == 'Prepend' and existing_caption: diff --git a/modules/shared_init.py b/modules/shared_init.py index d3fb687e0..586be3423 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -18,8 +18,10 @@ def initialize(): shared.options_templates = shared_options.options_templates shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) shared.restricted_opts = shared_options.restricted_opts - if os.path.exists(shared.config_filename): + try: shared.opts.load(shared.config_filename) + except FileNotFoundError: + pass from modules import devices devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index a86c368ef..f5278d22f 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -35,13 +35,11 @@ def css_html(): return f'' for cssfile in scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - head += stylesheet(cssfile) - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) + user_css = os.path.join(data_path, "user.css") + if os.path.exists(user_css): + head += stylesheet(user_css) return head diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 693ff75c5..2555cdb6c 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -26,8 +26,9 @@ class UiLoadsave: self.ui_defaults_review = None try: - if os.path.exists(self.filename): - self.ui_settings = self.read_from_file() + self.ui_settings = self.read_from_file() + except FileNotFoundError: + pass except Exception as e: self.error_loading = True errors.display(e, "loading settings") diff --git a/modules/util.py b/modules/util.py index 4861bcb08..d503f2672 100644 --- a/modules/util.py +++ b/modules/util.py @@ -21,11 +21,11 @@ def html_path(filename): def html(filename): path = html_path(filename) - if os.path.exists(path): + try: with open(path, encoding="utf8") as file: return file.read() - - return "" + except OSError: + return "" def walk_files(path, allowed_extensions=None): From 420f56c2e85ebbd3f530cf2c7b22022fda13ae13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:28:05 +0300 Subject: [PATCH 202/311] mass file lister as an attempt to tackle #14507 --- modules/extra_networks.py | 5 +-- modules/ui_extra_networks.py | 18 ++++++---- modules/util.py | 70 ++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index b95336778..04249dffd 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -206,7 +206,7 @@ def parse_prompts(prompts): return res, extra_data -def get_user_metadata(filename): +def get_user_metadata(filename, lister=None): if filename is None: return {} @@ -215,7 +215,8 @@ def get_user_metadata(filename): metadata = {} try: - if os.path.isfile(metadata_filename): + exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename) + if exists: with open(metadata_filename, "r", encoding="utf8") as file: metadata = json.load(file) except Exception as e: diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e1c679eca..c06c86649 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,7 +3,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util from modules.images import read_info_from_image, save_image_with_geninfo import gradio as gr import json @@ -107,6 +107,7 @@ class ExtraNetworksPage: self.allow_negative_prompt = False self.metadata = {} self.items = {} + self.lister = util.MassFileLister() def refresh(self): pass @@ -123,7 +124,7 @@ class ExtraNetworksPage: def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace('\\', '/')) - mtime = os.path.getmtime(filename) + mtime, _ = self.lister.mctime(filename) return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): @@ -137,6 +138,8 @@ class ExtraNetworksPage: return "" def create_html(self, tabname): + self.lister.reset() + items_html = '' self.metadata = {} @@ -282,10 +285,10 @@ class ExtraNetworksPage: List of default keys used for sorting in the UI. """ pth = Path(path) - stat = pth.stat() + mtime, ctime = self.lister.mctime(path) return { - "date_created": int(stat.st_ctime or 0), - "date_modified": int(stat.st_mtime or 0), + "date_created": int(mtime), + "date_modified": int(ctime), "name": pth.name.lower(), "path": str(pth.parent).lower(), } @@ -298,7 +301,7 @@ class ExtraNetworksPage: potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) for file in potential_files: - if os.path.isfile(file): + if self.lister.exists(file): return self.link_preview(file) return None @@ -308,6 +311,9 @@ class ExtraNetworksPage: Find and read a description file for a given path (without extension). """ for file in [f"{path}.txt", f"{path}.description.txt"]: + if not self.lister.exists(file): + continue + try: with open(file, "r", encoding="utf-8", errors="replace") as f: return f.read() diff --git a/modules/util.py b/modules/util.py index 4861bcb08..c2a27590d 100644 --- a/modules/util.py +++ b/modules/util.py @@ -66,3 +66,73 @@ def truncate_path(target_path, base_path=cwd): except ValueError: pass return abs_target + + +class MassFileListerCachedDir: + """A class that caches file metadata for a specific directory.""" + + def __init__(self, dirname): + self.files = None + self.files_cased = None + self.dirname = dirname + + stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname)) + files = [(n, s.st_mtime, s.st_ctime) for n, s in stats] + self.files = {x[0].lower(): x for x in files} + self.files_cased = {x[0]: x for x in files} + + +class MassFileLister: + """A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.""" + + def __init__(self): + self.cached_dirs = {} + + def find(self, path): + """ + Find the metadata for a file at the given path. + + Returns: + tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not. + """ + + dirname, filename = os.path.split(path) + + cached_dir = self.cached_dirs.get(dirname) + if cached_dir is None: + cached_dir = MassFileListerCachedDir(dirname) + self.cached_dirs[dirname] = cached_dir + + stats = cached_dir.files_cased.get(filename) + if stats is not None: + return stats + + stats = cached_dir.files.get(filename.lower()) + if stats is None: + return None + + try: + os_stats = os.stat(path, follow_symlinks=False) + return filename, os_stats.st_mtime, os_stats.st_ctime + except Exception: + return None + + def exists(self, path): + """Check if a file exists at the given path.""" + + return self.find(path) is not None + + def mctime(self, path): + """ + Get the modification and creation times for a file at the given path. + + Returns: + tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not. + """ + + stats = self.find(path) + return (0, 0) if stats is None else stats[1:3] + + def reset(self): + """Clear the cache of all directories.""" + self.cached_dirs.clear() From 320a217b78047f30e1aa5e735742669a7f4c6bd8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 02:39:02 +0300 Subject: [PATCH 203/311] forgot something --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c06c86649..62db36f53 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -114,7 +114,7 @@ class ExtraNetworksPage: def read_user_metadata(self, item): filename = item.get("filename", None) - metadata = extra_networks.get_user_metadata(filename) + metadata = extra_networks.get_user_metadata(filename, lister=self.lister) desc = metadata.get("description", None) if desc is not None: From 15ec54dd969d6dc3fea7790ca5cce5badcfda426 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 4 Jan 2024 19:47:00 +0300 Subject: [PATCH 204/311] Have upscale button use the same seed as hires fix. --- modules/scripts.py | 21 +++++++++++++++++++++ modules/txt2img.py | 14 +++++++++++++- modules/ui.py | 10 +++++----- modules/ui_common.py | 26 +++++++++++++------------- modules/ui_postprocessing.py | 2 +- 5 files changed, 53 insertions(+), 20 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 017aed5a0..cf938ebb9 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -91,6 +91,9 @@ class Script: setup_for_ui_only = False """If true, the script setup will only be run in Gradio UI, not in API""" + controls = None + """A list of controls retured by the ui().""" + def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -624,6 +627,7 @@ class ScriptRunner: import modules.api.models as api_models controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + script.controls = controls if controls is None: return @@ -918,6 +922,23 @@ class ScriptRunner: except Exception: errors.report(f"Error running setup: {script.filename}", exc_info=True) + def set_named_arg(self, args, script_type, arg_elem_id, value): + script = next((x for x in self.scripts if type(x).__name__ == script_type), None) + if script is None: + return + + for i, control in enumerate(script.controls): + if arg_elem_id in control.elem_id: + index = script.args_from + i + + if isinstance(args, list): + args[index] = value + return args + elif isinstance(args, tuple): + return args[:index] + (value,) + args[index+1:] + else: + return None + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 4a6fe72a6..41bb9da33 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,3 +1,4 @@ +import json from contextlib import closing import modules.scripts @@ -9,12 +10,19 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args): +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] image = infotext_utils.image_from_url_text(image_info) + gallery_index_from_end = len(gallery) - gallery_index + image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + return txt2img(id_task, request, *args, firstpass_image=image) @@ -22,6 +30,10 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str override_settings = create_override_settings_dict(override_settings_texts) if firstpass_image is not None: + seed = getattr(firstpass_image, 'seed', None) + if seed: + args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) + enable_hr = True batch_size = 1 n_iter = 1 diff --git a/modules/ui.py b/modules/ui.py index 7116d71c5..2d2e333b2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -405,8 +405,8 @@ def create_ui(): txt2img_outputs = [ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ] @@ -424,7 +424,7 @@ def create_ui(): output_panel.button_upscale.click( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']), _js="submit_txt2img_upscale", - inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:], + inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:], outputs=txt2img_outputs, show_progress=False, ) @@ -437,8 +437,8 @@ def create_ui(): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -766,8 +766,8 @@ def create_ui(): ] + custom_inputs, outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, @@ -807,8 +807,8 @@ def create_ui(): inputs=[dummy_component], outputs=[ output_panel.gallery, + output_panel.generation_info, output_panel.infotext, - output_panel.html_info, output_panel.html_log, ], show_progress=False, diff --git a/modules/ui_common.py b/modules/ui_common.py index ff84197c1..f17259c29 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -108,8 +108,8 @@ def save_files(js_data, images, do_make_zip, index): @dataclasses.dataclass class OutputPanel: gallery = None + generation_info = None infotext = None - html_info = None html_log = None button_upscale = None @@ -175,17 +175,17 @@ Requested path was: {f} download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log") - res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[res.infotext, res.html_info, res.html_info], - outputs=[res.html_info, res.html_info], + inputs=[res.generation_info, res.infotext, res.infotext], + outputs=[res.infotext, res.infotext], show_progress=False, ) @@ -193,10 +193,10 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -209,10 +209,10 @@ Requested path was: {f} fn=call_queue.wrap_gradio_call(save_files), _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", inputs=[ - res.infotext, + res.generation_info, res.gallery, - res.html_info, - res.html_info, + res.infotext, + res.infotext, ], outputs=[ download_files, @@ -221,8 +221,8 @@ Requested path was: {f} ) else: - res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}') - res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") + res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}') + res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext") res.html_log = gr.HTML(elem_id=f'html_log_{tabname}') paste_field_names = [] diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 8f09e6588..7a132ac22 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -49,7 +49,7 @@ def create_ui(): ], outputs=[ output_panel.gallery, - output_panel.infotext, + output_panel.generation_info, output_panel.html_log, ], show_progress=False, From 9805f35c6f3ef0b0fc4e3648aa3d6eddf0a907af Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 19:13:36 +0200 Subject: [PATCH 205/311] Ensure GRADIO_ANALYTICS_ENABLED is set early enough --- modules/initialize.py | 2 ++ modules/launch_utils.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/initialize.py b/modules/initialize.py index 4a3cd98cf..7c1ac99ef 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -1,5 +1,6 @@ import importlib import logging +import os import sys import warnings from threading import Thread @@ -18,6 +19,7 @@ def imports(): warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') import gradio # noqa: F401 startup_timer.record("import gradio") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 7ebdf0b40..c2a7ae932 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -27,8 +27,7 @@ dir_repos = "repositories" # Whether to default to printing command output default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1") -if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: - os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' +os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False') def check_python_version(): From 6fa42e919f27d7316a2c7b61674fb3eb17f3a1bb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 4 Jan 2024 16:13:08 +0200 Subject: [PATCH 206/311] Fix logging configuration again * Only use `tqdm.write()` if `tqdm` is active, defer to stderr * Correct log formatter for TqdmLoggingHandler * If `rich` is installed and `SD_WEBUI_RICH_LOG` is set, use `rich`'s formatter --- modules/logging_config.py | 62 ++++++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 792698756..11eee9a63 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -1,41 +1,57 @@ -import os import logging +import os try: - from tqdm.auto import tqdm + from tqdm import tqdm + class TqdmLoggingHandler(logging.Handler): - def __init__(self, level=logging.INFO): - super().__init__(level) + def __init__(self, fallback_handler: logging.Handler): + super().__init__() + self.fallback_handler = fallback_handler def emit(self, record): try: - msg = self.format(record) - tqdm.write(msg) - self.flush() + # If there are active tqdm progress bars, + # attempt to not interfere with them. + if tqdm._instances: + tqdm.write(self.format(record)) + else: + self.fallback_handler.emit(record) except Exception: - self.handleError(record) + self.fallback_handler.emit(record) - TQDM_IMPORTED = True except ImportError: - # tqdm does not exist before first launch - # I will import once the UI finishes seting up the enviroment and reloads. - TQDM_IMPORTED = False + TqdmLoggingHandler = None + def setup_logging(loglevel): if loglevel is None: loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") - loghandlers = [] + if not loglevel: + return - if TQDM_IMPORTED: - loghandlers.append(TqdmLoggingHandler()) + if logging.root.handlers: + # Already configured, do not interfere + return - if loglevel: - log_level = getattr(logging, loglevel.upper(), None) or logging.INFO - logging.basicConfig( - level=log_level, - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=loghandlers - ) + if os.environ.get("SD_WEBUI_RICH_LOG"): + from rich.logging import RichHandler + handler = RichHandler() + else: + handler = logging.StreamHandler() + + if TqdmLoggingHandler: + handler = TqdmLoggingHandler(handler) + + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) + + handler.setFormatter(formatter) + + log_level = getattr(logging, loglevel.upper(), None) or logging.INFO + logging.root.setLevel(log_level) + logging.root.addHandler(handler) From f8f38c7c28e48f9f79225c969e3e82b1adcfb910 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:31:48 +0800 Subject: [PATCH 207/311] Fix dtype casting for OFT module --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020f..342fcd0dc 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) From 18ca987c92f52690daec43a6c67363c341bb6008 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:32:19 +0800 Subject: [PATCH 208/311] Add general forward method for all modules. --- extensions-builtin/Lora/network.py | 34 ++++++++++++++++++++++++++++- extensions-builtin/Lora/networks.py | 12 +++++----- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a62e5eff9..f9b571b5f 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,10 @@ import os from collections import namedtuple import enum +import torch +import torch.nn as nn +import torch.nn.functional as F + from modules import sd_models, cache, errors, hashes, shared NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -115,6 +119,29 @@ class NetworkModule: if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + self.ops = None + self.extra_kwargs = {} + if isinstance(self.sd_module, nn.Conv2d): + self.ops = F.conv2d + self.extra_kwargs = { + 'stride': self.sd_module.stride, + 'padding': self.sd_module.padding + } + elif isinstance(self.sd_module, nn.Linear): + self.ops = F.linear + elif isinstance(self.sd_module, nn.LayerNorm): + self.ops = F.layer_norm + self.extra_kwargs = { + 'normalized_shape': self.sd_module.normalized_shape, + 'eps': self.sd_module.eps + } + elif isinstance(self.sd_module, nn.GroupNorm): + self.ops = F.group_norm + self.extra_kwargs = { + 'num_groups': self.sd_module.num_groups, + 'eps': self.sd_module.eps + } + self.dim = None self.bias = weights.w.get("bias") self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None @@ -155,5 +182,10 @@ class NetworkModule: raise NotImplementedError() def forward(self, x, y): - raise NotImplementedError() + """A general forward implementation for all modules""" + if self.ops is None: + raise NotImplementedError() + else: + updown, ex_bias = self.calc_updown(self.sd_module.weight) + return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 72ebd6241..32e10b625 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_current_names = wanted_names -def network_forward(module, input, original_forward): +def network_forward(org_module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: - return original_forward(module, input) + return original_forward(org_module, input) input = devices.cond_cast_unet(input) - network_restore_weights_from_backup(module) - network_reset_cached_weight(module) + network_restore_weights_from_backup(org_module) + network_reset_cached_weight(org_module) - y = original_forward(module, input) + y = original_forward(org_module, input) - network_layer_name = getattr(module, 'network_layer_name', None) + network_layer_name = getattr(org_module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: From 44744d6005da5e424267698ee3279caa597dfebc Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:38:05 +0800 Subject: [PATCH 209/311] linting --- extensions-builtin/Lora/network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index f9b571b5f..b8fd91941 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,7 +3,6 @@ import os from collections import namedtuple import enum -import torch import torch.nn as nn import torch.nn.functional as F @@ -124,7 +123,7 @@ class NetworkModule: if isinstance(self.sd_module, nn.Conv2d): self.ops = F.conv2d self.extra_kwargs = { - 'stride': self.sd_module.stride, + 'stride': self.sd_module.stride, 'padding': self.sd_module.padding } elif isinstance(self.sd_module, nn.Linear): From 88ba095fd022fbc1acbbbc845e23493a5b2c5d8b Mon Sep 17 00:00:00 2001 From: Keshav Nischal <91429557+keshav-nischal@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:15:58 +0530 Subject: [PATCH 210/311] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9f9f33b12..7f1ea4575 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Stable Diffusion web UI -A browser interface based on Gradio library for Stable Diffusion. +A web interface for Stable Diffusion, implemented using Gradio library. ![](screenshot.png) From 233c66b36eba07b905f9743c2ad807aec33d9ccb Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 5 Jan 2024 12:28:32 +0300 Subject: [PATCH 211/311] Make the upscale button update the gallery with the new image rather than replace it. --- modules/processing.py | 6 ++- modules/txt2img.py | 85 +++++++++++++++++++++++++++++-------------- 2 files changed, 63 insertions(+), 28 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 84e7b1b45..dcc807fe3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,7 +732,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), - "Denoising strength": getattr(p, 'denoising_strength', None), + "Denoising strength": p.extra_generation_params.get("Denoising strength"), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, @@ -1198,6 +1198,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + self.extra_generation_params["Denoising strength"] = self.denoising_strength + if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint': self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) @@ -1516,6 +1518,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.mask_blur_y = value def init(self, all_prompts, all_seeds, all_subseeds): + self.extra_generation_params["Denoising strength"] = self.denoising_strength + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) diff --git a/modules/txt2img.py b/modules/txt2img.py index 41bb9da33..c4cc12d2f 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,36 +7,15 @@ from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html +from PIL import Image import gradio as gr -def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): - assert len(gallery) > 0, 'No image to upscale' - assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' - - geninfo = json.loads(generation_info) - all_seeds = geninfo["all_seeds"] - - image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] - image = infotext_utils.image_from_url_text(image_info) - - gallery_index_from_end = len(gallery) - gallery_index - image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - - return txt2img(id_task, request, *args, firstpass_image=image) - - -def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None): +def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False): override_settings = create_override_settings_dict(override_settings_texts) - if firstpass_image is not None: - seed = getattr(firstpass_image, 'seed', None) - if seed: - args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed) - + if force_enable_hr: enable_hr = True - batch_size = 1 - n_iter = 1 p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -53,7 +32,7 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str width=width, height=height, enable_hr=enable_hr, - denoising_strength=denoising_strength if enable_hr else None, + denoising_strength=denoising_strength, hr_scale=hr_scale, hr_upscaler=hr_upscaler, hr_second_pass_steps=hr_second_pass_steps, @@ -64,7 +43,6 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, - firstpass_image=firstpass_image, ) p.scripts = modules.scripts.scripts_txt2img @@ -75,8 +53,61 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str if shared.opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + return p + + +def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args): + assert len(gallery) > 0, 'No image to upscale' + assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}' + + p = txt2img_create_processing(id_task, request, *args) + p.enable_hr = True + p.batch_size = 1 + p.n_iter = 1 + + geninfo = json.loads(generation_info) + all_seeds = geninfo["all_seeds"] + + image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] + p.firstpass_image = infotext_utils.image_from_url_text(image_info) + + gallery_index_from_end = len(gallery) - gallery_index + seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + with closing(p): - processed = modules.scripts.scripts_txt2img.run(p, *args) + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) + + if processed is None: + processed = processing.process_images(p) + + shared.total_tqdm.clear() + + new_gallery = [] + for i, image in enumerate(gallery): + fake_image = Image.new(mode="RGB", size=(1, 1)) + + if i == gallery_index: + already_saved_as = getattr(processed.images[0], 'already_saved_as', None) + if already_saved_as is not None: + fake_image.already_saved_as = already_saved_as + else: + fake_image = processed.images[0] + else: + fake_image.already_saved_as = image["name"] + + new_gallery.append(fake_image) + + geninfo["infotexts"][gallery_index] = processed.info + + return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments") + + +def txt2img(id_task: str, request: gr.Request, *args): + p = txt2img_create_processing(id_task, request, *args) + + with closing(p): + processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) if processed is None: processed = processing.process_images(p) From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [PATCH 212/311] [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66c..4e11125b2 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length From ec9acb31450536a9258192351d6a26421efd7eb4 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 17:17:53 +0800 Subject: [PATCH 213/311] Handle CondFunc exception when resolving attributes --- modules/sd_hijack_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f8684475e..79bf6e468 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -11,10 +11,14 @@ class CondFunc: break except ImportError: pass - for attr_name in func_path[i:-1]: - resolved_obj = getattr(resolved_obj, attr_name) - orig_func = getattr(resolved_obj, func_path[-1]) - setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + try: + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + except AttributeError: + print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") + pass self.__init__(orig_func, sub_func, cond_func) return lambda *args, **kwargs: self(*args, **kwargs) def __init__(self, orig_func, sub_func, cond_func): From 73786c047f14d6ae658b2c12f493f05486ba1789 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:09:56 +0800 Subject: [PATCH 214/311] [IPEX] Fix torch.Generator hijack --- modules/xpu_specific.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b2..1137891a6 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', From 818d6a11e709bf07d48606bdccab944c46a5f4b0 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 19:14:06 +0800 Subject: [PATCH 215/311] Fix format --- modules/xpu_specific.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 1137891a6..2971dbc3c 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -106,8 +106,8 @@ if has_xpu: try: # torch.Generator supports "xpu" device since 2.1 torch.Generator("xpu") - except: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + except RuntimeError: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) CondFunc('torch.Generator', lambda orig_func, device=None: torch.xpu.Generator(device), lambda orig_func, device=None: is_xpu_device(device)) From a183de04e3f965083e7f3462201327d30c36b958 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 20:03:33 +0800 Subject: [PATCH 216/311] Execute model_loaded_callback after moving to target device --- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 50bc209e4..2c0457715 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") - script_callbacks.model_loaded_callback(sd_model) - timer.record("script callbacks") - if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 31306d8ba..43687e48d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) if not sd_model.lowvram: sd_model.to(devices.device) + script_callbacks.model_loaded_callback(sd_model) + print("VAE weights loaded.") return sd_model From 2f98a35fc4508494355c01ec45f5bec725f570a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 7 Jan 2024 09:21:21 +0300 Subject: [PATCH 217/311] add assets repo; serve fonts locally rather than from google's servers --- modules/launch_utils.py | 3 +++ modules/sysinfo.py | 2 ++ modules/ui.py | 4 +++- style.css | 2 +- 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c2a7ae932..8e58d7145 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -344,11 +344,13 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") + assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') + assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") @@ -405,6 +407,7 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) + git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 5abf616b7..f336251e4 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -24,9 +24,11 @@ environment_whitelist = { "XFORMERS_PACKAGE", "CLIP_PACKAGE", "OPENCLIP_PACKAGE", + "ASSETS_REPO", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", "BLIP_REPO", + "ASSETS_COMMIT_HASH", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", "BLIP_COMMIT_HASH", diff --git a/modules/ui.py b/modules/ui.py index 2d2e333b2..a716a0405 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -1223,3 +1223,5 @@ def setup_ui_api(app): app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"]) + import fastapi.staticfiles + app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets") diff --git a/style.css b/style.css index 6d4c8a0d5..4957c523d 100644 --- a/style.css +++ b/style.css @@ -1,6 +1,6 @@ /* temporary fix to load default gradio font in frontend instead of backend */ -@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +@import url('webui-assets/css/sourcesanspro.css'); /* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */ From 425507bd10c55f1f804eb5015db74520668f46f9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 10:25:01 -0600 Subject: [PATCH 218/311] add p to cfgdenoiserparams --- modules/script_callbacks.py | 5 ++++- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 9ed7ad21d..bb47c18d3 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ class ExtraNoiseParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +63,9 @@ class CFGDenoiserParams: self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" + self.p = p + """StableDiffusionProcessing object with processing parameters""" + class CFGDenoisedParams: def __init__(self, x, sampling_step, total_sampling_steps, inner_model): diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index eb9d5dafa..f4ded6bdb 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From f56cebf5ba24313447b2204c3f804379767201c9 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 7 Jan 2024 12:35:35 -0600 Subject: [PATCH 219/311] add self instead --- modules/script_callbacks.py | 6 +++--- modules/sd_samplers_cfg_denoiser.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index bb47c18d3..053dfc963 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ class ExtraNoiseParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, p): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): self.x = x """Latent image representation in the process of being denoised""" @@ -63,8 +63,8 @@ class CFGDenoiserParams: self.text_uncond = text_uncond """ Encoder hidden states of text conditioning from negative prompt""" - self.p = p - """StableDiffusionProcessing object with processing parameters""" + self.denoiser = denoiser + """Current CFGDenoiser object with processing parameters""" class CFGDenoisedParams: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index f4ded6bdb..6d76aa965 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -146,7 +146,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self.p) + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self) cfg_denoiser_callback(denoiser_params) x_in = denoiser_params.x image_cond_in = denoiser_params.image_cond From 37906e429ae5a92f9ea96ab6dc2157b1d7c4d8b6 Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Sun, 7 Jan 2024 20:17:42 -0600 Subject: [PATCH 220/311] make denoiser None by default --- modules/script_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 053dfc963..a54cb3ebb 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -41,7 +41,7 @@ class ExtraNoiseParams: class CFGDenoiserParams: - def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser): + def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None): self.x = x """Latent image representation in the process of being denoised""" From 8e292373ec5c54493ce48af7c76f5eaa79dc8abd Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Mon, 8 Jan 2024 06:43:39 -0600 Subject: [PATCH 221/311] lcm sampler --- modules/sd_samplers.py | 3 +- modules/sd_samplers_lcm.py | 104 +++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 modules/sd_samplers_lcm.py diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 45faae628..a58528a0b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_timesteps.samplers_data_timesteps, + *sd_samplers_lcm.samplers_data_lcm, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_lcm.py b/modules/sd_samplers_lcm.py new file mode 100644 index 000000000..59839b720 --- /dev/null +++ b/modules/sd_samplers_lcm.py @@ -0,0 +1,104 @@ +import torch + +from k_diffusion import utils, sampling +from k_diffusion.external import DiscreteEpsDDPMDenoiser +from k_diffusion.sampling import default_noise_sampler, trange + +from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common + + +class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): + def __init__(self, model): + timesteps = 1000 + original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) + self.skip_steps = timesteps // original_timesteps + + alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) + for x in range(original_timesteps): + alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps] + + super().__init__(model, alphas_cumprod_valid, quantize=None) + + + def get_sigmas(self, n=None,): + if n is None: + return sampling.append_zero(self.sigmas.flip(0)) + + start = self.sigma_to_t(self.sigma_max) + end = self.sigma_to_t(self.sigma_min) + + t = torch.linspace(start, end, n, device=shared.sd_model.device) + + return sampling.append_zero(self.t_to_sigma(t)) + + + def sigma_to_t(self, sigma, quantize=None): + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + + + def t_to_sigma(self, timestep): + t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + return super().t_to_sigma(t) + + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + + def get_scaled_out(self, sigma, output, input): + sigma_data = 0.5 + scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0 + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * output + c_skip * input + + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return self.get_scaled_out(sigma, input + eps * c_out, input) + + +def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) + return x + + +class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = LCMCompVisDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + + +class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler): + def __init__(self, funcname, sd_model, options=None): + super().__init__(funcname, sd_model, options) + self.model_wrap_cfg = CFGDenoiserLCM(self) + self.model_wrap = self.model_wrap_cfg.inner_model + + +samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})] +samplers_data_lcm = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_lcm +] From df8aa69a99e38ae59a4e599b9dff11eccf3490f4 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:10:03 -0500 Subject: [PATCH 222/311] Add tree-view display for extra networks. --- html/extra-networks-card-minimal.html | 3 + html/extra-networks-card.html | 1 + javascript/extraNetworks.js | 5 + modules/shared_options.py | 1 + modules/ui_extra_networks.py | 350 ++++++++++++++++++-------- style.css | 92 +++++++ 6 files changed, 354 insertions(+), 98 deletions(-) create mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html new file mode 100644 index 000000000..a6a54d9f4 --- /dev/null +++ b/html/extra-networks-card-minimal.html @@ -0,0 +1,3 @@ +
+ {name}{copy_path_button}{metadata_button}{edit_button} +
diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 39674666f..d76011d73 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,6 +1,7 @@
{background_image}
+ {copy_path_button} {metadata_button} {edit_button}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb74..40309d557 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -337,6 +337,11 @@ function requestGet(url, data, handler, errorHandler) { xhr.send(js); } +function extraNetworksCopyCardPath(event, path) { + navigator.clipboard.writeText(path); + event.stopPropagation(); +} + function extraNetworksRequestMetadata(event, extraPage, cardName) { var showError = function() { extraNetworksShowMetadata("there was an error getting metadata"); diff --git a/modules/shared_options.py b/modules/shared_options.py index d2e86ff10..e698c2649 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,6 +238,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), + "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba33..8667617b9 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,6 +2,8 @@ import functools import os.path import urllib.parse from pathlib import Path +from typing import Optional, Union +from dataclasses import dataclass from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo @@ -15,10 +17,8 @@ from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() - default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -28,6 +28,58 @@ def allowed_preview_extensions(): return allowed_preview_extensions_with_extra((shared.opts.samples_format, )) +@dataclass +class ExtraNetworksItem: + """Wrapper for dictionaries representing ExtraNetworks items.""" + item: dict + + +def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict: + """Recursively builds a directory tree. + + Args: + paths: Path or list of paths to directories. These paths are treated as roots from which + the tree will be built. + items: A dictionary associating filepaths to an ExtraNetworksItem instance. + + Returns: + The result directory tree. + """ + if isinstance(paths, (str,)): + paths = [paths] + + def _get_tree(_paths: list[str]): + _res = {} + for path in _paths: + if os.path.isdir(path): + dir_items = os.listdir(path) + # Ignore empty directories. + if not dir_items: + continue + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + # We only want to store non-empty folders in the tree. + if dir_tree: + _res[os.path.basename(path)] = dir_tree + else: + if path not in items: + continue + # Add the ExtraNetworksItem to the result. + _res[os.path.basename(path)] = items[path] + return _res + + res = {} + # Handle each root directory separately. + # Each root WILL have a key/value at the root of the result dict though + # the value can be an empty dict if the directory is empty. We want these + # placeholders for empty dirs so we can inform the user later. + for path in paths: + # Wrap the path in a list since that is what the `_get_tree` expects. + res[path] = _get_tree([path]) + if res[path]: + res[path] = res[path][os.path.basename(path)] + + return res + def register_page(page): """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" @@ -80,7 +132,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): item = page.items.get(name) page.read_user_metadata(item) - item_html = page.create_html_for_item(item, tabname) + item_html = page.create_item_html(tabname, item) return JSONResponse({"html": item_html}) @@ -96,13 +148,15 @@ def quote_js(s): s = s.replace('"', '\\"') return f'"{s}"' - class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.card_page = shared.html("extra-networks-card.html") + if shared.opts.extra_networks_tree_view: + self.card_page = shared.html("extra-networks-card-minimal.html") + else: + self.card_page = shared.html("extra-networks-card.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -136,12 +190,141 @@ class ExtraNetworksPage: return "" - def create_html(self, tabname): - items_html = '' + def create_item_html(self, tabname: str, item: dict) -> str: + """Generates HTML for a single ExtraNetworks Item + + Args: + tabname: The name of the active tab. + item: Dictionary containing item information. + + Returns: + HTML string generated for this item. + Can be empty if the item is not meant to be shown. + """ + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) + + preview = item.get("preview", None) + height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + background_image = f'' if preview else '' + + onclick = item.get("onclick", None) + if onclick is None: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + + copy_path_button = f"
" + + metadata_button = "" + metadata = item.get("metadata") + if metadata: + metadata_button = f"" + + edit_button = f"
" + + local_path = "" + filename = item.get("filename", "") + for reldir in self.allowed_directories_for_previews(): + absdir = os.path.abspath(reldir) + + if filename.startswith(absdir): + local_path = filename[len(absdir):] + + # if this is true, the item must not be shown in the default view, and must instead only be + # shown when searching for it + if shared.opts.extra_networks_hidden_models == "Always": + search_only = False + else: + search_only = "/." in local_path or "\\." in local_path + + if search_only and shared.opts.extra_networks_hidden_models == "Never": + return "" + + sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + + # Some items here might not be used depending on HTML template used. + args = { + "background_image": background_image, + "card_clicked": onclick, + "copy_path_button": copy_path_button, + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "edit_button": edit_button, + "local_preview": quote_js(item["local_preview"]), + "metadata_button": metadata_button, + "name": html.escape(item["name"]), + "prompt": item.get("prompt", None), + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "search_only": " search_only" if search_only else "", + "search_term": item.get("search_term", ""), + "sort_keys": sort_keys, + "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "tabname": quote_js(tabname), + } + + return self.card_page.format(**args) + + def create_tree_view_html(self, tabname: str) -> str: + """Generates HTML for displaying folders in a tree view. + + Args: + tabname: The name of the active tab. + + Returns: + HTML string generated for this tree view. + """ + self_name_id = self.name.replace(" ", "_") + res = f"
" self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + roots = self.allowed_directories_for_previews() + tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} + tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) + + if not tree: + return res + "
" + + file_template = "
  • {}
  • " + dir_template = ( + "
    " + "{}" + "{}" + "
    " + ) + + def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: + """Recursively builds HTML for a tree.""" + _res = "
      " + if not data: + return "
      • DIRECTORY IS EMPTY
      " + + for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): + if isinstance(v, (ExtraNetworksItem,)): + _res += file_template.format(self.create_item_html(tabname, v.item)) + else: + _res += dir_template.format("", k, _build_tree(v)) + return _res + + res += "
        " + # Add each root directory to the tree. + for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): + # If root is empty, append the "disabled" attribute to the template details tag. + res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
      " + res += "
    " + + return res + + def create_card_view_html(self, tabname): + items_html = "" + self.metadata = {} subdirs = {} + for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): for dirname in sorted(dirs, key=shared.natural_sort_key): @@ -171,40 +354,45 @@ class ExtraNetworksPage: if subdirs: subdirs = {"": 1, **subdirs} - subdirs_html = "".join([f""" - -""" for subdir in subdirs]) + subdirs_html_template = ( + "" + ) + subdirs_html = "".join( + [ + subdirs_html_template.format( + " search-all" if subdir == "" else "", + tabname, + html.escape(subdir if subdir != "" else "all"), + ) for subdir in subdirs + ] + ) self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata + items_html += self.create_item_html(tabname, item) - if "user_metadata" not in item: - self.read_user_metadata(item) - - items_html += self.create_html_for_item(item, tabname) - - if items_html == '': + if items_html == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) self_name_id = self.name.replace(" ", "_") - res = f""" -
    -{subdirs_html} -
    -
    -{items_html} -
    -""" + res = ( + f"
    {subdirs_html}
    " + f"
    {items_html}
    " + ) return res + def create_html(self, tabname): + if shared.opts.extra_networks_tree_view: + return self.create_tree_view_html(tabname) + else: + return self.create_card_view_html(tabname) + def create_item(self, name, index=None): raise NotImplementedError() @@ -214,66 +402,6 @@ class ExtraNetworksPage: def allowed_directories_for_previews(self): return [] - def create_html_for_item(self, item, tabname): - """ - Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown. - """ - - preview = item.get("preview", None) - - onclick = item.get("onclick", None) - if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' - background_image = f'' if preview else '' - metadata_button = "" - metadata = item.get("metadata") - if metadata: - metadata_button = f"" - - edit_button = f"
    " - - local_path = "" - filename = item.get("filename", "") - for reldir in self.allowed_directories_for_previews(): - absdir = os.path.abspath(reldir) - - if filename.startswith(absdir): - local_path = filename[len(absdir):] - - # if this is true, the item must not be shown in the default view, and must instead only be - # shown when searching for it - if shared.opts.extra_networks_hidden_models == "Always": - search_only = False - else: - search_only = "/." in local_path or "\\." in local_path - - if search_only and shared.opts.extra_networks_hidden_models == "Never": - return "" - - sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() - - args = { - "background_image": background_image, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "prompt": item.get("prompt", None), - "tabname": quote_js(tabname), - "local_preview": quote_js(item["local_preview"]), - "name": html.escape(item["name"]), - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "card_clicked": onclick, - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', - "search_term": item.get("search_term", ""), - "metadata_button": metadata_button, - "edit_button": edit_button, - "search_only": " search_only" if search_only else "", - "sort_keys": sort_keys, - } - - return self.card_page.format(**args) - def get_sort_keys(self, path): """ List of default keys used for sorting in the UI. @@ -360,7 +488,6 @@ def pages_in_preferred_order(pages): return sorted(pages, key=lambda x: tab_scores[x.name]) - def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): from modules.ui import switch_values_symbol @@ -381,7 +508,6 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) editor = page.create_user_metadata_editor(ui, tabname) @@ -390,30 +516,60 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) + tab_controls = {} + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) + + tab_controls["edit_search"] = edit_search + tab_controls["dropdown_sort"] = dropdown_sort + tab_controls["button_sortorder"] = button_sortorder + tab_controls["button_refresh"] = button_refresh + tab_controls["checkbox_show_dirs"] = checkbox_show_dirs ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs] - for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False) + tab.select( + fn=lambda: [gr.update(visible=False) for _ in tab_controls], + _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) + + visible_controls = list(tab_controls.keys()) + if shared.opts.extra_networks_tree_view: + visible_controls = ["button_refresh"] for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');' + jscode = ( + "extraNetworksTabSelected(" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" + ");" + ) - tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False) + tab.select( + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + _js="function(){ " + jscode + " }", + inputs=[], + outputs=list(tab_controls.values()), + show_progress=False, + ) dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") + def pages_html(): if not ui.pages_contents: return refresh() @@ -478,5 +634,3 @@ def setup_ui(ui, gallery): for editor in ui.user_metadata_editors: editor.setup_ui(gallery) - - diff --git a/style.css b/style.css index ee39a57b7..680a5f837 100644 --- a/style.css +++ b/style.css @@ -1157,3 +1157,95 @@ body.resizing .resize-handle { left: 7.5px; border-left: 1px dashed var(--border-color-primary); } + +.extra-network-cards .card .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .button-column { + display: inline-flex; + visibility: hidden; + color: white; + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; +} + +.extra-network-cards .card-minimal:hover .button-column { + visibility: visible; +} + +.extra-network-cards .card-minimal .copy-path-button:before { + content: "⎘"; +} + +.extra-network-cards .card-minimal .metadata-button:before{ + content: "🛈"; +} + +.extra-network-cards .card-minimal .edit-button:before{ + content: "🛠"; +} + +.extra-network-cards .card-minimal .card-button { + color: white; + text-shadow: 2px 2px 3px black; + font-size: 1rem; + width: 1.5rem; +} + +.extra-network-cards .card-minimal .card-button:hover { + color: red; +} + +.extra-network-cards .card-minimal { + display: inline-flex; + position: relative; + overflow: hidden; + cursor: pointer; + font-size: 1rem; + font-weight: bold; + line-break: anywhere; +} + +.file-item { + list-style-type: '📄'; +} + +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +details[disabled] summary { + pointer-events: none; + user-select: none; +} + +details.folder-item > summary { + list-style-type: '📁'; +} + +details.folder-item[open] > summary { + list-style-type: '📂'; +} + +.file-item, +.folder-item, +.folder-item-summary { + display: block; + font-size: 1rem; + padding: 0.05rem; + cursor: pointer; + user-select: none; +} + +.folder-item-summary:hover, +.file-item:hover { + -webkit-transition: all 0.1s ease-in-out; + transition: all 0.1s ease-in-out; + background-color: var(--neutral-200); +} + +.dark .folder-item-summary:hover, +.dark .file-item:hover { + -webkit-transition: all 0.05s ease-in-out; + transition: all 0.05s ease-in-out; + background-color: var(--neutral-800); +} From 67a70ad112e1b0fb262852ab830896302dbd306a Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:11:24 -0500 Subject: [PATCH 223/311] fix indentation --- html/extra-networks-card.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index d76011d73..7770094da 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,7 +1,7 @@
    {background_image}
    - {copy_path_button} + {copy_path_button} {metadata_button} {edit_button}
    From 34fc215249e2bc0acc66cda47319e40b6e46a05f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 14:23:01 -0500 Subject: [PATCH 224/311] fix linting --- modules/ui_extra_networks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8667617b9..ab484a5dc 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,7 +216,7 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' - + copy_path_button = f"
    " metadata_button = "" @@ -523,7 +523,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - + tab_controls["edit_search"] = edit_search tab_controls["dropdown_sort"] = dropdown_sort tab_controls["button_sortorder"] = button_sortorder @@ -560,7 +560,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], _js="function(){ " + jscode + " }", inputs=[], outputs=list(tab_controls.values()), From 46413b20a4d88355b5fac5d27cc9fcb634cdb49e Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 8 Jan 2024 16:23:35 -0500 Subject: [PATCH 225/311] Increase limits for upscalers. --- modules/api/models.py | 2 +- scripts/postprocessing_upscale.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 33894b3e6..11ba4f0b9 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index ed709688d..5946678e5 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,13 +21,13 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): with gr.Column(elem_id="upscaling_column_size", scale=4): - upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") + upscaling_resize_w = gr.Slider(minimum=64, maximum=8192, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Slider(minimum=64, maximum=8192, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h") with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"): upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height") upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") From 8d986727b39ee6616190559efa1c41c1942b99b0 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 9 Jan 2024 03:01:20 -0600 Subject: [PATCH 226/311] include tls arguments in api uvicorn init --- modules/api/api.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9d1292e95..59e463352 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -879,7 +879,15 @@ class Api: def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) + uvicorn.run( + self.app, + host=server_name, + port=port, + timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, + root_path=root_path, + ssl_keyfile=shared.cmd_opts.tls_keyfile, + ssl_certfile=shared.cmd_opts.tls_certfile + ) def kill_webui(self): restart.stop_program() From 209c26a1cb9e4be357ab3c5e7613caf3cbc26183 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:11:44 +0800 Subject: [PATCH 227/311] improve efficiency and support more device --- modules/devices.py | 60 ++++++++++++++++++++++++++++++------------ modules/shared_init.py | 1 + 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ff279ac50..6edfb1278 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ device_codeformer: torch.device = None dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,49 @@ patch_module_list = [ ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + org_dtype = torch_utils.get_param(self).dtype + if not target_dtype == org_dtype == dtype_inference: + self.to(target_dtype) + args = [ + arg.to(target_dtype) + if isinstance(arg, torch.Tensor) + else arg + for arg in args + ] + kwargs = { + k: v.to(target_dtype) + if isinstance(v, torch.Tensor) + else v + for k, v in kwargs.items() + } + + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +190,12 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() - - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 and shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype_inference) + return torch.autocast("cuda") diff --git a/modules/shared_init.py b/modules/shared_init.py index 586be3423..935e3a21c 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -29,6 +29,7 @@ def initialize(): devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype shared.device = devices.device shared.weight_load_location = None if cmd_opts.lowram else "cpu" From 42e6df723c68af775b73c9fa4f43f99345348689 Mon Sep 17 00:00:00 2001 From: KohakuBlueleaf Date: Tue, 9 Jan 2024 22:39:39 +0800 Subject: [PATCH 228/311] Fix bugs when arg dtype doesn't match --- modules/devices.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 6edfb1278..e05740524 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -134,24 +134,19 @@ patch_module_list = [ def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - if not target_dtype == org_dtype == dtype_inference: - self.to(target_dtype) - args = [ - arg.to(target_dtype) - if isinstance(arg, torch.Tensor) - else arg - for arg in args - ] - kwargs = { - k: v.to(target_dtype) - if isinstance(v, torch.Tensor) - else v - for k, v in kwargs.items() - } + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + org_dtype = torch_utils.get_param(self).dtype + if org_dtype != target_dtype: + self.to(target_dtype) result = self.org_forward(*args, **kwargs) - self.to(org_dtype) + if org_dtype != target_dtype: + self.to(org_dtype) if target_dtype != dtype_inference: if isinstance(result, tuple): From c2c05fcca8f3547783c5440c04ec10cc63c65db5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 22:53:58 +0800 Subject: [PATCH 229/311] linting and debugs --- modules/devices.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index e05740524..ad36f6562 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -140,20 +140,20 @@ def manual_cast_forward(target_dtype): ): args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - + org_dtype = torch_utils.get_param(self).dtype if org_dtype != target_dtype: self.to(target_dtype) result = self.org_forward(*args, **kwargs) if org_dtype != target_dtype: self.to(org_dtype) - + if target_dtype != dtype_inference: if isinstance(result, tuple): result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i for i in result ) elif isinstance(result, torch.Tensor): @@ -185,7 +185,7 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32 and shared.cmd_opts.precision == "full": + if dtype == torch.float32: return contextlib.nullcontext() if has_xpu() or has_mps() or cuda_no_autocast(): From e00365962b17550a42235d1fbe2ad2c7cc4b8961 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:13:34 +0800 Subject: [PATCH 230/311] Apply correct inference precision implementation --- modules/devices.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562..9e1f207c3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,6 +132,21 @@ patch_module_list = [ ] +def cast_output(result): + if isinstance(result, tuple): + result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + + +def autocast_with_cast_output(self, *args, **kwargs): + result = self.org_forward(*args, **kwargs) + if dtype_inference != dtype: + result = cast_output(result) + return result + + def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -149,15 +164,7 @@ def manual_cast_forward(target_dtype): self.to(org_dtype) if target_dtype != dtype_inference: - if isinstance(result, tuple): - result = tuple( - i.to(dtype_inference) - if isinstance(i, torch.Tensor) - else i - for i in result - ) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) + result = cast_output(result) return result return forward_wrapper @@ -178,6 +185,20 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward +@contextlib.contextmanager +def precision_full_with_autocast(autocast_ctx): + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = autocast_with_cast_output + module_type.org_forward = org_forward + try: + with autocast_ctx: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward + + def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -191,6 +212,9 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) + if dtype_inference == torch.float32 and dtype != torch.float32: + return precision_full_with_autocast(torch.autocast("cuda")) + return torch.autocast("cuda") From 1fd69655fe340325863cbd7bf5297e034a6a3a0a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:15:05 +0800 Subject: [PATCH 231/311] Revert "Apply correct inference precision implementation" This reverts commit e00365962b17550a42235d1fbe2ad2c7cc4b8961. --- modules/devices.py | 42 +++++++++--------------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 9e1f207c3..ad36f6562 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -132,21 +132,6 @@ patch_module_list = [ ] -def cast_output(result): - if isinstance(result, tuple): - result = tuple(i.to(dtype_inference) if isinstance(i, torch.Tensor) else i for i in result) - elif isinstance(result, torch.Tensor): - result = result.to(dtype_inference) - return result - - -def autocast_with_cast_output(self, *args, **kwargs): - result = self.org_forward(*args, **kwargs) - if dtype_inference != dtype: - result = cast_output(result) - return result - - def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): if any( @@ -164,7 +149,15 @@ def manual_cast_forward(target_dtype): self.to(org_dtype) if target_dtype != dtype_inference: - result = cast_output(result) + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) return result return forward_wrapper @@ -185,20 +178,6 @@ def manual_cast(target_dtype): module_type.forward = module_type.org_forward -@contextlib.contextmanager -def precision_full_with_autocast(autocast_ctx): - for module_type in patch_module_list: - org_forward = module_type.forward - module_type.forward = autocast_with_cast_output - module_type.org_forward = org_forward - try: - with autocast_ctx: - yield None - finally: - for module_type in patch_module_list: - module_type.forward = module_type.org_forward - - def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -212,9 +191,6 @@ def autocast(disable=False): if has_xpu() or has_mps() or cuda_no_autocast(): return manual_cast(dtype_inference) - if dtype_inference == torch.float32 and dtype != torch.float32: - return precision_full_with_autocast(torch.autocast("cuda")) - return torch.autocast("cuda") From 58d5b042cd02f287faabef399134b97d323691f2 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:23:40 +0800 Subject: [PATCH 232/311] Apply the correct behavior of precision='full' --- modules/devices.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index ad36f6562..29a270d11 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,11 +185,14 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if dtype == torch.float32: - return contextlib.nullcontext() - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype_inference) + return manual_cast(dtype) + + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) + + if dtype == torch.float32 or dtype_inference == torch.float32: + return contextlib.nullcontext() return torch.autocast("cuda") From ca671e5d7b9d03227f01e6bcb350032b6d14e722 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 9 Jan 2024 23:30:55 +0800 Subject: [PATCH 233/311] rearrange if-statements for cpu --- modules/devices.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 29a270d11..0321d12c6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -185,15 +185,15 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if has_xpu() or has_mps() or cuda_no_autocast(): - return manual_cast(dtype) - if fp8 and dtype_inference == torch.float32: return manual_cast(dtype) if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda") From 3cc7572f5dec429d265ce937249eb4b5cc18d0ba Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:46:10 -0500 Subject: [PATCH 234/311] Restore scale factor limit changes to master branch. --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index 11ba4f0b9..33894b3e6 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -143,7 +143,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?") From 9dd25348248c06770897f4cde64e2663dfa9f2de Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 9 Jan 2024 11:47:39 -0500 Subject: [PATCH 235/311] Restore scale factor limit changes to master branch. --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 5946678e5..a57f9d4a4 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -21,7 +21,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): with FormRow(): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: - upscaling_resize = gr.Slider(minimum=1.0, maximum=100.0, step=0.05, label="Resize", value=2, elem_id="extras_upscaling_resize") + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: with FormRow(): From 4d9f2c3ec8e791b9c354292c4333fc85b8f8d740 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 6 Jan 2024 21:41:29 +0900 Subject: [PATCH 236/311] update p.seed and p.subseed --- modules/txt2img.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index c4cc12d2f..d22a1f319 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -67,13 +67,16 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g geninfo = json.loads(generation_info) all_seeds = geninfo["all_seeds"] + all_subseeds = geninfo["all_subseeds"] image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) gallery_index_from_end = len(gallery) - gallery_index seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] - p.script_args = modules.scripts.scripts_txt2img.set_named_arg(p.script_args, 'ScriptSeed', 'seed', seed) + subseed = all_subseeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0] + p.seed = seed + p.subseed = subseed with closing(p): processed = modules.scripts.scripts_txt2img.run(p, *p.script_args) From 3db6938caa719aaa38b52edecf42740ef62b0c3c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 10 Jan 2024 18:11:48 -0500 Subject: [PATCH 237/311] begin redesign of tree module. --- html/extra-networks-card-minimal.html | 3 +- html/extra-networks-pane.html | 11 ++ html/extra-networks-tree-directory.html | 4 + html/extra-networks-tree-file.html | 1 + javascript/extraNetworks.js | 22 +++ modules/shared_options.py | 1 - modules/ui_extra_networks.py | 164 ++++++++++-------- style.css | 218 +++++++++++++----------- 8 files changed, 255 insertions(+), 169 deletions(-) create mode 100644 html/extra-networks-pane.html create mode 100644 html/extra-networks-tree-directory.html create mode 100644 html/extra-networks-tree-file.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html index a6a54d9f4..d66df7dfb 100644 --- a/html/extra-networks-card-minimal.html +++ b/html/extra-networks-card-minimal.html @@ -1,3 +1,4 @@
    - {name}{copy_path_button}{metadata_button}{edit_button} + {name} + {copy_path_button}{metadata_button}{edit_button}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html new file mode 100644 index 000000000..93bad698d --- /dev/null +++ b/html/extra-networks-pane.html @@ -0,0 +1,11 @@ +
    + {subdirs_html} +
    +
    +
    + {tree_html} +
    +
    + {items_html} +
    +
    \ No newline at end of file diff --git a/html/extra-networks-tree-directory.html b/html/extra-networks-tree-directory.html new file mode 100644 index 000000000..cec155886 --- /dev/null +++ b/html/extra-networks-tree-directory.html @@ -0,0 +1,4 @@ +
    +{folder_name} +{content} +
    \ No newline at end of file diff --git a/html/extra-networks-tree-file.html b/html/extra-networks-tree-file.html new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/html/extra-networks-tree-file.html @@ -0,0 +1 @@ + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 40309d557..33f45c8e4 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -253,6 +253,28 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksFolderClick(event, tabs_id) { + var els = document.querySelectorAll(".folder-item-summary.selected"); + [...els].forEach(el => { + el.classList.remove("selected"); + }); + event.target.classList.add("selected"); + + var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); + var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + searchTextArea.value = text; + updateInput(searchTextArea); + + if (event.target.parentElement.open) { + // before close + console.log("closed"); + } else { + // before open + console.log("opened"); + //console.log("Opened:", event.target.parentElement); + } +} + function extraNetworksSearchButton(tabs_id, event) { var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; diff --git a/modules/shared_options.py b/modules/shared_options.py index e698c2649..d2e86ff10 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -238,7 +238,6 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s "extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."), "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), - "extra_networks_tree_view": OptionInfo(False, "Show extra networks using a directory tree view.").needs_reload_ui(), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ab484a5dc..6318594fa 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -73,10 +73,11 @@ def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: + short_path = os.path.basename(path) # Wrap the path in a list since that is what the `_get_tree` expects. - res[path] = _get_tree([path]) - if res[path]: - res[path] = res[path][os.path.basename(path)] + res[short_path] = _get_tree([path]) + if res[short_path]: + res[short_path] = res[short_path][os.path.basename(path)] return res @@ -153,10 +154,9 @@ class ExtraNetworksPage: self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - if shared.opts.extra_networks_tree_view: - self.card_page = shared.html("extra-networks-card-minimal.html") - else: - self.card_page = shared.html("extra-networks-card.html") + self.extra_networks_pane_template = shared.html("extra-networks-pane.html") + self.card_page_template = shared.html("extra-networks-card.html") + self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -182,15 +182,14 @@ class ExtraNetworksPage: def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) - for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()): - parentdir = os.path.abspath(parentdir) + parentdir = os.path.dirname(os.path.abspath(parentdir)) if abspath.startswith(parentdir): - return abspath[len(parentdir):].replace('\\', '/') + return os.path.relpath(abspath, parentdir) return "" - def create_item_html(self, tabname: str, item: dict) -> str: + def create_item_html(self, tabname: str, item: dict, template: Optional[str] = None) -> str: """Generates HTML for a single ExtraNetworks Item Args: @@ -265,7 +264,10 @@ class ExtraNetworksPage: "tabname": quote_js(tabname), } - return self.card_page.format(**args) + if template: + return template.format(**args) + else: + return self.card_page.format(**args) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -276,53 +278,67 @@ class ExtraNetworksPage: Returns: HTML string generated for this tree view. """ - self_name_id = self.name.replace(" ", "_") - res = f"
    " - - self.metadata = {} - self.items = {x["name"]: x for x in self.list_items()} + res = f"" + # Generate HTML for the tree. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) if not tree: - return res + "
    " + return res - file_template = "
  • {}
  • " + file_template = "
  • {card}
  • " dir_template = ( - "
    " - "{}" - "{}" + "
    " + "" + "{folder_name}" + "" + "
      {content}
    " "
    " ) def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" - _res = "
      " + _res = "" if not data: - return "
      • DIRECTORY IS EMPTY
      " + return "
    • DIRECTORY IS EMPTY
    • " for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _res += file_template.format(self.create_item_html(tabname, v.item)) + item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + _res += file_template.format(**{"card": item_html}) else: - _res += dir_template.format("", k, _build_tree(v)) + tmp = dir_template.format( + **{ + "attributes": "", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + _res += tmp + return _res - res += "
        " # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += dir_template.format("open" if v else "open disabled", k, _build_tree(v)) + res += "
          " + res += dir_template.format( + **{ + "attributes": "open" if v else "open", + "tabname": tabname, + "folder_name": k, + "content": _build_tree(v), + } + ) + res += "
        " res += "
      " - res += "
    " return res - def create_card_view_html(self, tabname): - items_html = "" - self.metadata = {} + def create_subdirs_html(self, tabname): subdirs = {} for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: @@ -355,43 +371,53 @@ class ExtraNetworksPage: subdirs = {"": 1, **subdirs} subdirs_html_template = ( - "" ) - subdirs_html = "".join( + return "".join( [ subdirs_html_template.format( - " search-all" if subdir == "" else "", - tabname, - html.escape(subdir if subdir != "" else "all"), + **{ + "classes": "search-all" if not subdir else "", + "tabname": tabname, + "content": html.escape(subdir if subdir else "all"), + } ) for subdir in subdirs ] ) + def create_card_view_html(self, tabname): + res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - items_html += self.create_item_html(tabname, item) + res += self.create_item_html(tabname, item, self.card_page_template) - if items_html == "": + if res == "": dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) - items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) - - self_name_id = self.name.replace(" ", "_") - - res = ( - f"
    {subdirs_html}
    " - f"
    {items_html}
    " - ) + res = shared.html("extra-networks-no-cards.html").format(dirs=dirs) return res def create_html(self, tabname): - if shared.opts.extra_networks_tree_view: - return self.create_tree_view_html(tabname) - else: - return self.create_card_view_html(tabname) + self.metadata = {} + self.items = {x["name"]: x for x in self.list_items()} + + tree_view_html = self.create_tree_view_html(tabname) + subdirs_html = self.create_subdirs_html(tabname) + card_view_html = self.create_card_view_html(tabname) + network_type_id = self.name.replace(" ", "_") + + return self.extra_networks_pane_template.format( + **{ + "tabname": tabname, + "network_type_id": network_type_id, + "tree_html": tree_view_html, + "subdirs_html": subdirs_html, + "items_html": card_view_html, + } + ) def create_item(self, name, index=None): raise NotImplementedError() @@ -516,19 +542,19 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - tab_controls = {} - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) - tab_controls["edit_search"] = edit_search - tab_controls["dropdown_sort"] = dropdown_sort - tab_controls["button_sortorder"] = button_sortorder - tab_controls["button_refresh"] = button_refresh - tab_controls["checkbox_show_dirs"] = checkbox_show_dirs + tab_controls = [ + edit_search, + dropdown_sort, + button_sortorder, + button_refresh, + checkbox_show_dirs, + ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -538,32 +564,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) - visible_controls = list(tab_controls.keys()) - if shared.opts.extra_networks_tree_view: - visible_controls = ["button_refresh"] - for page, tab in zip(ui.stored_extra_pages, related_tabs): allow_prompt = "true" if page.allow_prompt else "false" allow_negative_prompt = "true" if page.allow_negative_prompt else "false" jscode = ( "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" + f"'{tabname}', " + f"'{tabname}_{page.id_page}_prompts', " + f"'{allow_prompt}', " + f"'{allow_negative_prompt}'" ");" ) tab.select( - fn=lambda: [gr.update(visible=k in visible_controls) for k in tab_controls], + fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js="function(){ " + jscode + " }", inputs=[], - outputs=list(tab_controls.values()), + outputs=tab_controls, show_progress=False, ) diff --git a/style.css b/style.css index 680a5f837..4f285c68b 100644 --- a/style.css +++ b/style.css @@ -863,7 +863,7 @@ footer { margin-bottom: 1em; } -.extra-network-cards{ +.extra-network-pane{ height: calc(100vh - 24rem); overflow: clip scroll; resize: vertical; @@ -908,53 +908,75 @@ footer { width: auto; } -.extra-network-cards .nocards{ +.extra-network-pane .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-pane .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-pane .nocards li{ margin-left: 0.5em; } +.extra-network-pane :is(.card, .card-minimal) .button-row{ + display: inline-flex; + visibility: hidden; + color: white; +} -.extra-network-cards .card .button-row{ - display: none; +.extra-network-pane .card .button-row { position: absolute; - color: white; right: 0; - z-index: 1 -} -.extra-network-cards .card:hover .button-row{ - display: flex; + z-index: 1; } -.extra-network-cards .card .card-button{ +.extra-network-pane .card-minimal .button-row { + padding-left: 0.5rem; + padding-right: 0.5rem; + align-items: center; +} + +.extra-network-pane :is(.card:hover, .card-minimal:hover) .button-row{ + visibility: visible; +} + +.extra-network-pane .card-button{ color: white; } -.extra-network-cards .card .metadata-button:before{ +.extra-network-pane .copy-path-button:before { + content: "⎘"; +} + +.extra-network-pane .metadata-button:before{ content: "🛈"; } -.extra-network-cards .card .edit-button:before{ +.extra-network-pane .edit-button:before{ content: "🛠"; } -.extra-network-cards .card .card-button { - text-shadow: 2px 2px 3px black; - padding: 0.25em 0.1em; - font-size: 200%; +.extra-network-pane .card-button { width: 1.5em; + text-shadow: 2px 2px 3px black; + color: white; + padding: 0.25em 0.1em; } -.extra-network-cards .card .card-button:hover{ + +.extra-network-pane .card-button:hover{ color: red; } +.extra-network-pane .card .card-button { + font-size: 2rem; +} + +.extra-network-pane .card-minimal .card-button { + font-size: 1rem; +} .standalone-card-preview.card .preview{ position: absolute; @@ -963,7 +985,7 @@ footer { height:100%; } -.extra-network-cards .card, .standalone-card-preview.card{ +.extra-network-pane .card, .standalone-card-preview.card{ display: inline-block; margin: 0.5rem; width: 16rem; @@ -980,15 +1002,15 @@ footer { background-image: url('./file=html/card-no-preview.png') } -.extra-network-cards .card:hover{ +.extra-network-pane .card:hover{ box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); } -.extra-network-cards .card .actions .additional{ +.extra-network-pane .card .actions .additional{ display: none; } -.extra-network-cards .card .actions{ +.extra-network-pane .card .actions{ position: absolute; bottom: 0; left: 0; @@ -999,45 +1021,45 @@ footer { text-shadow: 0 0 0.2em black; } -.extra-network-cards .card .actions *{ +.extra-network-pane .card .actions *{ color: white; } -.extra-network-cards .card .actions .name{ +.extra-network-pane .card .actions .name{ font-size: 1.7em; font-weight: bold; line-break: anywhere; } -.extra-network-cards .card .actions .description { +.extra-network-pane .card .actions .description { display: block; max-height: 3em; white-space: pre-wrap; line-height: 1.1; } -.extra-network-cards .card .actions .description:hover { +.extra-network-pane .card .actions .description:hover { max-height: none; } -.extra-network-cards .card .actions:hover .additional{ +.extra-network-pane .card .actions:hover .additional{ display: block; } -.extra-network-cards .card ul{ +.extra-network-pane .card ul{ margin: 0.25em 0 0.75em 0.25em; cursor: unset; } -.extra-network-cards .card ul a{ +.extra-network-pane .card ul a{ cursor: pointer; } -.extra-network-cards .card ul a:hover{ +.extra-network-pane .card ul a:hover{ color: red; } -.extra-network-cards .card .preview{ +.extra-network-pane .card .preview{ position: absolute; object-fit: cover; width: 100%; @@ -1158,48 +1180,9 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-cards .card .copy-path-button:before { - content: "⎘"; -} - -.extra-network-cards .card-minimal .button-column { - display: inline-flex; - visibility: hidden; - color: white; - padding-left: 0.5rem; - padding-right: 0.5rem; - align-items: center; -} - -.extra-network-cards .card-minimal:hover .button-column { - visibility: visible; -} - -.extra-network-cards .card-minimal .copy-path-button:before { - content: "⎘"; -} - -.extra-network-cards .card-minimal .metadata-button:before{ - content: "🛈"; -} - -.extra-network-cards .card-minimal .edit-button:before{ - content: "🛠"; -} - -.extra-network-cards .card-minimal .card-button { - color: white; - text-shadow: 2px 2px 3px black; - font-size: 1rem; - width: 1.5rem; -} - -.extra-network-cards .card-minimal .card-button:hover { - color: red; -} - -.extra-network-cards .card-minimal { +.extra-network-pane .card-minimal { display: inline-flex; + flex-grow: 1; position: relative; overflow: hidden; cursor: pointer; @@ -1208,44 +1191,87 @@ body.resizing .resize-handle { line-break: anywhere; } -.file-item { - list-style-type: '📄'; -} - -/* prevents clicking/collapsing of details tags when disabled attribute is used*/ -details[disabled] summary { - pointer-events: none; - user-select: none; -} - -details.folder-item > summary { - list-style-type: '📁'; -} - -details.folder-item[open] > summary { - list-style-type: '📂'; +/* Pushes buttons to right */ +.extra-network-pane .card-minimal .name { + flex-grow: 1; } .file-item, .folder-item, .folder-item-summary { - display: block; - font-size: 1rem; - padding: 0.05rem; + padding-left: 0.05rem; cursor: pointer; user-select: none; + font-size: 1rem; } -.folder-item-summary:hover, -.file-item:hover { +.extra-network-pane .extra-network-tree .folder-item-summary:hover, +.extra-network-pane .extra-network-tree .file-item:hover { -webkit-transition: all 0.1s ease-in-out; transition: all 0.1s ease-in-out; background-color: var(--neutral-200); } -.dark .folder-item-summary:hover, -.dark .file-item:hover { +.dark .extra-network-pane .extra-network-tree .folder-item-summary:hover, +.dark .extra-network-pane .extra-network-tree .file-item:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-800); } + +/* prevents clicking/collapsing of details tags when disabled attribute is used*/ +.extra-network-pane .extra-network-tree details[disabled] summary { + pointer-events: none; + user-select: none; +} + +.extra-network-pane .extra-network-tree details.folder-item > summary { + list-style-type: '📁'; + text-overflow: ellipsis; +} + +.extra-network-pane .extra-network-tree details.folder-item[open] > summary { + list-style-type: '📂'; + text-overflow: ellipsis; +} + +.extra-network-pane .extra-network tree ul.folder-container { + list-style: none; + font-size: 1rem; + text-overflow: ellipsis; +} + +.extra-network-pane .extra-network-tree li.file-item { + display: flex; + position: relative; + align-items: center; +} + +.extra-network-pane .extra-network-tree li.file-item::before { + content: '📄'; + font-size: 0.85rem; + vertical-align: middle; +} + +.extra-network-pane { + display: flex; +} + +.extra-network-pane .extra-network-subdirs { + display: block; +} +.extra-network-pane .extra-network-tree { + font-size: 1rem; + width: 25%; +} +.extra-network-pane .extra-network-cards { + flex-grow: 1; +} + +.dark .extra-network-tree .folder-item-summary.selected{ + background-color: var(--neutral-800); +} + +.extra-network-tree .folder-item-summary.selected { + background-color: var(--neutral-200); +} From 0011640ab187e5a59098bb80d165a23f9d08568c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 11 Jan 2024 08:29:42 +0200 Subject: [PATCH 238/311] Logging: set formatter correctly for fallback logger too --- modules/logging_config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/logging_config.py b/modules/logging_config.py index 11eee9a63..8e31d8c9f 100644 --- a/modules/logging_config.py +++ b/modules/logging_config.py @@ -36,20 +36,21 @@ def setup_logging(loglevel): # Already configured, do not interfere return + formatter = logging.Formatter( + '%(asctime)s %(levelname)s [%(name)s] %(message)s', + '%Y-%m-%d %H:%M:%S', + ) + if os.environ.get("SD_WEBUI_RICH_LOG"): from rich.logging import RichHandler handler = RichHandler() else: handler = logging.StreamHandler() + handler.setFormatter(formatter) if TqdmLoggingHandler: handler = TqdmLoggingHandler(handler) - formatter = logging.Formatter( - '%(asctime)s %(levelname)s [%(name)s] %(message)s', - '%Y-%m-%d %H:%M:%S', - ) - handler.setFormatter(formatter) log_level = getattr(logging, loglevel.upper(), None) or logging.INFO From 0726a6e12e85a37d1e514f5603acf9f058c11783 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 11 Jan 2024 15:06:57 -0500 Subject: [PATCH 239/311] Finish base layout. Fix bugs. Need to test for stability and clean up. --- .../Lora/ui_extra_networks_lora.py | 5 +- html/extra-networks-card.html | 4 +- html/extra-networks-pane.html | 3 - javascript/extraNetworks.js | 48 ++++---- modules/ui_extra_networks.py | 113 ++++++------------ modules/ui_extra_networks_checkpoints.py | 5 +- modules/ui_extra_networks_hypernets.py | 6 +- .../ui_extra_networks_textual_inversion.py | 5 +- style.css | 24 ++-- 9 files changed, 89 insertions(+), 124 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663b..db612fa2b 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): alias = lora_on_disk.get_alias() + search_terms = [self.search_terms_from_path(lora_on_disk.filename)] + if lora_on_disk.hash: + search_terms.append(lora_on_disk.hash) item = { "name": name, "filename": lora_on_disk.filename, "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), + "search_terms": search_terms, "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 7770094da..d163fe370 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -6,9 +6,7 @@ {edit_button}
    -
    - -
    +
    {search_terms}
    {name} {description}
    diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 93bad698d..20cf66867 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,6 +1,3 @@ -
    - {subdirs_html} -
    {tree_html} diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 33f45c8e4..4cc67fd18 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -24,8 +24,6 @@ function setupExtraNetworksForTab(tabname) { var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); - var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); @@ -33,14 +31,14 @@ function setupExtraNetworksForTab(tabname) { tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); - tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); var visible = text.indexOf(searchTerm) != -1; @@ -100,15 +98,6 @@ function setupExtraNetworksForTab(tabname) { extraNetworksApplySort[tabname] = applySort; extraNetworksApplyFilter[tabname] = applyFilter; - - var showDirsUpdate = function() { - var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; - toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); - localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); - }; - showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; - showDirs.addEventListener("change", showDirsUpdate); - showDirsUpdate(); } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,14 +125,23 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } +function clearSearch(tabname) { + // Clear search box. + var tab_id = tabname + "_extra_search"; + var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); + searchTextarea.value = ""; + updateInput(searchTextarea); +} -function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) + +function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); + clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - + clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -254,6 +252,15 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksFolderClick(event, tabs_id) { + // If folder is open but not selected, we don't want to collapse it. Instead + // we override the removal of the "open" attribute so that the folder is + // only selected but remains open. Since this is a toggle event, removing + // the "open" attribute instead forces the event to add it back which keeps it open. + if (event.target.parentElement.open && !event.target.classList.contains("selected")) { + // before event handler removes "open" + event.target.parentElement.removeAttribute("open"); + } + var els = document.querySelectorAll(".folder-item-summary.selected"); [...els].forEach(el => { el.classList.remove("selected"); @@ -261,18 +268,9 @@ function extraNetworksFolderClick(event, tabs_id) { event.target.classList.add("selected"); var searchTextArea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); - var text = event.target.classList.contains("search-all") ? "" : event.target.firstChild.textContent.trim(); + var text = event.target.classList.contains("search-all") ? "" : event.target.getAttribute("data-path"); searchTextArea.value = text; updateInput(searchTextArea); - - if (event.target.parentElement.open) { - // before close - console.log("closed"); - } else { - // before open - console.log("opened"); - //console.log("Opened:", event.target.parentElement); - } } function extraNetworksSearchButton(tabs_id, event) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6318594fa..2e226ba00 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -48,23 +48,24 @@ def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) if isinstance(paths, (str,)): paths = [paths] - def _get_tree(_paths: list[str]): + def _get_tree(_paths: list[str], _root: str): _res = {} for path in _paths: + relpath = os.path.relpath(path, _root) if os.path.isdir(path): dir_items = os.listdir(path) # Ignore empty directories. if not dir_items: continue - dir_tree = _get_tree([os.path.join(path, x) for x in dir_items]) + dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root) # We only want to store non-empty folders in the tree. if dir_tree: - _res[os.path.basename(path)] = dir_tree + _res[relpath] = dir_tree else: if path not in items: continue # Add the ExtraNetworksItem to the result. - _res[os.path.basename(path)] = items[path] + _res[relpath] = items[path] return _res res = {} @@ -73,11 +74,13 @@ def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) # the value can be an empty dict if the directory is empty. We want these # placeholders for empty dirs so we can inform the user later. for path in paths: - short_path = os.path.basename(path) + root = os.path.dirname(path) + relpath = os.path.relpath(path, root) # Wrap the path in a list since that is what the `_get_tree` expects. - res[short_path] = _get_tree([path]) - if res[short_path]: - res[short_path] = res[short_path][os.path.basename(path)] + res[relpath] = _get_tree([path], root) + if res[relpath]: + # We need to pull the inner path out one for these root dirs. + res[relpath] = res[relpath][relpath] return res @@ -245,6 +248,17 @@ class ExtraNetworksPage: sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip() + search_terms_html = "" + search_term_template = "{search_term}" + for search_term in item.get("search_terms", []): + search_terms_html += search_term_template.format( + **{ + "style": "display: none;", + "class": "search_terms" + (" search_only" if search_only else ""), + "search_term": search_term, + } + ) + # Some items here might not be used depending on HTML template used. args = { "background_image": background_image, @@ -258,7 +272,7 @@ class ExtraNetworksPage: "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_only": " search_only" if search_only else "", - "search_term": item.get("search_term", ""), + "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "tabname": quote_js(tabname), @@ -278,7 +292,7 @@ class ExtraNetworksPage: Returns: HTML string generated for this tree view. """ - res = f"" + res = "" # Generate HTML for the tree. roots = self.allowed_directories_for_previews() @@ -291,7 +305,8 @@ class ExtraNetworksPage: file_template = "
  • {card}
  • " dir_template = ( "
    " - "" + "" "{folder_name}" "" "
      {content}
    " @@ -309,16 +324,15 @@ class ExtraNetworksPage: item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) _res += file_template.format(**{"card": item_html}) else: - tmp = dir_template.format( + _res += dir_template.format( **{ "attributes": "", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) - _res += tmp - return _res # Add each root directory to the tree. @@ -329,65 +343,15 @@ class ExtraNetworksPage: **{ "attributes": "open" if v else "open", "tabname": tabname, - "folder_name": k, + "folder_name": os.path.basename(k), + "data_path": k, "content": _build_tree(v), } ) res += "" res += "" - return res - def create_subdirs_html(self, tabname): - subdirs = {} - - for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: - for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])): - for dirname in sorted(dirs, key=shared.natural_sort_key): - x = os.path.join(root, dirname) - - if not os.path.isdir(x): - continue - - subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/") - - if shared.opts.extra_networks_dir_button_function: - if not subdir.startswith("/"): - subdir = "/" + subdir - else: - while subdir.startswith("/"): - subdir = subdir[1:] - - is_empty = len(os.listdir(x)) == 0 - if not is_empty and not subdir.endswith("/"): - subdir = subdir + "/" - - if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories: - continue - - subdirs[subdir] = 1 - - if subdirs: - subdirs = {"": 1, **subdirs} - - subdirs_html_template = ( - "" - ) - return "".join( - [ - subdirs_html_template.format( - **{ - "classes": "search-all" if not subdir else "", - "tabname": tabname, - "content": html.escape(subdir if subdir else "all"), - } - ) for subdir in subdirs - ] - ) - def create_card_view_html(self, tabname): res = "" self.items = {x["name"]: x for x in self.list_items()} @@ -405,7 +369,6 @@ class ExtraNetworksPage: self.items = {x["name"]: x for x in self.list_items()} tree_view_html = self.create_tree_view_html(tabname) - subdirs_html = self.create_subdirs_html(tabname) card_view_html = self.create_card_view_html(tabname) network_type_id = self.name.replace(" ", "_") @@ -414,7 +377,6 @@ class ExtraNetworksPage: "tabname": tabname, "network_type_id": network_type_id, "tree_html": tree_view_html, - "subdirs_html": subdirs_html, "items_html": card_view_html, } ) @@ -534,7 +496,12 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.id_page}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + page_elem.change( + fn=lambda: None, + _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + inputs=[], + outputs=[], + ) editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() @@ -542,18 +509,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) tab_controls = [ edit_search, dropdown_sort, button_sortorder, button_refresh, - checkbox_show_dirs, ] ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) @@ -562,7 +527,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for tab in unrelated_tabs: tab.select( fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js="function(){ extraNetworksUrelatedTabSelected('" + tabname + "'); }", + _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", inputs=[], outputs=tab_controls, show_progress=False, diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 1693e71f1..e7976ba12 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -21,13 +21,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): return path, ext = os.path.splitext(checkpoint.filename) + search_terms = [self.search_terms_from_path(checkpoint.filename)] + if checkpoint.sha256: + search_terms.append(checkpoint.sha256) return { "name": checkpoint.name_for_extra, "filename": checkpoint.filename, "shorthash": checkpoint.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), + "search_terms": search_terms, "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index c96c4fa3b..2fb4bd190 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -20,14 +20,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): path, ext = os.path.splitext(full_path) sha256 = sha256_from_cache(full_path, f'hypernet/{name}') shorthash = sha256[0:10] if sha256 else None - + search_terms = [self.search_terms_from_path(path)] + if sha256: + search_terms.append(sha256) return { "name": name, "filename": full_path, "shorthash": shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""), + "search_terms": search_terms, "prompt": quote_js(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 1b334fda1..deb7cb873 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -18,13 +18,16 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): return path, ext = os.path.splitext(embedding.filename) + search_terms = [self.search_terms_from_path(embedding.filename)] + if embedding.hash: + search_terms.append(embedding.hash) return { "name": name, "filename": embedding.filename, "shorthash": embedding.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""), + "search_terms": search_terms, "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, diff --git a/style.css b/style.css index 4f285c68b..70d80d6a1 100644 --- a/style.css +++ b/style.css @@ -878,16 +878,8 @@ footer { margin: 0.3em; } -.extra-network-subdirs{ - padding: 0.2em 0.35em; -} - -.extra-network-subdirs button{ - margin: 0 0.15em; -} .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort, -.extra-networks .tab-nav .show-dirs +.extra-networks .tab-nav .sort { margin: 0.3em; align-self: center; @@ -1196,6 +1188,10 @@ body.resizing .resize-handle { flex-grow: 1; } +.folder-container { + margin-left: 1.5em !important; +} + .file-item, .folder-item, .folder-item-summary { @@ -1235,7 +1231,7 @@ body.resizing .resize-handle { text-overflow: ellipsis; } -.extra-network-pane .extra-network tree ul.folder-container { +.extra-network-pane .extra-network-tree ul.folder-container { list-style: none; font-size: 1rem; text-overflow: ellipsis; @@ -1257,15 +1253,15 @@ body.resizing .resize-handle { display: flex; } -.extra-network-pane .extra-network-subdirs { - display: block; -} .extra-network-pane .extra-network-tree { font-size: 1rem; - width: 25%; + min-width: 25%; + max-width: 25%; + border: 1px solid var(--block-border-color); } .extra-network-pane .extra-network-cards { flex-grow: 1; + border: 1px solid var(--block-border-color); } .dark .extra-network-tree .folder-item-summary.selected{ From 541881e3188b4f59d82043c5fdfac02607ec3119 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:11:06 -0800 Subject: [PATCH 240/311] Adjust brush size with hotkeys. --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 45c7600ac..c695debc6 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,6 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", + canvas_hotkey_shrink_brush: "BracketLeft", + canvas_hotkey_grow_brush: "BracketRight", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, @@ -227,6 +229,8 @@ onUiLoaded(async() => { const functionMap = { "Zoom": "canvas_hotkey_zoom", "Adjust brush size": "canvas_hotkey_adjust", + "Shrink brush size": "canvas_hotkey_shrink_brush", + "Grow brush size": "canvas_hotkey_grow_brush", "Moving canvas": "canvas_hotkey_move", "Fullscreen": "canvas_hotkey_fullscreen", "Reset Zoom": "canvas_hotkey_reset", @@ -686,7 +690,9 @@ onUiLoaded(async() => { const hotkeyActions = { [hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, - [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen, + [defaultHotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10), + [defaultHotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10) }; const action = hotkeyActions[event.code]; From 47b52d9b28aaabb603358fae5d9b824c49aa627b Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Sat, 13 Jan 2024 02:31:26 -0800 Subject: [PATCH 241/311] Add # to the invalid_filename_chars list --- modules/images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/images.py b/modules/images.py index 87a7bf221..b6f2358c3 100644 --- a/modules/images.py +++ b/modules/images.py @@ -321,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): return res -invalid_filename_chars = '<>:"/\\|?*\n\r\t' +invalid_filename_chars = '#<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') From b6dc307c99a25bc3f3f3e96ce640a4710c14e133 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 13 Jan 2024 14:45:15 +0400 Subject: [PATCH 242/311] fix_extension_check_for_requirements --- modules/extensions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index 99e7ee60f..04bda297e 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -224,13 +224,16 @@ def list_extensions(): # check for requirements for extension in extensions: + if not extension.enabled: + continue + for req in extension.metadata.requires: required_extension = loaded_extensions.get(req) if required_extension is None: errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False) continue - if not extension.enabled: + if not required_extension.enabled: errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False) continue From 02e6963325e5221e0efb96a63f3dc849550489b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 13 Jan 2024 13:16:39 -0500 Subject: [PATCH 243/311] continue cleanup and redesign. --- html/extra-networks-tree-button.html | 11 + javascript/extraNetworks.js | 296 ++++++++++++++++-------- modules/ui_extra_networks.py | 181 ++++++++++++--- style.css | 325 ++++++++++++++++++++------- 4 files changed, 599 insertions(+), 214 deletions(-) create mode 100644 html/extra-networks-tree-button.html diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html new file mode 100644 index 000000000..920330f7a --- /dev/null +++ b/html/extra-networks-tree-button.html @@ -0,0 +1,11 @@ + + \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3e3b03f32..cce224689 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -16,88 +16,110 @@ function toggleCss(key, css, enable) { } function setupExtraNetworksForTab(tabname) { - gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); + this_tab.classList.add('extra-networks'); - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); - var search = searchDiv.querySelector('textarea'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); + function registerPrompt(tabname, id) { + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - tabs.appendChild(searchDiv); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); - - var applyFilter = function() { - var searchTerm = search.value.toLowerCase(); - - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { - var searchOnly = elem.querySelector('.search_only'); - - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); - - var visible = text.indexOf(searchTerm) != -1; - - if (searchOnly && searchTerm.length < 4) { - visible = false; - } - - elem.style.display = visible ? "" : "none"; - }); - - applySort(); - }; - - var applySort = function() { - var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; - sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); - var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; - - if (sortKeyStore == sort.dataset.sortkey) { - return; + if (!activePromptTextarea[tabname]) { + activePromptTextarea[tabname] = textarea; } - sort.dataset.sortkey = sortKeyStore; - cards.forEach(function(card) { - card.originalParentElement = card.parentElement; + textarea.addEventListener("focus", function() { + activePromptTextarea[tabname] = textarea; }); - var sortedCards = Array.from(cards); - sortedCards.sort(function(cardA, cardB) { - var a = cardA.dataset[sortKey]; - var b = cardB.dataset[sortKey]; - if (!isNaN(a) && !isNaN(b)) { - return parseInt(a) - parseInt(b); + } + + this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { + var tab_id = elem.getAttribute("id"); + + var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); + var searchDiv = gradioApp().QuerySelector("#" + tab_id + "_extra_search"); + console.log("HERE:", tab_id + "_extra_search", searchDiv); + var search = searchDiv.value; + var sort = gradioApp().getElementById(tabname + '_extra_sort'); + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); + var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); + var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); + tabs.appendChild(searchDiv); + tabs.appendChild(sort); + tabs.appendChild(sortOrder); + tabs.appendChild(refresh); + var applyFilter = function() { + var searchTerm = search.value.toLowerCase(); + + gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { + var searchOnly = elem.querySelector('.search_only'); + + var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase() }).join(" "); + + var visible = text.indexOf(searchTerm) != -1; + + if (searchOnly && searchTerm.length < 4) { + visible = false; + } + + elem.style.display = visible ? "" : "none"; + }); + + applySort(); + }; + + var applySort = function() { + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + + var reverse = sortOrder.classList.contains("sortReverse"); + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); + var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; + + if (sortKeyStore == sort.dataset.sortkey) { + return; } + sort.dataset.sortkey = sortKeyStore; + + cards.forEach(function(card) { + card.originalParentElement = card.parentElement; + }); + var sortedCards = Array.from(cards); + sortedCards.sort(function(cardA, cardB) { + var a = cardA.dataset[sortKey]; + var b = cardB.dataset[sortKey]; + if (!isNaN(a) && !isNaN(b)) { + return parseInt(a) - parseInt(b); + } + + return (a < b ? -1 : (a > b ? 1 : 0)); + }); + if (reverse) { + sortedCards.reverse(); + } + cards.forEach(function(card) { + card.remove(); + }); + sortedCards.forEach(function(card) { + card.originalParentElement.appendChild(card); + }); + }; + + search.addEventListener("input", applyFilter); + sortOrder.addEventListener("click", function() { + sortOrder.classList.toggle("sortReverse"); + applySort(); + }); + applyFilter(); + + extraNetworksApplySort[tab_id] = applySort; + extraNetworksApplyFilter[tab_id] = applyFilter; - return (a < b ? -1 : (a > b ? 1 : 0)); - }); - if (reverse) { - sortedCards.reverse(); - } - cards.forEach(function(card) { - card.remove(); - }); - sortedCards.forEach(function(card) { - card.originalParentElement.appendChild(card); - }); - }; - - search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); + registerPrompt(tab_id, tab_id + "_prompt"); + registerPrompt(tab_id, tab_id + "_neg_prompt"); }); - applyFilter(); - extraNetworksApplySort[tabname] = applySort; - extraNetworksApplyFilter[tabname] = applyFilter; + + } function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) { @@ -136,12 +158,12 @@ function clearSearch(tabname) { function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); - clearSearch(tabname); + //clearSearch(tabname); } function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); - clearSearch(tabname); + //clearSearch(tabname); } function applyExtraNetworkFilter(tabname) { @@ -159,23 +181,6 @@ var activePromptTextarea = {}; function setupExtraNetworks() { setupExtraNetworksForTab('txt2img'); setupExtraNetworksForTab('img2img'); - - function registerPrompt(tabname, id) { - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (!activePromptTextarea[tabname]) { - activePromptTextarea[tabname] = textarea; - } - - textarea.addEventListener("focus", function() { - activePromptTextarea[tabname] = textarea; - }); - } - - registerPrompt('txt2img', 'txt2img_prompt'); - registerPrompt('txt2img', 'txt2img_neg_prompt'); - registerPrompt('img2img', 'img2img_prompt'); - registerPrompt('img2img', 'img2img_neg_prompt'); } onUiLoaded(setupExtraNetworks); @@ -262,6 +267,106 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } +function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { + /** + * Processes `onclick` events when user clicks on files in tree. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var par = btn.parentElement; + var search_id = tabname + "_" + tab_id + "_extra_search"; + var type = par.getAttribute("data-tree-entry-type"); + var path = par.getAttribute("data-path"); +} + +function extraNetworksTreeProcessDirectoryClick(event, btn) { + /** + * Processes `onclick` events when user clicks on directories in tree. + * + * Here is how the tree reacts to clicks for various states: + * unselected unopened directory: Diretory is selected and expanded. + * unselected opened directory: Directory is selected. + * selected opened directory: Directory is collapsed and deselected. + * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. + * + * @param event The generated event. + * @param btn The clicked `action-list-item` button. + */ + var ul = btn.nextElementSibling; + // This is the actual target that the user clicked on within the target button. + // We use this to detect if the chevron was clicked. + var true_targ = event.target; + + function _expand_or_collapse(_ul, _btn) { + // Expands
      if it is collapsed, collapses otherwise. Updates button attributes. + if (_ul.hasAttribute("data-hidden")) { + _ul.removeAttribute("data-hidden"); + _btn.setAttribute("expanded", "true"); + } else { + _ul.setAttribute("data-hidden", ""); + _btn.setAttribute("expanded", "false"); + } + } + + function _remove_selected_from_all() { + // Removes the `selected` attribute from all buttons. + var sels = document.querySelectorAll("button.action-list-content"); + [...sels].forEach(el => { + el.removeAttribute("selected"); + }) + } + + function _select_button(_btn) { + // Removes `selected` attribute from all buttons then adds to passed button. + _remove_selected_from_all(); + _btn.setAttribute("selected", ""); + } + + // If user clicks on the chevron, then we do not select the folder. + if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + _expand_or_collapse(ul, btn); + } else { + // User clicked anywhere else on the button. + if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + // If folder is select and open, collapse and deselect button. + _expand_or_collapse(ul, btn); + btn.removeAttribute("selected"); + } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + // If folder is open and not selected, then we don't collapse; just select. + // NOTE: Double inversion sucks but it is the clearest way to show the branching here. + _expand_or_collapse(ul, btn); + _select_button(btn); + } else { + // All other cases, just select the button. + _select_button(btn); + } + + } +} + +function extraNetworksTreeOnClick(event, tabname, tab_id) { + /** + * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * + * Determines whether the clicked button in the tree is for a file entry or a directory + * then calls the appropriate function. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + var btn = event.currentTarget; + var par = btn.parentElement; + if (par.getAttribute("data-tree-entry-type") === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + } else { + extraNetworksTreeProcessDirectoryClick(event, btn); + } +} + function extraNetworksFolderClick(event, tabs_id) { // If folder is open but not selected, we don't want to collapse it. Instead // we override the removal of the "open" attribute so that the folder is @@ -434,3 +539,10 @@ window.addEventListener("keydown", function(event) { closePopup(); } }); + +function testprint(e) { + console.log(e); +} + +const testinput = gradioApp().querySelector("#txt2img_lora_extra_search"); +testinput.addEventListener("input", testprint); \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 093ac7b4f..9cf5b57fb 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,6 +19,90 @@ extra_pages = [] allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] +tree_tpl = ( + "" + "
        " + "{content}" + "
      " +) + +tree_ul_tpl = ( + "
        " + "{content}" + "
      " +) + +tree_li_dir_tpl = ( + "
    • " + "{content}" + "
    • " +) +tree_li_file_tpl = ( + "
    • " + "{content}" + "
    • " +) + +tree_btn_dir_tpl = ( + "" +) + +tree_btn_file_action_buttons_tpl = ( + "
      " + "
      " + "
      " + "
      " + "
      " + "
      " +) + +tree_btn_file_tpl = ( + "" + "" +) + + @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -160,6 +244,7 @@ class ExtraNetworksPage: self.extra_networks_pane_template = shared.html("extra-networks-pane.html") self.card_page_template = shared.html("extra-networks-card.html") self.card_page_minimal_template = shared.html("extra-networks-card-minimal.html") + self.tree_button_template = shared.html("extra-networks-tree-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -279,7 +364,9 @@ class ExtraNetworksPage: "search_terms": search_terms_html, "sort_keys": sort_keys, "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", - "tabname": quote_js(tabname), + "tabname": tabname, + "tab_id": self.id_page, + } if template: @@ -306,55 +393,81 @@ class ExtraNetworksPage: if not tree: return res - file_template = "
    • {card}
    • " - dir_template = ( - "
      " - "" - "{folder_name}" - "" - "
        {content}
      " - "
      " - ) - def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> str: """Recursively builds HTML for a tree.""" _res = "" if not data: - return "
    • DIRECTORY IS EMPTY
    • " + return ( + "
      " + "Directory is empty" + "
      " + ) for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) - _res += file_template.format(**{"card": item_html}) - else: - _res += dir_template.format( + _action_buttons = tree_btn_file_action_buttons_tpl.format( **{ - "attributes": "", - "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "path": quote_js(k), + "filename": quote_js(v.item["name"]), + "tabname": quote_js(tabname), + "tab_id": quote_js(self.id_page), } ) + _btn = tree_btn_file_tpl.format( + **{ + "label": v.item["name"], + "filter": v.item["search_terms"], + "tabname": tabname, + "tab_id": self.id_page, + "buttons": _action_buttons, + } + ) + _li = tree_li_file_tpl.format( + **{ + "hash": v.item["shorthash"], + "path": k, + "type": "file", + #"content": _btn, + "content": self.create_item_html(tabname, v.item, self.tree_button_template), + } + ) + _res += _li + #item_html = self.create_item_html(tabname, v.item, self.card_page_minimal_template) + #_res += file_template.format(**{"card": item_html}) + else: + _btn = tree_btn_dir_tpl.format( + **{ + "label": os.path.basename(k), + "tabname": tabname, + "tab_id": self.id_page, + } + ) + _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) + _res += _li return _res # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - res += "
        " - res += dir_template.format( + btn = tree_btn_dir_tpl.format( **{ - "attributes": "open" if v else "open", + "label": os.path.basename(k), "tabname": tabname, - "folder_name": os.path.basename(k), - "data_path": k, - "content": _build_tree(v), + "tab_id": self.id_page, } ) - res += "
      " - res += "
    " - return res + ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) + li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) + res += li + + return tree_tpl.format( + **{ + "content": res, + "tabname": tabname, + "tab_id": self.id_page, + } + ) def create_card_view_html(self, tabname): res = "" @@ -375,7 +488,7 @@ class ExtraNetworksPage: tree_view_html = self.create_tree_view_html(tabname) card_view_html = self.create_card_view_html(tabname) - network_type_id = self.name.replace(" ", "_") + network_type_id = self.id_page return self.extra_networks_pane_template.format( **{ @@ -506,7 +619,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", inputs=[], outputs=[], ) @@ -517,13 +630,11 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - edit_search = gr.Textbox('', show_label=False, elem_id=f"{tabname}_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) tab_controls = [ - edit_search, dropdown_sort, button_sortorder, button_refresh, diff --git a/style.css b/style.css index aaafaa9d7..8aa41088d 100644 --- a/style.css +++ b/style.css @@ -955,15 +955,15 @@ footer { color: white; } -.extra-network-pane .copy-path-button:before { +.extra-network-pane .copy-path-button::before { content: "⎘"; } -.extra-network-pane .metadata-button:before{ +.extra-network-pane .metadata-button::before{ content: "🛈"; } -.extra-network-pane .edit-button:before{ +.extra-network-pane .edit-button::before{ content: "🛠"; } @@ -1188,102 +1188,253 @@ body.resizing .resize-handle { border-left: 1px dashed var(--border-color-primary); } -.extra-network-pane .card-minimal { - display: inline-flex; - flex-grow: 1; - position: relative; +/* ========================= */ +.extra-network-pane { + display: flex; +} + +.extra-network-pane .extra-network-cards { + display: block; +} + +.extra-network-pane .extra-network-tree { + display: block; + font-size: 1rem; + min-width: 25%; + border: 1px solid var(--block-border-color); overflow: hidden; +} + +.extra-network-tree .action-list--tree { cursor: pointer; - font-size: 1rem; - font-weight: bold; - line-break: anywhere; -} - -/* Pushes buttons to right */ -.extra-network-pane .card-minimal .name { - flex-grow: 1; -} - -.folder-container { - margin-left: 1.5em !important; -} - -.file-item, -.folder-item, -.folder-item-summary { - padding-left: 0.05rem; - cursor: pointer; + -webkit-user-select: none; + -moz-user-select: none; + -ms-user-select: none; user-select: none; + margin: 0; + padding: 0; +} + +/* Remove auto indentation from tree. Will be overridden later. */ +.extra-network-tree .action-list--subgroup { + margin: 0 !important; + padding: 0 !important; + box-shadow: 0.6rem 0 0 var(--body-background-fill) inset, + 0.8rem 0 0 var(--neutral-800) inset; +} + +/* Set indentation for each depth of tree. */ +.extra-network-tree .action-list--subgroup > .action-list-item { + margin-left: 0.4rem !important; + padding-left: 0.4rem !important; +} + +/* Styles for tree
      elements. */ +.extra-network-tree .action-list { + +} + +/* Styles for tree
    • elements. */ +.extra-network-tree .action-list-item { + list-style: none; + position: relative; + background-color: transparent; +} + +/* Directory
        */ +.extra-network-tree .action-list-content[expanded=false]+.action-list--subgroup { + height: 0; + overflow: hidden; + visibility: hidden; + opacity: 0; +} + +.extra-network-tree .action-list-content[expanded=true]+.action-list--subgroup { + height: auto; + overflow: visible; + visibility: visible; + opacity: 1; +} + +/* File
      • */ +.extra-network-tree .action-list-item--subitem { +} + +/*
      • containing
          */ +.extra-network-tree .action-list-item--has-subitem { +} + +/* BUTTON ELEMENTS */ +/* \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 97a97d61f..56c9dc4c3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -267,7 +267,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { * Processes `onclick` events when user clicks on files in tree. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -288,7 +288,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * * @param event The generated event. - * @param btn The clicked `action-list-item` button. + * @param btn The clicked `tree-list-item` button. * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ @@ -310,7 +310,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.action-list-content"); + var sels = document.querySelectorAll("button.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -331,7 +331,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // If user clicks on the chevron, then we do not select the folder. - if (true_targ.matches(".action-list-item-action--leading, .action-list-item-action-chevron")) { + if (true_targ.matches(".tree-list-item-action--leading, .tree-list-item-action-chevron")) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. @@ -356,7 +356,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function extraNetworksTreeOnClick(event, tabname, tab_id) { /** - * Handles `onclick` events for buttons within an `extra-network-tree .action-list--tree`. + * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 9cf5b57fb..a49c6c1c5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -20,28 +20,28 @@ allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] tree_tpl = ( - " \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html new file mode 100644 index 000000000..4d29b1be2 --- /dev/null +++ b/html/extra-networks-tree.html @@ -0,0 +1,42 @@ +
          +
          + +
          + +
          +
          + +
          +
          + +
          +
          +
          + {tree} +
          +
          \ No newline at end of file diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 56c9dc4c3..cf98452a2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -37,15 +37,20 @@ function setupExtraNetworksForTab(tabname) { return; // `continue` doesn't work in `forEach` loops. This is equivalent. } - var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var sort = gradioApp().getElementById(tabname + '_extra_sort'); - var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); - var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container'); - var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt'); - tabs.appendChild(sort); - tabs.appendChild(sortOrder); - tabs.appendChild(refresh); + var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); + if (!sort) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); + if (!sort_dir) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } + + var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); + if (!refresh) { + return; // `continue` doesn't work in `forEach` loops. This is equivalent. + } var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -72,8 +77,8 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sortOrder.classList.contains("sortReverse"); - var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var reverse = sort_dir.dataset.sortdir == "Descending"; + var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; @@ -107,10 +112,7 @@ function setupExtraNetworksForTab(tabname) { }; search.addEventListener("input", applyFilter); - sortOrder.addEventListener("click", function() { - sortOrder.classList.toggle("sortReverse"); - applySort(); - }); + applySort(); applyFilter(); extraNetworksApplySort[tab_id] = applySort; @@ -274,7 +276,7 @@ function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { var par = btn.parentElement; var search_id = tabname + "_" + tab_id + "_extra_search"; var type = par.getAttribute("data-tree-entry-type"); - var path = par.getAttribute("data-path"); + var path = btn.getAttribute("data-path"); } function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { @@ -310,7 +312,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _remove_selected_from_all() { // Removes the `selected` attribute from all buttons. - var sels = document.querySelectorAll("button.tree-list-content"); + var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { el.removeAttribute("selected"); }); @@ -345,11 +347,11 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } else { // All other cases, just select the button. _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.parentElement.getAttribute("data-path")); + _update_search(tabname, tab_id, btn.getAttribute("data-path")); } } } @@ -374,6 +376,48 @@ function extraNetworksTreeOnClick(event, tabname, tab_id) { } } +function extraNetworksTreeSortOnClick(event, tabname, tab_id) { + var curr_mode = event.currentTarget.dataset.sortmode; + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var sort_dir = el_sort_dir.dataset.sortdir; + if (curr_mode == "path") { + event.currentTarget.dataset.sortmode = "name"; + event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by filename"); + } else if (curr_mode == "name") { + event.currentTarget.dataset.sortmode = "date_created"; + event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date created"); + } else if (curr_mode == "date_created") { + event.currentTarget.dataset.sortmode = "date_modified"; + event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by date modified"); + } else { + event.currentTarget.dataset.sortmode = "path"; + event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; + event.currentTarget.setAttribute("title", "Sort by path"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { + var curr_dir = event.currentTarget.getAttribute("data-sortdir"); + if (curr_dir == "Ascending") { + event.currentTarget.dataset.sortdir = "Descending"; + event.currentTarget.setAttribute("title", "Sort descending"); + } else { + event.currentTarget.dataset.sortdir = "Ascending"; + event.currentTarget.setAttribute("title", "Sort ascending"); + } + applyExtraNetworkSort(tabname + "_" + tab_id); +} + +function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { + console.log("refresh clicked"); + var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); + btn_refresh_internal.dispatchEvent(new Event("click")); +} + var globalPopup = null; var globalPopupInner = null; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a49c6c1c5..4ba2bea1a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -19,50 +19,6 @@ extra_pages = [] allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] -tree_tpl = ( - "" - "
            " - "{content}" - "
          " -) - -tree_ul_tpl = ( - "
            " - "{content}" - "
          " -) - -tree_li_dir_tpl = ( - "
        • " - "{content}" - "
        • " -) -tree_li_file_tpl = ( - "
        • " - "{content}" - "
        • " -) - -action_list_item_action_leading = ( - "" - "" - "" -) - @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) @@ -201,9 +157,13 @@ class ExtraNetworksPage: self.title = title self.name = title.lower() self.id_page = self.name.replace(" ", "_") - self.extra_networks_pane_template = shared.html("extra-networks-pane.html") - self.card_page_template = shared.html("extra-networks-card.html") - self.tree_button_template = shared.html("extra-networks-tree-button.html") + self.pane_tpl = shared.html("extra-networks-pane.html") + self.tree_tpl = shared.html("extra-networks-tree.html") + self.card_tpl = shared.html("extra-networks-card.html") + self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") + self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") + self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") + self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} @@ -268,12 +228,8 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - print("HERE") - print("TABNAME:", tabname) - print("PROMPT:", item["prompt"]) - print("NEG_PROMPT:", item.get("negative_prompt", "")) - print("ALLOW_NEG:", self.allow_negative_prompt) - onclick_js_tpl = "cardClicked('{tabname}', '{prompt}', '{neg_prompt}', '{allow_neg}');" + # Don't quote prompt/neg_prompt since they are stored as js strings already. + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -284,15 +240,23 @@ class ExtraNetworksPage: ) onclick = html.escape(onclick) - - copy_path_button = f"
          " - - metadata_button = "" + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) + btn_metadata = "" metadata = item.get("metadata") if metadata: - metadata_button = f"" - - edit_button = f"
          " + btn_metadata = self.btn_metadata_tpl.format( + **{ + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) + btn_edit_item = self.btn_edit_item_tpl.format( + **{ + "tabname": tabname, + "page_id": self.id_page, + "name": html.escape(item["name"]), + } + ) local_path = "" filename = item.get("filename", "") @@ -334,11 +298,11 @@ class ExtraNetworksPage: args = { "background_image": background_image, "card_clicked": onclick, - "copy_path_button": copy_path_button, + "copy_path_button": btn_copy_path, "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), - "edit_button": edit_button, + "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), - "metadata_button": metadata_button, + "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', @@ -355,6 +319,57 @@ class ExtraNetworksPage: else: return args + def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + if not content: + return None + + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-dir", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": "", + "data_path": dir_path, + "data_hash": "", + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗀", + "action_list_item_label": os.path.basename(dir_path), + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": "", + } + ) + ul = f"
            {content}
          " + return f"
        • {btn + ul}
        • " + + def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + item_html_args = self.create_item_html(tabname, item) + action_buttons = "".join( + [ + item_html_args["copy_path_button"], + item_html_args["metadata_button"], + item_html_args["edit_button"], + ] + ) + action_buttons = f"
          {action_buttons}
          " + btn = self.btn_tree_tpl.format( + **{ + "search_terms": "", + "subclass": "tree-list-content-file", + "tabname": tabname, + "tab_id": self.id_page, + "onclick_extra": item_html_args["card_clicked"], + "data_path": item_name, + "data_hash": item["shorthash"], + "action_list_item_action_leading": "", + "action_list_item_visual_leading": "🗎", + "action_list_item_label": item["name"], + "action_list_item_visual_trailing": "", + "action_list_item_action_trailing": action_buttons, + } + ) + return f"
        • {btn}
        • " + def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. @@ -385,57 +400,9 @@ class ExtraNetworksPage: for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])): if isinstance(v, (ExtraNetworksItem,)): - _item_html_args = self.create_item_html(tabname, v.item) - _action_buttons = "".join( - [ - _item_html_args["copy_path_button"], - _item_html_args["metadata_button"], - _item_html_args["edit_button"], - ] - ) - _action_buttons = f"
          {_action_buttons}
          " - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-file", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": _item_html_args["card_clicked"], - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗎", - "action_list_item_label": v.item["name"], - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": _action_buttons, - } - ) - - _li = tree_li_file_tpl.format( - **{ - "hash": v.item["shorthash"], - "path": k, - "type": "file", - "content": _btn, - } - ) - _file_li.append(_li) + _file_li.append(self.create_tree_file_item_html(tabname, k, v.item)) else: - _btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - _ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - _li = tree_li_dir_tpl.format(**{"content": _btn + _ul, "path": k}) - _dir_li.append(_li) + _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) # Directories should always be displayed before files. return "".join(_dir_li) + "".join(_file_li) @@ -443,31 +410,15 @@ class ExtraNetworksPage: # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): # If root is empty, append the "disabled" attribute to the template details tag. - btn = self.tree_button_template.format( - **{ - "search_terms": "", - "subclass": "tree-list-content-dir", - "tabname": tabname, - "tab_id": self.id_page, - "onclick_extra": "", - "action_list_item_action_leading": action_list_item_action_leading, - "action_list_item_visual_leading": "🗀", - "action_list_item_label": os.path.basename(k), - "action_list_item_visual_trailing": "", - "action_list_item_action_trailing": "", - } - ) - subtree = _build_tree(v) - if subtree: - ul = tree_ul_tpl.format(**{"content": _build_tree(v)}) - li = tree_li_dir_tpl.format(**{"content": btn + ul, "path": k}) - res += li + item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) + if item_html: + res += item_html - return tree_tpl.format( + return self.tree_tpl.format( **{ - "content": res, "tabname": tabname, "tab_id": self.id_page, + "tree": f"
            {res}
          " } ) @@ -475,8 +426,7 @@ class ExtraNetworksPage: res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): - print("HEEEERRE:", item) - res += self.create_item_html(tabname, item, self.card_page_template) + res += self.create_item_html(tabname, item, self.card_tpl) if res == "": dirs = "".join([f"
        • {x}
        • " for x in self.allowed_directories_for_previews()]) @@ -493,7 +443,7 @@ class ExtraNetworksPage: card_view_html = self.create_card_view_html(tabname) network_type_id = self.id_page - return self.extra_networks_pane_template.format( + return self.pane_tpl.format( **{ "tabname": tabname, "network_type_id": network_type_id, @@ -612,6 +562,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] + button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): @@ -633,51 +585,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) - dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - - tab_controls = [ - dropdown_sort, - button_sortorder, - button_refresh, - ] - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - for tab in unrelated_tabs: - tab.select( - fn=lambda: [gr.update(visible=False) for _ in tab_controls], - _js=f"function(){{ extraNetworksUnrelatedTabSelected('{tabname}'); }}", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - for page, tab in zip(ui.stored_extra_pages, related_tabs): - allow_prompt = "true" if page.allow_prompt else "false" - allow_negative_prompt = "true" if page.allow_negative_prompt else "false" - - jscode = ( - "extraNetworksTabSelected(" - f"'{tabname}', " - f"'{tabname}_{page.id_page}_prompts', " - f"'{allow_prompt}', " - f"'{allow_negative_prompt}'" - ");" - ) - - tab.select( - fn=lambda: [gr.update(visible=True) for _ in tab_controls], - _js="function(){ " + jscode + " }", - inputs=[], - outputs=tab_controls, - show_progress=False, - ) - - dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }") - def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] @@ -693,6 +603,8 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): return ui.pages_contents interface.load(fn=pages_html, inputs=[], outputs=ui.pages) + # NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick() + # button is unused and hidden at all times. Only used in order to fire this event. button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/style.css b/style.css index 2dafe97fd..085732486 100644 --- a/style.css +++ b/style.css @@ -1196,17 +1196,33 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-tree .tree-list--tree { - cursor: pointer; - -webkit-user-select: none; - -moz-user-select: none; - -ms-user-select: none; - user-select: none; - margin: 0; +.extra-network-tree .tree-list { + margin: 0 0.25rem; padding: 0; - margin-left: 0.25rem; } +.extra-network-tree .tree-list .tree-list-controls { + position: relative; + display: grid; + width: 100%; + padding: 0 !important; + margin-top: 0 !important; + margin-bottom: 0 !important; + font-size: 1rem; + text-align: left; + user-select: none; + background-color: transparent; + border: none; + transition: background 33.333ms linear; + grid-template-rows: min-content; + grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; + grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-gap: 0.1rem; + align-items: start; +} + +.extra-network-tree .tree-list--tree {} + /* Remove auto indentation from tree. Will be overridden later. */ .extra-network-tree .tree-list--subgroup { margin: 0 !important; @@ -1221,9 +1237,6 @@ body.resizing .resize-handle { padding-left: 0.4rem !important; } -/* Styles for tree
            elements. */ -.extra-network-tree .tree-list {} - /* Styles for tree
          • elements. */ .extra-network-tree .tree-list-item { list-style: none; @@ -1288,26 +1301,182 @@ body.resizing .resize-handle { padding-top: 0.5rem !important; } -.dark .extra-network-tree button.tree-list-content:hover { +.dark .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-800); } -.dark .extra-network-tree button.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-700); } -.extra-network-tree button.tree-list-content:hover { +.extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree button.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[selected] { background-color: var(--neutral-300); } +/* ==== CHEVRON ICON ACTIONS ==== */ +/* Define the animation for the arrow when it is clicked. */ +.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { + -ms-transform: rotate(135deg); + -webkit-transform: rotate(135deg); + transform: rotate(135deg); + transition: transform 0.2s; +} + +.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { + -ms-transform: rotate(225deg); + -webkit-transform: rotate(225deg); + transform: rotate(225deg); + transition: transform 0.2s; +} + +.tree-list-item-action-chevron { + display: inline-flex; + /* Uses box shadow to generate a pseudo chevron `>` icon. */ + padding: 0.3rem; + box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; + transform: rotate(135deg); +} + +/* ==== SEARCH INPUT ACTIONS ==== */ +/* Add icon to left side of */ +.extra-network-tree .tree-list-controls .tree-list-search::before { + content: "🔎︎"; + position: absolute; + margin: 0.5rem; + font-size: 1rem; + color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-controls .tree-list-search { + display: inline-flex; + grid-area: tree-list-controls-col-0; + position: relative; + margin: 0.5rem; +} + +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { + border: 1px solid var(--button-secondary-border-color); + border-radius: 0.5rem; + color: var(--button-secondary-text-color); + background-color: transparent; + width: 100%; + padding-left: 2rem; + line-height: 1rem; +} + +/* clear button (x on right side) styling */ +.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { + -webkit-appearance: none; + appearance: none; + cursor: pointer; + height: 1rem; + width: 1rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +/* ==== SORT ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort { + grid-area: tree-list-controls-col-1; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== SORT DIRECTION ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-sort-dir { + grid-area: tree-list-controls-col-2; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { + height: 1.5rem; + width: 1.5rem; + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { + mask-image: url('data:image/svg+xml,'); +} + +/* ==== REFRESH ICON ACTIONS ==== */ +.extra-network-tree .tree-list-controls .tree-list-refresh { + grid-area: tree-list-controls-col-3; + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.extra-network-tree .tree-list-refresh-icon:active { + -ms-transform: rotate(180deg); + -webkit-transform: rotate(180deg); + transform: rotate(180deg); + transition: transform 0.2s; +} + +/* ==== TREE GRID CONFIG ==== */ + /* Text for button. */ .extra-network-tree .tree-list-item-label { position: relative; @@ -1332,6 +1501,7 @@ body.resizing .resize-handle { align-items: right; } + /* Icon for button when it is before label. */ .extra-network-tree .tree-list-item-visual--leading { grid-area: leading-visual; @@ -1348,7 +1518,7 @@ body.resizing .resize-handle { /* Dropdown arrow for button. */ .extra-network-tree .tree-list-item-action--leading { - margin-right: 0.2rem; + margin-right: 0.5rem; margin-left: 0.2rem; } @@ -1356,30 +1526,6 @@ body.resizing .resize-handle { visibility: hidden; } -/* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { - -ms-transform: rotate(135deg); - -webkit-transform: rotate(135deg); - transform: rotate(135deg); - transition: transform 0.2s; -} - -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { - -ms-transform: rotate(225deg); - -webkit-transform: rotate(225deg); - transform: rotate(225deg); - transition: transform 0.2s; -} - -.tree-list-item-action-chevron { - display: inline-flex; - /* Uses box shadow to generate a pseudo chevron `>` icon. */ - padding: 0.3rem; - box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset; - transform: rotate(135deg); -} - - .extra-network-tree .tree-list-item-action--leading { grid-area: leading-action; } @@ -1399,41 +1545,3 @@ body.resizing .resize-handle { .extra-network-tree .tree-list-content:hover .button-row { visibility: visible; } - -/* Add icon to left side of */ -.extra-network-tree .tree-list-search::before { - content: "🔎︎"; - position: absolute; - margin: 0.5rem; - font-size: 1rem; - color: var(--input-placeholder-color); -} - -.extra-network-tree .tree-list-search { - position: relative; - margin: 0.5rem; -} - -.extra-network-tree .tree-list-search .tree-list-search-text { - border: 1px solid var(--button-secondary-border-color); - border-radius: 0.5rem; - color: var(--button-secondary-text-color); - background-color: transparent; - width: 100%; - padding-left: 2rem; - line-height: 1rem; -} - -/* clear button (x on right side) styling */ -.extra-network-tree .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { - -webkit-appearance: none; - appearance: none; - cursor: pointer; - height: 1rem; - width: 1rem; - mask-image: url('data:image/svg+xml,'); - mask-repeat: no-repeat; - mask-position: center center; - mask-size: 100%; - background-color: var(--input-placeholder-color); -} \ No newline at end of file From 1fdc18e6a01eb40889d46fab40f21aa138d64b01 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Mon, 15 Jan 2024 18:01:13 -0500 Subject: [PATCH 255/311] Run linting --- modules/ui_extra_networks.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4ba2bea1a..06fa22afa 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -13,7 +13,6 @@ import html from fastapi.exceptions import HTTPException from modules.infotext_utils import image_from_url_text -from modules.ui_components import ToolButton extra_pages = [] allowed_dirs = set() @@ -225,7 +224,6 @@ class ExtraNetworksPage: width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' background_image = f'' if preview else '' - onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. @@ -239,7 +237,7 @@ class ExtraNetworksPage: } ) onclick = html.escape(onclick) - + btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]}) btn_metadata = "" metadata = item.get("metadata") @@ -551,8 +549,6 @@ def pages_in_preferred_order(pages): return sorted(pages, key=lambda x: tab_scores[x.name]) def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): - from modules.ui import switch_values_symbol - ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -588,6 +584,9 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + for tab in unrelated_tabs: + tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] From c1e04c63b3d2a5102b4cf6deddea337c8a964c53 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:18:20 +0900 Subject: [PATCH 256/311] callback postprocess_image_after_composite --- modules/processing.py | 5 +++++ modules/scripts.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe3..2c3365a02 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1029,6 +1029,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = apply_overlay(image, p.paste_to, overlay_image) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image_after_composite(p, pp) + image = pp.image + if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) diff --git a/modules/scripts.py b/modules/scripts.py index cf938ebb9..060069cf3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -262,6 +262,15 @@ class Script: pass + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + Same as postprocess_image but after inpaint_full_res composite + So that it operates on the full image instead of the inpaint_full_res crop region. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -856,6 +865,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) + def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image_after_composite(p, pp, *script_args) + except Exception: + errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True) + def before_component(self, component, **kwargs): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): try: From 14b9762bcab41fe0f7411498d9d3ffdc69ea1dda Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 17:08:07 +0900 Subject: [PATCH 257/311] immediately stop on second interrupt Revert "immediately stop on second interrupt" This reverts commit ab409072a1f2a9c911a63aee98a6b42081803cdc. immediately stop on second interrupt --- modules/ui_toprow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 1abc9117b..fbe705be1 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -107,8 +107,9 @@ class Toprow: ) def interrupt_function(): - if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current: shared.state.stop_generating() + gr.Info("Generation will stop after finishing this image, click again to stop immediately.") else: shared.state.interrupt() From a75dfe1c0de1564f4607a6f4ed7f4a7c00ef18a0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:03:48 +0100 Subject: [PATCH 258/311] - expand fields to include model name and hash - write these in the CSV log file - ensure old log files are updated w.r.t delimiter count --- modules/ui_common.py | 47 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c29..cbb2495d6 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -36,6 +36,29 @@ def plaintext_to_html(text, classname=None): return f"

            {content}

            " if classname else f"

            {content}

            " +def update_logfile(logfile_path, fields): + import csv + + with open(logfile_path, "r", encoding="utf8", newline="") as file: + reader = csv.reader(file) + rows = list(reader) + + # blank file: leave it as is + if not rows: + return + + rows[0] = fields + + # append new fields to each row as empty values + for row in rows[1:]: + while len(row) < len(fields): + row.append("") + + with open(logfile_path, "w", encoding="utf8", newline="") as file: + writer = csv.writer(file) + writer.writerows(rows) + + def save_files(js_data, images, do_make_zip, index): import csv filenames = [] @@ -64,11 +87,31 @@ def save_files(js_data, images, do_make_zip, index): os.makedirs(shared.opts.outdir_save, exist_ok=True) + fields = [ + "prompt", + "seed", + "width", + "height", + "sampler", + "cfgs", + "steps", + "filename", + "negative_prompt", + "sd_model_name", + "sd_model_hash", + ] + logfile_path = os.path.join(shared.opts.outdir_save, "log.csv") + + # NOTE: ensure csv integrity when fields are added by + # updating headers and padding with delimeters where needed + if os.path.exists(logfile_path): + update_logfile(logfile_path, fields) + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + writer.writerow(fields) for image_index, filedata in enumerate(images, start_index): image = image_from_url_text(filedata) @@ -86,7 +129,7 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 315e40a49c32438551ed6b66138acdf664ecdbc8 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Tue, 16 Jan 2024 19:11:28 +0100 Subject: [PATCH 259/311] reuse variable for log file path --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index cbb2495d6..1db72092d 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -107,7 +107,7 @@ def save_files(js_data, images, do_make_zip, index): if os.path.exists(logfile_path): update_logfile(logfile_path, fields) - with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + with open(logfile_path, "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 writer = csv.writer(file) if at_start: From 4f9626703345ced77935e6bbb06de0b4522d53b7 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 13:35:01 -0500 Subject: [PATCH 260/311] Finish cleanup. --- html/extra-networks-card-minimal.html | 4 - html/extra-networks-card.html | 10 +- html/extra-networks-edit-item-button.html | 2 +- html/extra-networks-metadata-button.html | 2 +- html/extra-networks-pane.html | 6 +- html/extra-networks-tree-button.html | 3 +- html/extra-networks-tree.html | 14 +- javascript/extraNetworks.js | 168 +++++++++++---------- modules/ui_extra_networks.py | 162 +++++++++++++++----- modules/ui_extra_networks_user_metadata.py | 2 +- style.css | 81 +++++++--- 11 files changed, 287 insertions(+), 167 deletions(-) delete mode 100644 html/extra-networks-card-minimal.html diff --git a/html/extra-networks-card-minimal.html b/html/extra-networks-card-minimal.html deleted file mode 100644 index d66df7dfb..000000000 --- a/html/extra-networks-card-minimal.html +++ /dev/null @@ -1,4 +0,0 @@ -
            - {name} - {copy_path_button}{metadata_button}{edit_button} -
            diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index ca683dc47..f1d959a67 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,9 +1,9 @@ -
            +
            {background_image}
            {copy_path_button}{metadata_button}{edit_button}
            -
            -
            {search_terms}
            - {name} - {description} +
            +
            {search_terms}
            + {name} + {description}
            diff --git a/html/extra-networks-edit-item-button.html b/html/extra-networks-edit-item-button.html index 7d2677d9f..0fe43082a 100644 --- a/html/extra-networks-edit-item-button.html +++ b/html/extra-networks-edit-item-button.html @@ -1,4 +1,4 @@
            + onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
            \ No newline at end of file diff --git a/html/extra-networks-metadata-button.html b/html/extra-networks-metadata-button.html index ad6d6f416..285b5b3b6 100644 --- a/html/extra-networks-metadata-button.html +++ b/html/extra-networks-metadata-button.html @@ -1,4 +1,4 @@ \ No newline at end of file diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 20cf66867..bf46ca163 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,8 @@ -
            -
            +
            +
            {tree_html}
            -
            +
            {items_html}
            \ No newline at end of file diff --git a/html/extra-networks-tree-button.html b/html/extra-networks-tree-button.html index 20a9b0b86..9dc2e2a40 100644 --- a/html/extra-networks-tree-button.html +++ b/html/extra-networks-tree-button.html @@ -1,8 +1,7 @@
            diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 4d29b1be2..23f6af105 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -2,36 +2,36 @@
            diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index cf98452a2..a3f003bf4 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,25 +31,15 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var tab_id = elem.getAttribute("id"); - var search = gradioApp().querySelector("#" + tab_id + "_extra_search"); - if (!search) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } + var extra_networks_tabname = elem.id; + var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); - var sort = gradioApp().querySelector("#" + tab_id + "_extra_sort"); - if (!sort) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var sort_dir = gradioApp().querySelector("#" + tab_id + "_extra_sort_dir"); - if (!sort_dir) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. - } - - var refresh = gradioApp().querySelector("#" + tab_id + "_extra_refresh"); - if (!refresh) { - return; // `continue` doesn't work in `forEach` loops. This is equivalent. + // If any of the buttons above don't exist, we want to skip this iteration of the loop. + if (!search || !sort_mode || !sort_dir || !refresh) { + return; // `return` is equivalent of `continue` but for forEach loops. } var applyFilter = function() { @@ -78,14 +68,14 @@ function setupExtraNetworksForTab(tabname) { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); var reverse = sort_dir.dataset.sortdir == "Descending"; - var sortKey = sort.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; + var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length; - if (sortKeyStore == sort.dataset.sortkey) { + if (sortKeyStore == sort_mode.dataset.sortkey) { return; } - sort.dataset.sortkey = sortKeyStore; + sort_mode.dataset.sortkey = sortKeyStore; cards.forEach(function(card) { card.originalParentElement = card.parentElement; @@ -115,8 +105,8 @@ function setupExtraNetworksForTab(tabname) { applySort(); applyFilter(); - extraNetworksApplySort[tab_id] = applySort; - extraNetworksApplyFilter[tab_id] = applyFilter; + extraNetworksApplySort[extra_networks_tabname] = applySort; + extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -148,14 +138,6 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp } } -function clearSearch(tabname) { - // Clear search box. - var tab_id = tabname + "_extra_search"; - var searchTextarea = gradioApp().querySelector("#" + tab_id + ' > label > textarea'); - searchTextarea.value = ""; - updateInput(searchTextarea); -} - function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate) extraNetworksMovePromptToTab(tabname, '', false, false); @@ -264,22 +246,20 @@ function saveCardPreview(event, tabname, filename) { event.preventDefault(); } -function extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on files in tree. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ - var par = btn.parentElement; - var search_id = tabname + "_" + tab_id + "_extra_search"; - var type = par.getAttribute("data-tree-entry-type"); - var path = btn.getAttribute("data-path"); + // NOTE: Currently unused. + return; } -function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { +function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) { /** * Processes `onclick` events when user clicks on directories in tree. * @@ -289,10 +269,10 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { * selected opened directory: Directory is collapsed and deselected. * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged. * - * @param event The generated event. - * @param btn The clicked `tree-list-item` button. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param btn The clicked `tree-list-item` button. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var ul = btn.nextElementSibling; // This is the actual target that the user clicked on within the target button. @@ -301,12 +281,12 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { function _expand_or_collapse(_ul, _btn) { // Expands
              if it is collapsed, collapses otherwise. Updates button attributes. - if (_ul.hasAttribute("data-hidden")) { - _ul.removeAttribute("data-hidden"); - _btn.setAttribute("expanded", "true"); + if (_ul.hasAttribute("hidden")) { + _ul.removeAttribute("hidden"); + _btn.dataset.expanded = ""; } else { - _ul.setAttribute("data-hidden", ""); - _btn.setAttribute("expanded", "false"); + _ul.setAttribute("hidden", ""); + delete _btn.dataset.expanded; } } @@ -314,19 +294,19 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { // Removes the `selected` attribute from all buttons. var sels = document.querySelectorAll("div.tree-list-content"); [...sels].forEach(el => { - el.removeAttribute("selected"); + delete el.dataset.selected; }); } function _select_button(_btn) { - // Removes `selected` attribute from all buttons then adds to passed button. + // Removes `data-selected` attribute from all buttons then adds to passed button. _remove_selected_from_all(); - _btn.setAttribute("selected", ""); + _btn.dataset.selected = ""; } - function _update_search(_tabname, _tab_id, _search_text) { + function _update_search(_tabname, _extra_networks_tabname, _search_text) { // Update search input with select button's path. - var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_search"); + var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search"); search_input_elem.value = _search_text; updateInput(search_input_elem); } @@ -337,48 +317,58 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id) { _expand_or_collapse(ul, btn); } else { // User clicked anywhere else on the button. - if (btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden")) { + if ("selected" in btn.dataset && !(ul.hasAttribute("hidden"))) { // If folder is select and open, collapse and deselect button. _expand_or_collapse(ul, btn); - btn.removeAttribute("selected"); - _update_search(tabname, tab_id, ""); - } else if (!(!btn.hasAttribute("selected") && !ul.hasAttribute("data-hidden"))) { + delete btn.dataset.selected; + _update_search(tabname, extra_networks_tabname, ""); + } else if (!(!("selected" in btn.dataset) && !(ul.hasAttribute("hidden")))) { // If folder is open and not selected, then we don't collapse; just select. // NOTE: Double inversion sucks but it is the clearest way to show the branching here. _expand_or_collapse(ul, btn); - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } else { // All other cases, just select the button. - _select_button(btn, tabname, tab_id); - _update_search(tabname, tab_id, btn.getAttribute("data-path")); + _select_button(btn, tabname, extra_networks_tabname); + _update_search(tabname, extra_networks_tabname, btn.dataset.path); } } } -function extraNetworksTreeOnClick(event, tabname, tab_id) { +function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`. * * Determines whether the clicked button in the tree is for a file entry or a directory * then calls the appropriate function. * - * @param event The generated event. - * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. - * @param tab_id The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. */ var btn = event.currentTarget; var par = btn.parentElement; - if (par.getAttribute("data-tree-entry-type") === "file") { - extraNetworksTreeProcessFileClick(event, btn, tabname, tab_id); + if (par.dataset.treeEntryType === "file") { + extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname); } else { - extraNetworksTreeProcessDirectoryClick(event, btn, tabname, tab_id); + extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname); } } -function extraNetworksTreeSortOnClick(event, tabname, tab_id) { +function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Mode button. + * + * Modifies the data attributes of the Sort Mode button to cycle between + * various sorting modes. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var curr_mode = event.currentTarget.dataset.sortmode; - var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + tab_id + "_extra_sort_dir"); + var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir"); var sort_dir = el_sort_dir.dataset.sortdir; if (curr_mode == "path") { event.currentTarget.dataset.sortmode = "name"; @@ -397,23 +387,43 @@ function extraNetworksTreeSortOnClick(event, tabname, tab_id) { event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640"; event.currentTarget.setAttribute("title", "Sort by path"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, tab_id) { - var curr_dir = event.currentTarget.getAttribute("data-sortdir"); - if (curr_dir == "Ascending") { +function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Sort Direction button. + * + * Modifies the data attributes of the Sort Direction button to cycle between + * ascending and descending sort directions. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + if (event.currentTarget.dataset.sortdir == "Ascending") { event.currentTarget.dataset.sortdir = "Descending"; event.currentTarget.setAttribute("title", "Sort descending"); } else { event.currentTarget.dataset.sortdir = "Ascending"; event.currentTarget.setAttribute("title", "Sort ascending"); } - applyExtraNetworkSort(tabname + "_" + tab_id); + applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, tab_id) { - console.log("refresh clicked"); +function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Refresh Page button. + * + * In order to actually call the python functions in `ui_extra_networks.py` + * to refresh the page, we created an empty gradio button in that file with an + * event handler that refreshes the page. So what this function here does + * is it manually raises a `click` event on that button. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal"); btn_refresh_internal.dispatchEvent(new Event("click")); } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 06fa22afa..a03207b2b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -155,7 +155,14 @@ class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() - self.id_page = self.name.replace(" ", "_") + # This is the actual name of the extra networks tab (not txt2img/img2img). + self.extra_networks_tabname = self.name.replace(" ", "_") + self.allow_prompt = True + self.allow_negative_prompt = False + self.metadata = {} + self.items = {} + self.lister = util.MassFileLister() + # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") @@ -163,11 +170,6 @@ class ExtraNetworksPage: self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html") self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html") - self.allow_prompt = True - self.allow_negative_prompt = False - self.metadata = {} - self.items = {} - self.lister = util.MassFileLister() def refresh(self): pass @@ -202,15 +204,17 @@ class ExtraNetworksPage: item: dict, template: Optional[str] = None, ) -> Union[str, dict]: - """Generates HTML for a single ExtraNetworks Item + """Generates HTML for a single ExtraNetworks Item. Args: tabname: The name of the active tab. item: Dictionary containing item information. + template: Optional template string to use. Returns: - HTML string generated for this item. - Can be empty if the item is not meant to be shown. + If a template is passed: HTML string generated for this item. + Can be empty if the item is not meant to be shown. + If no template is passed: A dictionary containing the generated item's attributes. """ metadata = item.get("metadata") if metadata: @@ -244,14 +248,14 @@ class ExtraNetworksPage: if metadata: btn_metadata = self.btn_metadata_tpl.format( **{ - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) btn_edit_item = self.btn_edit_item_tpl.format( **{ "tabname": tabname, - "page_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "name": html.escape(item["name"]), } ) @@ -307,9 +311,9 @@ class ExtraNetworksPage: "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", + "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, } if template: @@ -317,7 +321,32 @@ class ExtraNetworksPage: else: return args - def create_tree_dir_item_html(self, tabname: str, dir_path: str, content: Optional[str] = None) -> Optional[str]: + def create_tree_dir_item_html( + self, + tabname: str, + dir_path: str, + content: Optional[str] = None, + ) -> Optional[str]: + """Generates HTML for a directory item in the tree. + + The generated HTML is of the format: + ```html +
            • +
              +
                + {content} +
              +
            • + ``` + + Args: + tabname: The name of the active tab. + dir_path: Path to the directory for this item. + content: Optional HTML string that will be wrapped by this
                . + + Returns: + HTML formatted string. + """ if not content: return None @@ -326,7 +355,7 @@ class ExtraNetworksPage: "search_terms": "", "subclass": "tree-list-content-dir", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": "", "data_path": dir_path, "data_hash": "", @@ -337,10 +366,32 @@ class ExtraNetworksPage: "action_list_item_action_trailing": "", } ) - ul = f"
                  {content}
                " - return f"
              • {btn + ul}
              • " + ul = f"" + return ( + "
              • " + f"{btn + ul}" + "
              • " + ) - def create_tree_file_item_html(self, tabname: str, item_name: str, item: dict) -> str: + def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str: + """Generates HTML for a file item in the tree. + + The generated HTML is of the format: + ```html +
              • + +
                +
              • + ``` + + Args: + tabname: The name of the active tab. + file_path: The path to the file for this item. + item: Dictionary containing the item information. + + Returns: + HTML formatted string. + """ item_html_args = self.create_item_html(tabname, item) action_buttons = "".join( [ @@ -355,9 +406,9 @@ class ExtraNetworksPage: "search_terms": "", "subclass": "tree-list-content-file", "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "onclick_extra": item_html_args["card_clicked"], - "data_path": item_name, + "data_path": file_path, "data_hash": item["shorthash"], "action_list_item_action_leading": "", "action_list_item_visual_leading": "🗎", @@ -366,11 +417,17 @@ class ExtraNetworksPage: "action_list_item_action_trailing": action_buttons, } ) - return f"
              • {btn}
              • " + return ( + "
              • " + f"{btn}" + "
              • " + ) def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. + The generated HTML uses `extra-networks-tree.html` as a template. + Args: tabname: The name of the active tab. @@ -379,7 +436,7 @@ class ExtraNetworksPage: """ res = "" - # Generate HTML for the tree. + # Setup the tree dictionary. roots = self.allowed_directories_for_previews() tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()} tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items) @@ -388,7 +445,17 @@ class ExtraNetworksPage: return res def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]: - """Recursively builds HTML for a tree.""" + """Recursively builds HTML for a tree. + + Args: + data: Dictionary representing a directory tree. Can be NoneType. + Data keys should be absolute paths from the root and values + should be subdirectory trees or an ExtraNetworksItem. + + Returns: + If data is not None: HTML string + Else: None + """ if not data: return None @@ -402,25 +469,36 @@ class ExtraNetworksPage: else: _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v))) - # Directories should always be displayed before files. + # Directories should always be displayed before files so we order them here. return "".join(_dir_li) + "".join(_file_li) # Add each root directory to the tree. for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])): - # If root is empty, append the "disabled" attribute to the template details tag. item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v)) - if item_html: + # Only add non-empty entries to the tree. + if item_html is not None: res += item_html return self.tree_tpl.format( **{ "tabname": tabname, - "tab_id": self.id_page, + "extra_networks_tabname": self.extra_networks_tabname, "tree": f"
                  {res}
                " } ) - def create_card_view_html(self, tabname): + def create_card_view_html(self, tabname: str) -> str: + """Generates HTML for the network Card View section for a tab. + + This HTML goes into the `extra-networks-pane.html`
                with + `id='{tabname}_{extra_networks_tabname}_cards`. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ res = "" self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): @@ -433,20 +511,26 @@ class ExtraNetworksPage: return res def create_html(self, tabname): + """Generates an HTML string for the current pane. + + The generated HTML uses `extra-networks-pane.html` as a template. + + Args: + tabname: The name of the active tab. + + Returns: + HTML formatted string. + """ self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} - tree_view_html = self.create_tree_view_html(tabname) - card_view_html = self.create_card_view_html(tabname) - network_type_id = self.id_page - return self.pane_tpl.format( **{ "tabname": tabname, - "network_type_id": network_type_id, - "tree_html": tree_view_html, - "items_html": card_view_html, + "extra_networks_tabname": self.extra_networks_tabname, + "tree_html": self.create_tree_view_html(tabname), + "items_html": self.create_card_view_html(tabname), } ) @@ -561,16 +645,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: - with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab: - with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]): + with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: + with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]): pass - elem_id = f"{tabname}_{page.id_page}_cards_html" + elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) page_elem.change( fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.id_page}_extra_search); return []}}", + _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", inputs=[], outputs=[], ) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 989a649b7..2ca937fd1 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -14,7 +14,7 @@ class UserMetadataEditor: self.ui = ui self.tabname = tabname self.page = page - self.id_part = f"{self.tabname}_{self.page.id_page}_edit_user_metadata" + self.id_part = f"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata" self.box = None diff --git a/style.css b/style.css index 085732486..1090e4368 100644 --- a/style.css +++ b/style.css @@ -879,13 +879,6 @@ footer { margin-bottom: 1em; } -.extra-network-pane{ - height: calc(100vh - 24rem); - overflow: clip scroll; - resize: vertical; - min-height: 52rem; -} - .extra-networks > div.tab-nav{ min-height: 3.4rem; } @@ -1182,23 +1175,63 @@ body.resizing .resize-handle { /* ========================= */ .extra-network-pane { display: flex; -} - -.extra-network-pane .extra-network-cards { - display: block; + height: calc(100vh - 24rem); + resize: vertical; + min-height: 52rem; } .extra-network-pane .extra-network-tree { - display: block; + flex: 1; + flex-direction: column; + display: flex; font-size: 1rem; - min-width: 25%; border: 1px solid var(--block-border-color); +} + +.extra-network-pane .extra-network-cards { + flex: 3; + overflow: clip auto !important; + border: 1px solid var(--block-border-color); +} + +.extra-network-pane .extra-network-tree .tree-list { + flex: 1; + display: flex; + flex-direction: column; + padding: 0; + width: 100%; overflow: hidden; } -.extra-network-tree .tree-list { - margin: 0 0.25rem; - padding: 0; +.extra-network-pane .extra-network-tree .tree-list .tree-list-container { + flex: 1; + overflow: clip auto !important; + width: 100%; +} + + +.extra-network-pane .extra-network-cards::-webkit-scrollbar, +.extra-network-pane .tree-list-container::-webkit-scrollbar { + background-color: transparent; + width: 16px; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-track, +.extra-network-pane .tree-list-container::-webkit-scrollbar-track { + background-color: transparent; + background-clip: content-box; +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, +.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { + background-color: var(--border-color-primary); + border-radius: 16px; + border: 4px solid var(--background-fill-primary); +} + +.extra-network-pane .extra-network-cards::-webkit-scrollbar-button, +.extra-network-pane .tree-list-container::-webkit-scrollbar-button { + display: none; } .extra-network-tree .tree-list .tree-list-controls { @@ -1244,17 +1277,15 @@ body.resizing .resize-handle { background-color: transparent; } -/* Directory
                  visibility based on expanded attribute. */ -.extra-network-tree .tree-list-content[expanded=false]+.tree-list--subgroup { +/* Directory
                    visibility based on data-expanded attribute. */ +.extra-network-tree .tree-list-content+.tree-list--subgroup { height: 0; - overflow: hidden; visibility: hidden; opacity: 0; } -.extra-network-tree .tree-list-content[expanded=true]+.tree-list--subgroup { +.extra-network-tree .tree-list-content[data-expanded]+.tree-list--subgroup { height: auto; - overflow: visible; visibility: visible; opacity: 1; } @@ -1307,7 +1338,7 @@ body.resizing .resize-handle { background-color: var(--neutral-800); } -.dark .extra-network-tree div.tree-list-content[selected] { +.dark .extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-700); } @@ -1317,20 +1348,20 @@ body.resizing .resize-handle { background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[selected] { +.extra-network-tree div.tree-list-content[data-selected] { background-color: var(--neutral-300); } /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ -.extra-network-tree .tree-list-content-dir[expanded=false] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { -ms-transform: rotate(135deg); -webkit-transform: rotate(135deg); transform: rotate(135deg); transition: transform 0.2s; } -.extra-network-tree .tree-list-content-dir[expanded=true] .tree-list-item-action-chevron { +.extra-network-tree .tree-list-content-dir[data-expanded] .tree-list-item-action-chevron { -ms-transform: rotate(225deg); -webkit-transform: rotate(225deg); transform: rotate(225deg); From ccee26b0653b4f6778c107d68df52da27446abd2 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Tue, 16 Jan 2024 14:54:07 -0500 Subject: [PATCH 261/311] fix bugs --- javascript/extraNetworks.js | 28 ++++++++----------- modules/ui_extra_networks.py | 35 ++++++++++++------------ modules/ui_extra_networks_checkpoints.py | 3 +- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index a3f003bf4..caaa3fae0 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -31,11 +31,12 @@ function setupExtraNetworksForTab(tabname) { var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs'); this_tab.classList.add('extra-networks'); this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) { - var extra_networks_tabname = elem.id; - var search = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_search"); - var sort_mode = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort"); - var sort_dir = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_sort_dir"); - var refresh = gradioApp().querySelector("#" + extra_networks_tabname + "_extra_refresh"); + // tabname_full = {tabname}_{extra_networks_tabname} + var tabname_full = elem.id; + var search = gradioApp().querySelector("#" + tabname_full + "_extra_search"); + var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort"); + var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir"); + var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh"); // If any of the buttons above don't exist, we want to skip this iteration of the loop. if (!search || !sort_mode || !sort_dir || !refresh) { @@ -44,16 +45,13 @@ function setupExtraNetworksForTab(tabname) { var applyFilter = function() { var searchTerm = search.value.toLowerCase(); - gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) { var searchOnly = elem.querySelector('.search_only'); - var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) { return t.textContent.toLowerCase(); }).join(" "); var visible = text.indexOf(searchTerm) != -1; - if (searchOnly && searchTerm.length < 4) { visible = false; } @@ -66,7 +64,6 @@ function setupExtraNetworksForTab(tabname) { var applySort = function() { var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); - var reverse = sort_dir.dataset.sortdir == "Descending"; var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name"; sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1); @@ -104,9 +101,8 @@ function setupExtraNetworksForTab(tabname) { search.addEventListener("input", applyFilter); applySort(); applyFilter(); - - extraNetworksApplySort[extra_networks_tabname] = applySort; - extraNetworksApplyFilter[extra_networks_tabname] = applyFilter; + extraNetworksApplySort[tabname_full] = applySort; + extraNetworksApplyFilter[tabname_full] = applyFilter; }); registerPrompt(tabname, tabname + "_prompt"); @@ -147,12 +143,12 @@ function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt); } -function applyExtraNetworkFilter(tabname) { - setTimeout(extraNetworksApplyFilter[tabname], 1); +function applyExtraNetworkFilter(tabname_full) { + setTimeout(extraNetworksApplyFilter[tabname_full], 1); } -function applyExtraNetworkSort(tabname) { - setTimeout(extraNetworksApplySort[tabname], 1); +function applyExtraNetworkSort(tabname_full) { + setTimeout(extraNetworksApplySort[tabname_full], 1); } var extraNetworksApplyFilter = {}; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index a03207b2b..55cd1da27 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -237,7 +237,7 @@ class ExtraNetworksPage: "tabname": tabname, "prompt": item["prompt"], "neg_prompt": item.get("negative_prompt", ""), - "allow_neg": "true" if self.allow_negative_prompt else "false" + "allow_neg": str(self.allow_negative_prompt).lower(), } ) onclick = html.escape(onclick) @@ -291,7 +291,7 @@ class ExtraNetworksPage: search_terms_html += search_term_template.format( **{ "style": "display: none;", - "class": "search_terms" + (" search_only" if search_only else ""), + "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } ) @@ -307,7 +307,7 @@ class ExtraNetworksPage: "metadata_button": btn_metadata, "name": html.escape(item["name"]), "prompt": item.get("prompt", None), - "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', + "save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"), "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, @@ -369,7 +369,7 @@ class ExtraNetworksPage: ul = f"" return ( "
                  • " - f"{btn + ul}" + f"{btn}{ul}" "
                  • " ) @@ -561,7 +561,7 @@ class ExtraNetworksPage: Find a preview PNG for a given path (without extension) and call link_preview on it. """ - potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], []) + potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], []) for file in potential_files: if self.lister.exists(file): @@ -642,7 +642,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs = [] - button_refresh = gr.Button("Refresh", elem_id=tabname+"_extra_refresh_internal", visible=False) + button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False) for page in ui.stored_extra_pages: with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab: @@ -652,24 +652,25 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html" page_elem = gr.HTML('Loading...', elem_id=elem_id) ui.pages.append(page_elem) - page_elem.change( - fn=lambda: None, - _js=f"function(){{applyExtraNetworkFilter({tabname}_{page.extra_networks_tabname}_extra_search); return []}}", - inputs=[], - outputs=[], - ) - editor = page.create_user_metadata_editor(ui, tabname) editor.create_ui() ui.user_metadata_editors.append(editor) - related_tabs.append(tab) - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False) for tab in unrelated_tabs: - tab.select(fn=None, _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=[], show_progress=False) + tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False) + + for page, tab in zip(ui.stored_extra_pages, related_tabs): + jscode = ( + "function(){{" + f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()});" + f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');" + "}}" + ) + tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False) def create_html(): ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index e7976ba12..a8c336719 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -2,7 +2,6 @@ import html import os from modules import shared, ui_extra_networks, sd_models -from modules.ui_extra_networks import quote_js from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor @@ -31,7 +30,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "preview": self.find_preview(path), "description": self.find_description(path), "search_terms": search_terms, - "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', + "onclick": html.escape(f"return selectCheckpoint('{name}');"), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, From 0b83f4c26376c87505769a3687cb06e321b413a1 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 13 Jan 2024 18:38:05 +0900 Subject: [PATCH 262/311] reuse seed from infotexts --- modules/processing_scripts/seed.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 2d3cbb97f..7b727d17c 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -6,6 +6,7 @@ from modules import scripts, ui, errors from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton +from modules import infotext_utils class ScriptSeed(scripts.ScriptBuiltinUI): @@ -77,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI): p.seed_resize_from_h = seed_resize_from_h - def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): """ Connects a 'reuse (sub)seed' button's click event so that it copies last used (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength @@ -85,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def copy_seed(gen_info_string: str, index): res = -1 - try: gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError: + infotext = gen_info.get('infotexts')[index] + gen_parameters = infotext_utils.parse_generation_parameters(infotext) + res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) + except Exception: if gen_info_string: - errors.report(f"Error parsing JSON generation info: {gen_info_string}") + errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True) return [res, gr.update()] From 6acd8e28fceccbbea0c09c91aed0eff5a3504315 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 14 Jan 2024 01:58:01 +0900 Subject: [PATCH 263/311] save_files info base on infotexts --- modules/ui_common.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index f17259c29..6a4465e93 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -40,8 +40,9 @@ def save_files(js_data, images, do_make_zip, index): import csv filenames = [] fullfns = [] + parsed_infotexts = [] - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: @@ -49,16 +50,14 @@ def save_files(js_data, images, do_make_zip, index): setattr(self, key, value) data = json.loads(js_data) - p = MyObject(data) + path = shared.opts.outdir_save save_to_dirs = shared.opts.use_save_to_dirs_for_ui extension: str = shared.opts.samples_format start_index = 0 - only_one = False if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - only_one = True images = [images[index]] start_index = index @@ -74,10 +73,12 @@ def save_files(js_data, images, do_make_zip, index): image = image_from_url_text(filedata) is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) p.batch_index = image_index-1 - fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parsed_infotexts.append(parameters) + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) @@ -86,12 +87,12 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) # Make Zip if do_make_zip: - zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0] - namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True) + p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts] + namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True) zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") zip_filepath = os.path.join(path, f"{zip_filename}.zip") From d224fed0ce16e6f1507b0da29e7495ab2b0035e4 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 264/311] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf24..1049c6c3c 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From 45a51c07e2ce5a36dc3cddde9a666a4c7b61af82 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:21:58 +0900 Subject: [PATCH 265/311] parse_generation_parameters with no skip_fields --- modules/processing_scripts/seed.py | 2 +- modules/ui_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 7b727d17c..7a4c01598 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -88,7 +88,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: try: gen_info = json.loads(gen_info_string) infotext = gen_info.get('infotexts')[index] - gen_parameters = infotext_utils.parse_generation_parameters(infotext) + gen_parameters = infotext_utils.parse_generation_parameters(infotext, []) res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1)) except Exception: if gen_info_string: diff --git a/modules/ui_common.py b/modules/ui_common.py index 6a4465e93..6d7f3a675 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -76,7 +76,7 @@ def save_files(js_data, images, do_make_zip, index): p.batch_index = image_index-1 - parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index]) + parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], []) parsed_infotexts.append(parameters) fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) From 6916de5c0bd8df3835d450caa3327d1924db081c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:16:07 +0900 Subject: [PATCH 266/311] parse_generation_parameters skip_fields --- modules/infotext_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py index 9a02cdf24..1049c6c3c 100644 --- a/modules/infotext_utils.py +++ b/modules/infotext_utils.py @@ -230,7 +230,7 @@ def restore_old_hires_fix_params(res): res['Hires resize-2'] = height -def parse_generation_parameters(x: str): +def parse_generation_parameters(x: str, skip_fields: list[str] | None = None): """parses generation parameters string, the one you see in text field under the picture in UI: ``` girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate @@ -240,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model returns a dict with field values """ + if skip_fields is None: + skip_fields = shared.opts.infotext_skip_pasting res = {} @@ -356,8 +358,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model infotext_versions.backcompat(res) - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} + for key in skip_fields: + res.pop(key, None) return res From e1dfd452c0447e729a7341c434b4aab6063aa654 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 16 Jan 2024 20:27:29 +0900 Subject: [PATCH 267/311] parse_generation_parameters with no skip_fields --- modules/txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/txt2img.py b/modules/txt2img.py index e617cb1c5..92a160d61 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -70,7 +70,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0] p.firstpass_image = infotext_utils.image_from_url_text(image_info) - parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index]) + parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], []) p.seed = parameters.get('Seed', -1) p.subseed = parameters.get('Variation seed', -1) From 2cf23099ebb81832a27c8016f14062885f5a9c98 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 18 Jan 2024 04:44:21 +0900 Subject: [PATCH 268/311] fix console total progress bar when using txt2img_upscale add p.txt2img_upscale as indicator --- modules/processing.py | 7 +++++-- modules/txt2img.py | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index dcc807fe3..547ab2f86 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1227,8 +1227,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter - - shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) + if getattr(self, 'txt2img_upscale', False): + total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count + else: + total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count + shared.total_tqdm.updateTotal(total_steps) state.job_count = state.job_count * 2 state.processing_has_refined_job_count = True diff --git a/modules/txt2img.py b/modules/txt2img.py index 92a160d61..4efcb4c3d 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -64,6 +64,7 @@ def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, g p.enable_hr = True p.batch_size = 1 p.n_iter = 1 + p.txt2img_upscale = True geninfo = json.loads(generation_info) From f25c81a74462554890ac7327a30629b332db1084 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Wed, 17 Jan 2024 22:38:51 -0500 Subject: [PATCH 269/311] Fix embeddings add/remove to/from prompt on click bugs. --- javascript/extraNetworks.js | 13 +++---------- modules/ui_extra_networks.py | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index caaa3fae0..1e2786ab0 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -169,8 +169,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; + var extraTextBeforeNet = opts.extra_networks_add_text_separator; if (m) { - var extraTextBeforeNet = opts.extra_networks_add_text_separator; var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; @@ -183,7 +183,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } return found; }); - if (foundAtPosition >= 0) { if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); @@ -193,13 +192,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } } } else { - newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) { - if (found == text) { - replaced = true; - return ""; - } - return found; - }); + newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, "g"), ""); + replaced = (newTextareaText != textarea.value); } if (replaced) { @@ -211,7 +205,6 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { } function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 55cd1da27..5dd4e443f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -236,7 +236,7 @@ class ExtraNetworksPage: **{ "tabname": tabname, "prompt": item["prompt"], - "neg_prompt": item.get("negative_prompt", ""), + "neg_prompt": item.get("negative_prompt", "''"), "allow_neg": str(self.allow_negative_prompt).lower(), } ) From 0181c1f76b97162c42401f1e6286ae73d8aa6033 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 19 Jan 2024 00:14:03 +0800 Subject: [PATCH 270/311] Fix nested manual cast --- modules/devices.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/devices.py b/modules/devices.py index 0321d12c6..370286297 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -165,6 +165,8 @@ def manual_cast_forward(target_dtype): @contextlib.contextmanager def manual_cast(target_dtype): for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + continue org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -175,7 +177,9 @@ def manual_cast(target_dtype): yield None finally: for module_type in patch_module_list: - module_type.forward = module_type.org_forward + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From 50e444fa1d1c779b0c9c915e44ed2d78b587f277 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:09 -0500 Subject: [PATCH 271/311] Fix missing important style. --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 1090e4368..57c52354a 100644 --- a/style.css +++ b/style.css @@ -28,7 +28,7 @@ div.gradio-container{ } .hidden{ - display: none; + display: none !important; } .compact{ From 69f4f148dce0868b748b700f96942d4036e848c9 Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Thu, 18 Jan 2024 12:13:33 -0500 Subject: [PATCH 272/311] Fix various bugs including refresh bug. --- javascript/extraNetworks.js | 7 +++++-- modules/ui_extra_networks.py | 31 ++++++++++++++++--------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1e2786ab0..3029afec8 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -55,8 +55,11 @@ function setupExtraNetworksForTab(tabname) { if (searchOnly && searchTerm.length < 4) { visible = false; } - - elem.style.display = visible ? "" : "none"; + if (visible) { + elem.classList.remove("hidden"); + } else { + elem.classList.add("hidden"); + } }); applySort(); diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 5dd4e443f..656e7f181 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -216,22 +216,17 @@ class ExtraNetworksPage: Can be empty if the item is not meant to be shown. If no template is passed: A dictionary containing the generated item's attributes. """ - metadata = item.get("metadata") - if metadata: - self.metadata[item["name"]] = metadata - - if "user_metadata" not in item: - self.read_user_metadata(item) - preview = item.get("preview", None) - height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' - width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' + style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' + style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;" + card_style = style_height + style_width + style_font_size background_image = f'' if preview else '' onclick = item.get("onclick", None) if onclick is None: # Don't quote prompt/neg_prompt since they are stored as js strings already. - onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, '{allow_neg}');" + onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});" onclick = onclick_js_tpl.format( **{ "tabname": tabname, @@ -286,11 +281,10 @@ class ExtraNetworksPage: ).strip() search_terms_html = "" - search_term_template = "{search_term}" + search_term_template = "" for search_term in item.get("search_terms", []): search_terms_html += search_term_template.format( **{ - "style": "display: none;", "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } @@ -301,7 +295,7 @@ class ExtraNetworksPage: "background_image": background_image, "card_clicked": onclick, "copy_path_button": btn_copy_path, - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""), "edit_button": btn_edit_item, "local_preview": quote_js(item["local_preview"]), "metadata_button": btn_metadata, @@ -311,7 +305,7 @@ class ExtraNetworksPage: "search_only": " search_only" if search_only else "", "search_terms": search_terms_html, "sort_keys": sort_keys, - "style": f"display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%", + "style": card_style, "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, } @@ -500,7 +494,6 @@ class ExtraNetworksPage: HTML formatted string. """ res = "" - self.items = {x["name"]: x for x in self.list_items()} for item in self.items.values(): res += self.create_item_html(tabname, item, self.card_tpl) @@ -524,6 +517,14 @@ class ExtraNetworksPage: self.lister.reset() self.metadata = {} self.items = {x["name"]: x for x in self.list_items()} + # Populate the instance metadata for each item. + for item in self.items.values(): + metadata = item.get("metadata") + if metadata: + self.metadata[item["name"]] = metadata + + if "user_metadata" not in item: + self.read_user_metadata(item) return self.pane_tpl.format( **{ From a97147bc8a43ade7c18bb755f0cfac111fc1a619 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:10:02 +0100 Subject: [PATCH 273/311] Add support for DAT upscaler models --- modules/dat_model.py | 81 +++++++++++++++++++++++++++++++++++++++ modules/shared_items.py | 5 +++ modules/shared_options.py | 3 ++ 3 files changed, 89 insertions(+) create mode 100644 modules/dat_model.py diff --git a/modules/dat_model.py b/modules/dat_model.py new file mode 100644 index 000000000..8637351c5 --- /dev/null +++ b/modules/dat_model.py @@ -0,0 +1,81 @@ +import os +import sys + +from modules import modelloader, devices +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model + +from icecream import ic + +class UpscalerDAT(Upscaler): + def __init__(self, user_path): + self.name = "DAT" + self.user_path = user_path + self.scalers = [] + super().__init__() + + for file in self.find_models(ext_filter=[".pt", ".pth"]): + name = modelloader.friendly_name(file) + scaler_data = UpscalerData(name, file, upscaler=self, scale=None) + self.scalers.append(scaler_data) + + for model in get_dat_models(self): + if model.name in opts.dat_enabled_models: + self.scalers.append(model) + + def do_upscale(self, img, selected_model): + try: + info = self.load_model(selected_model) + except Exception as e: + errors.report(f"Unable to load DAT model {path}", exc_info=True) + return img + + model_descriptor = modelloader.load_spandrel_model( + info.local_data_path, + device=self.device, + prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="DAT", + ) + return upscale_with_model( + model_descriptor, + img, + tile_size=opts.DAT_tile, + tile_overlap=opts.DAT_tile_overlap, + ) + + def load_model(self, path): + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") + + +def get_dat_models(scaler): + return [ + UpscalerData( + name="DAT x2", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", + scale=2, + upscaler=scaler, + ), + UpscalerData( + name="DAT x3", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", + scale=3, + upscaler=scaler, + ), + UpscalerData( + name="DAT x4", + path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", + scale=4, + upscaler=scaler, + ), + ] diff --git a/modules/shared_items.py b/modules/shared_items.py index 13fb2814f..88f636452 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -8,6 +8,11 @@ def realesrgan_models_names(): return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] +def dat_models_names(): + import modules.dat_model + return [x.name for x in modules.dat_model.get_dat_models(None)] + + def postprocessing_scripts(): import modules.scripts diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e7..48a206ce9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), + "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), + "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), })) From 2e7efe47b6c78a59f9f3d64c1d30f54751a31a56 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:39:14 +0100 Subject: [PATCH 274/311] Minor cleanup --- modules/dat_model.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/dat_model.py b/modules/dat_model.py index 8637351c5..495d5f493 100644 --- a/modules/dat_model.py +++ b/modules/dat_model.py @@ -1,12 +1,10 @@ import os -import sys -from modules import modelloader, devices +from modules import modelloader, errors from modules.shared import cmd_opts, opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -from icecream import ic class UpscalerDAT(Upscaler): def __init__(self, user_path): @@ -24,13 +22,13 @@ class UpscalerDAT(Upscaler): if model.name in opts.dat_enabled_models: self.scalers.append(model) - def do_upscale(self, img, selected_model): + def do_upscale(self, img, path): try: - info = self.load_model(selected_model) - except Exception as e: + info = self.load_model(path) + except Exception: errors.report(f"Unable to load DAT model {path}", exc_info=True) return img - + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, From 1ddb886a804dc69f542ebc71bdd7baec48f677b6 Mon Sep 17 00:00:00 2001 From: n0kovo Date: Fri, 19 Jan 2024 00:48:46 +0100 Subject: [PATCH 275/311] Fix wrong options value --- modules/shared_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared_options.py b/modules/shared_options.py index 48a206ce9..74a2a67f9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -97,7 +97,7 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), - "dat_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), + "dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}), "DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), "DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), From d31dc7a73996ee611a72ef641237606fc9eddca5 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 00:40:03 +0400 Subject: [PATCH 276/311] fix extras big batch crashes --- modules/postprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 7449b0dc5..f14882321 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, else: image_data = image_placeholder - shared.state.assign_current_image(image_data) - parameters, existing_pnginfo = images.read_info_from_image(image_data) if parameters: existing_pnginfo["parameters"] = parameters @@ -92,6 +90,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, pp.image.info = existing_pnginfo pp.image.info["postprocessing"] = infotext + shared.state.assign_current_image(pp.image) + if save_output: fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix) From 4dde96109bb9621bb5608f7f792dc569d692ecf5 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:38:34 -0800 Subject: [PATCH 277/311] Update zoom.js --- extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index dd6028c1e..df60c1a17 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -218,8 +218,8 @@ onUiLoaded(async() => { canvas_hotkey_fullscreen: "KeyS", canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", - canvas_hotkey_shrink_brush: "BracketLeft", - canvas_hotkey_grow_brush: "BracketRight", + canvas_hotkey_shrink_brush: "KeyQ", + canvas_hotkey_grow_brush: "KeyW", canvas_disabled_functions: [], canvas_show_tooltip: true, canvas_auto_expand: true, From bcdcc8be7d5fb923d64b7eb8a9a831f26f179d97 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:07 -0800 Subject: [PATCH 278/311] Update hotkey_config.py --- .../canvas-zoom-and-pan/scripts/hotkey_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 9d9accce8..6d9e34d8b 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -4,8 +4,8 @@ from modules import shared shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), - "canvas_hotkey_shrink_brush": shared.OptionInfo("[", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("]", "Enlarge the brush size"), + "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 30f31d23c6d212389aef035137c381dd79d44a26 Mon Sep 17 00:00:00 2001 From: WebDev <6970043+WebDev9000@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:39:33 -0800 Subject: [PATCH 279/311] Update hotkey_config.py --- extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 6d9e34d8b..89b7c31f2 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -5,7 +5,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), - "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), + "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), From 56676ff923497901d56fcbca1ac549095e71d72d Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 11:49:05 +0400 Subject: [PATCH 280/311] fix tab indexes reset after restart ui --- modules/ui.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index a716a0405..ebd33a856 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -532,7 +532,7 @@ def create_ui(): if category == "image": with gr.Tabs(elem_id="mode_img2img"): - img2img_selected_tab = gr.State(0) + img2img_selected_tab = gr.Number(value=0, visible=False) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) @@ -613,7 +613,7 @@ def create_ui(): elif category == "dimensions": with FormRow(): with gr.Column(elem_id="img2img_column_size", scale=4): - selected_scale_tab = gr.State(value=0) + selected_scale_tab = gr.Number(value=0, visible=False) with gr.Tabs(): with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to: diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 7a132ac22..ff22a1780 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -5,7 +5,7 @@ import modules.infotext_utils as parameters_copypaste def create_ui(): dummy_component = gr.Label(visible=False) - tab_index = gr.State(value=0) + tab_index = gr.Number(value=0, visible=False) with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): From 81126027f5226e7ee58e1a99194eb9ec7b8ec6e7 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:31:12 +0800 Subject: [PATCH 281/311] Avoid early disable --- modules/devices.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 370286297..3bde16990 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -164,9 +164,11 @@ def manual_cast_forward(target_dtype): @contextlib.contextmanager def manual_cast(target_dtype): + applied = False for module_type in patch_module_list: if hasattr(module_type, "org_forward"): continue + applied = True org_forward = module_type.forward if module_type == torch.nn.MultiheadAttention and has_xpu(): module_type.forward = manual_cast_forward(torch.float32) @@ -176,6 +178,8 @@ def manual_cast(target_dtype): try: yield None finally: + if not applied: + return for module_type in patch_module_list: if hasattr(module_type, "org_forward"): module_type.forward = module_type.org_forward From 4a66d2fb228584bb38dc22db6a3e657561834c7a Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 16:33:59 +0800 Subject: [PATCH 282/311] Avoid exceptions to be silenced --- modules/devices.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 3bde16990..dfffaf24f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -178,12 +178,11 @@ def manual_cast(target_dtype): try: yield None finally: - if not applied: - return - for module_type in patch_module_list: - if hasattr(module_type, "org_forward"): - module_type.forward = module_type.org_forward - delattr(module_type, "org_forward") + if applied: + for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): From ed383eb5a0db78d0da60420bb464943e4cc9b847 Mon Sep 17 00:00:00 2001 From: Andray Date: Sat, 20 Jan 2024 15:37:49 +0400 Subject: [PATCH 283/311] keep postprocessing upscale selected tab after restart --- scripts/postprocessing_upscale.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index a57f9d4a4..e269682d0 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -15,7 +15,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): order = 1000 def ui(self): - selected_tab = gr.State(value=0) + selected_tab = gr.Number(value=0, visible=False) with gr.Column(): with FormRow(): From 2310cd66e5381fbe6b966894381c6ee7b762898f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 11:43:45 -0500 Subject: [PATCH 284/311] Add toggle button for tree view. Use default settings for sortmode and direction. --- html/extra-networks-pane.html | 55 ++++++++++++++++++++-- html/extra-networks-tree.html | 37 --------------- javascript/extraNetworks.js | 20 ++++++-- modules/ui_extra_networks.py | 7 +++ style.css | 88 +++++++++++++++++++++++------------ 5 files changed, 133 insertions(+), 74 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index bf46ca163..73dad2ab2 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,8 +1,55 @@
                    -
                    - {tree_html} +
                    + +
                    + +
                    +
                    + +
                    +
                    + +
                    +
                    + +
                    -
                    - {items_html} +
                    +
                    + {tree_html} +
                    +
                    + {items_html} +
                    \ No newline at end of file diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 23f6af105..39649e860 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,41 +1,4 @@
                    -
                    - -
                    - -
                    -
                    - -
                    -
                    - -
                    -
                    {tree}
                    diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 3029afec8..ce788328c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -348,7 +348,7 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) { } } -function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Mode button. * @@ -382,7 +382,7 @@ function extraNetworksTreeSortOnClick(event, tabname, extra_networks_tabname) { applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Sort Direction button. * @@ -403,7 +403,21 @@ function extraNetworksTreeSortDirOnClick(event, tabname, extra_networks_tabname) applyExtraNetworkSort(tabname + "_" + extra_networks_tabname); } -function extraNetworksTreeRefreshOnClick(event, tabname, extra_networks_tabname) { +function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) { + /** + * Handles `onclick` events for the Tree View button. + * + * Toggles the tree view in the extra networks pane. + * + * @param event The generated event. + * @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc. + * @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc. + */ + gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree").classList.toggle("hidden"); + event.currentTarget.classList.toggle("extra-network-control--enabled"); +} + +function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) { /** * Handles `onclick` events for the Refresh Page button. * diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 656e7f181..4c8a40744 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -526,10 +526,17 @@ class ExtraNetworksPage: if "user_metadata" not in item: self.read_user_metadata(item) + data_sortdir = shared.opts.extra_networks_card_order + data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() + data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + return self.pane_tpl.format( **{ "tabname": tabname, "extra_networks_tabname": self.extra_networks_tabname, + "data_sortmode": data_sortmode, + "data_sortkey": data_sortkey, + "data_sortdir": data_sortdir, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } diff --git a/style.css b/style.css index 57c52354a..f3fd1571b 100644 --- a/style.css +++ b/style.css @@ -1178,6 +1178,13 @@ body.resizing .resize-handle { height: calc(100vh - 24rem); resize: vertical; min-height: 52rem; + flex-direction: column; +} + +.extra-network-pane .extra-network-pane-content { + display: flex; + flex: 1; + flex-direction: row; } .extra-network-pane .extra-network-tree { @@ -1234,7 +1241,7 @@ body.resizing .resize-handle { display: none; } -.extra-network-tree .tree-list .tree-list-controls { +.extra-network-pane .extra-network-control { position: relative; display: grid; width: 100%; @@ -1248,8 +1255,7 @@ body.resizing .resize-handle { border: none; transition: background 33.333ms linear; grid-template-rows: min-content; - grid-template-areas: "tree-list-controls-col-0 tree-list-controls-col-1 tree-list-controls-col-2 tree-list-controls-col-3"; - grid-template-columns: minmax(0, auto) min-content min-content min-content; + grid-template-columns: minmax(0, auto) repeat(4, min-content); grid-gap: 0.1rem; align-items: start; } @@ -1342,16 +1348,16 @@ body.resizing .resize-handle { background-color: var(--neutral-700); } +.extra-network-tree div.tree-list-content[data-selected] { + background-color: var(--neutral-300); +} + .extra-network-tree div.tree-list-content:hover { -webkit-transition: all 0.05s ease-in-out; transition: all 0.05s ease-in-out; background-color: var(--neutral-200); } -.extra-network-tree div.tree-list-content[data-selected] { - background-color: var(--neutral-300); -} - /* ==== CHEVRON ICON ACTIONS ==== */ /* Define the animation for the arrow when it is clicked. */ .extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron { @@ -1378,7 +1384,7 @@ body.resizing .resize-handle { /* ==== SEARCH INPUT ACTIONS ==== */ /* Add icon to left side of */ -.extra-network-tree .tree-list-controls .tree-list-search::before { +.extra-network-pane .extra-network-control .extra-network-control--search::before { content: "🔎︎"; position: absolute; margin: 0.5rem; @@ -1386,14 +1392,12 @@ body.resizing .resize-handle { color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-controls .tree-list-search { +.extra-network-pane .extra-network-control .extra-network-control--search { display: inline-flex; - grid-area: tree-list-controls-col-0; position: relative; - margin: 0.5rem; } -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text { border: 1px solid var(--button-secondary-border-color); border-radius: 0.5rem; color: var(--button-secondary-text-color); @@ -1404,7 +1408,7 @@ body.resizing .resize-handle { } /* clear button (x on right side) styling */ -.extra-network-tree .tree-list-controls .tree-list-search .tree-list-search-text::-webkit-search-cancel-button { +.extra-network-pane .extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button { -webkit-appearance: none; appearance: none; cursor: pointer; @@ -1418,8 +1422,7 @@ body.resizing .resize-handle { } /* ==== SORT ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort { - grid-area: tree-list-controls-col-1; +.extra-network-pane .extra-network-control .extra-network-control--sort { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1427,7 +1430,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort .extra-network-control--sort-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1436,25 +1439,24 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort[data-sortmode="path"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="path"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="name"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="name"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_created"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_created"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort[data-sortmode="date_modified"] .tree-list-sort-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort[data-sortmode="date_modified"] .extra-network-control--sort-icon { mask-image: url('data:image/svg+xml,'); } /* ==== SORT DIRECTION ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-sort-dir { - grid-area: tree-list-controls-col-2; +.extra-network-pane .extra-network-control .extra-network-control--sort-dir { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1462,7 +1464,7 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-sort-dir .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir .extra-network-control--sort-dir-icon { height: 1.5rem; width: 1.5rem; mask-repeat: no-repeat; @@ -1471,17 +1473,16 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Ascending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Ascending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } -.extra-network-tree .tree-list-sort-dir[data-sortdir="Descending"] .tree-list-sort-dir-icon { +.extra-network-pane .extra-network-control .extra-network-control--sort-dir[data-sortdir="Descending"] .extra-network-control--sort-dir-icon { mask-image: url('data:image/svg+xml,'); } -/* ==== REFRESH ICON ACTIONS ==== */ -.extra-network-tree .tree-list-controls .tree-list-refresh { - grid-area: tree-list-controls-col-3; +/* ==== TREE VIEW ICON ACTIONS ==== */ +.extra-network-pane .extra-network-control .extra-network-control--tree-view { padding: 0.25rem; display: inline-flex; cursor: pointer; @@ -1489,7 +1490,34 @@ body.resizing .resize-handle { align-self: center; } -.extra-network-tree .tree-list-controls .tree-list-refresh .tree-list-refresh-icon { +.extra-network-pane .extra-network-control .extra-network-control--tree-view .extra-network-control--tree-view-icon { + height: 1.5rem; + width: 1.5rem; + mask-image: url('data:image/svg+xml,'); + mask-repeat: no-repeat; + mask-position: center center; + mask-size: 100%; + background-color: var(--input-placeholder-color); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-700); +} + +.dark .extra-network-pane .extra-network-control .extra-network-control--enabled { + background-color: var(--neutral-300); +} + +/* ==== REFRESH ICON ACTIONS ==== */ +.extra-network-pane .extra-network-control .extra-network-control--refresh { + padding: 0.25rem; + display: inline-flex; + cursor: pointer; + justify-self: center; + align-self: center; +} + +.extra-network-pane .extra-network-control .extra-network-control--refresh .extra-network-control--refresh-icon { height: 1.5rem; width: 1.5rem; mask-image: url('data:image/svg+xml,'); @@ -1499,7 +1527,7 @@ body.resizing .resize-handle { background-color: var(--input-placeholder-color); } -.extra-network-tree .tree-list-refresh-icon:active { +.extra-network-pane .extra-network-control .extra-network-control--refresh-icon:active { -ms-transform: rotate(180deg); -webkit-transform: rotate(180deg); transform: rotate(180deg); From b67a49441fc420f37c6bef1172a0b1ad5c42f30f Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sat, 20 Jan 2024 13:28:37 -0500 Subject: [PATCH 285/311] Add option in settings to enable/disable tree view by default. --- html/extra-networks-pane.html | 4 ++-- modules/shared_options.py | 1 + modules/ui_extra_networks.py | 7 +++++++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 73dad2ab2..9f5b3ecee 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -29,7 +29,7 @@
                    @@ -45,7 +45,7 @@
                    -
                    +
                    {tree_html}
                    diff --git a/modules/shared_options.py b/modules/shared_options.py index 63488f4e7..e0a6d977c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -251,6 +251,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(), "extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(), + "extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c8a40744..80160b847 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -529,6 +529,11 @@ class ExtraNetworksPage: data_sortdir = shared.opts.extra_networks_card_order data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip() data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}" + tree_view_btn_extra_class = "" + tree_view_div_extra_class = "hidden" + if shared.opts.extra_networks_tree_view_default_enabled: + tree_view_btn_extra_class = "extra-network-control--enabled" + tree_view_div_extra_class = "" return self.pane_tpl.format( **{ @@ -537,6 +542,8 @@ class ExtraNetworksPage: "data_sortmode": data_sortmode, "data_sortkey": data_sortkey, "data_sortdir": data_sortdir, + "tree_view_btn_extra_class": tree_view_btn_extra_class, + "tree_view_div_extra_class": tree_view_div_extra_class, "tree_html": self.create_tree_view_html(tabname), "items_html": self.create_card_view_html(tabname), } From 25e8273d2f6481deca221b29d35093f6d0c9da6a Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 02:21:36 +0900 Subject: [PATCH 286/311] re-work multi --styles-file --styles-file change to append str --styles-file is [] then defaults to [styles.csv] --styles-file accepts paths or paths with wildcard "*" the first `--styles-file` entry is use as the default styles file path if filename a wildcard then the first matching file is used if no match is found, create a new "styles.csv" in the same dir as the first path when saving a new style it will be save in the default styles file when saving a existing style, it will be saved to file it belongs to order of the styles files in the styles dropdown can be controlled to a certain degree by the order of --styles-file --- modules/cmd_args.py | 2 +- modules/shared.py | 3 +- modules/styles.py | 79 +++++++++++++++++++------------------ modules/ui_prompt_styles.py | 9 +++-- 4 files changed, 49 insertions(+), 44 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e58059a1f..f1251b6c8 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -88,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[]) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/shared.py b/modules/shared.py index 636619391..ccdca4e70 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,3 +1,4 @@ +import os import sys import gradio as gr @@ -11,7 +12,7 @@ parser = shared_cmd_options.parser batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond parallel_processing_allowed = True -styles_filename = cmd_opts.styles_file +styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')] config_filename = cmd_opts.ui_settings_file hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} diff --git a/modules/styles.py b/modules/styles.py index 026c43001..9edcc7e44 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,16 +1,15 @@ +from pathlib import Path import csv -import fnmatch import os -import os.path import typing import shutil class PromptStyle(typing.NamedTuple): name: str - prompt: str - negative_prompt: str - path: str = None + prompt: str | None + negative_prompt: str | None + path: str | None = None def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt): class StyleDatabase: - def __init__(self, path: str): + def __init__(self, paths: list[str | Path]): self.no_style = PromptStyle("None", "", "", None) self.styles = {} - self.path = path + self.paths = paths + self.all_styles_files: list[Path] = [] - folder, file = os.path.split(self.path) - filename, _, ext = file.partition('*') - self.default_path = os.path.join(folder, filename + ext) + folder, file = os.path.split(self.paths[0]) + if '*' in file or '?' in file: + # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path + self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv'))) + self.paths.insert(0, self.default_path) + else: + self.default_path = Path(self.paths[0]) self.prompt_fields = [field for field in PromptStyle._fields if field != "path"] @@ -99,33 +103,31 @@ class StyleDatabase: """ self.styles.clear() - path, filename = os.path.split(self.path) + # scans for all styles files + all_styles_files = [] + for pattern in self.paths: + folder, file = os.path.split(pattern) + if '*' in file or '?' in file: + found_files = Path(folder).glob(file) + [all_styles_files.append(file) for file in found_files] + else: + # if os.path.exists(pattern): + all_styles_files.append(Path(pattern)) - if "*" in filename: - fileglob = filename.split("*")[0] + "*.csv" - filelist = [] - for file in os.listdir(path): - if fnmatch.fnmatch(file, fileglob): - filelist.append(file) - # Add a visible divider to the style list - half_len = round(len(file) / 2) - divider = f"{'-' * (20 - half_len)} {file.upper()}" - divider = f"{divider} {'-' * (40 - len(divider))}" - self.styles[divider] = PromptStyle( - f"{divider}", None, None, "do_not_save" - ) - # Add styles from this CSV file - self.load_from_csv(os.path.join(path, file)) - if len(filelist) == 0: - print(f"No styles found in {path} matching {fileglob}") - return - elif not os.path.exists(self.path): - print(f"Style database not found: {self.path}") - return - else: - self.load_from_csv(self.path) + # Remove any duplicate entries + seen = set() + self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))] - def load_from_csv(self, path: str): + for styles_file in self.all_styles_files: + if len(all_styles_files) > 1: + # add divider when more than styles file + # '---------------- STYLES ----------------' + divider = f' {styles_file.stem.upper()} '.center(40, '-') + self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save") + if styles_file.is_file(): + self.load_from_csv(styles_file) + + def load_from_csv(self, path: str | Path): with open(path, "r", encoding="utf-8-sig", newline="") as file: reader = csv.DictReader(file, skipinitialspace=True) for row in reader: @@ -137,7 +139,7 @@ class StyleDatabase: negative_prompt = row.get("negative_prompt", "") # Add style to database self.styles[row["name"]] = PromptStyle( - row["name"], prompt, negative_prompt, path + row["name"], prompt, negative_prompt, str(path) ) def get_style_paths(self) -> set: @@ -145,11 +147,11 @@ class StyleDatabase: # Update any styles without a path to the default path for style in list(self.styles.values()): if not style.path: - self.styles[style.name] = style._replace(path=self.default_path) + self.styles[style.name] = style._replace(path=str(self.default_path)) # Create a list of all distinct paths, including the default path style_paths = set() - style_paths.add(self.default_path) + style_paths.add(str(self.default_path)) for _, style in self.styles.items(): if style.path: style_paths.add(style.path) @@ -177,7 +179,6 @@ class StyleDatabase: def save_styles(self, path: str = None) -> None: # The path argument is deprecated, but kept for backwards compatibility - _ = path style_paths = self.get_style_paths() diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py index 0d74c23fa..d67e3f17e 100644 --- a/modules/ui_prompt_styles.py +++ b/modules/ui_prompt_styles.py @@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt): if not name: return gr.update(visible=False) - style = styles.PromptStyle(name, prompt, negative_prompt) + existing_style = shared.prompt_styles.styles.get(name) + path = existing_style.path if existing_style is not None else None + + style = styles.PromptStyle(name, prompt, negative_prompt, path) shared.prompt_styles.styles[style.name] = style - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return gr.update(visible=True) @@ -34,7 +37,7 @@ def delete_style(name): return shared.prompt_styles.styles.pop(name, None) - shared.prompt_styles.save_styles(shared.styles_filename) + shared.prompt_styles.save_styles() return '', '', '' From 8459015017fde8833a4159b8d56a32f92ebf4186 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:19:53 +0100 Subject: [PATCH 287/311] skip if headers haven't changed --- modules/ui_common.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 92b007192..2c67834c6 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -1,3 +1,4 @@ +import csv import dataclasses import json import html @@ -37,8 +38,6 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): - import csv - with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) @@ -47,6 +46,10 @@ def update_logfile(logfile_path, fields): if not rows: return + # file is already synced, do nothing + if len(rows[0]) == len(fields): + return + rows[0] = fields # append new fields to each row as empty values @@ -60,7 +63,6 @@ def update_logfile(logfile_path, fields): def save_files(js_data, images, do_make_zip, index): - import csv filenames = [] fullfns = [] parsed_infotexts = [] From f190b85182a89f873fa99d897ac20310047131ea Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 21:27:38 +0100 Subject: [PATCH 288/311] restore saving fields --- modules/ui_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index 2c67834c6..6eb740f16 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -132,7 +132,7 @@ def save_files(js_data, images, do_make_zip, index): filenames.append(os.path.basename(txt_fullfn)) fullfns.append(txt_fullfn) - writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt']]) + writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]]) # Make Zip if do_make_zip: From 4aa99f77abd439469f20e6e2941dd553031ed2d0 Mon Sep 17 00:00:00 2001 From: Arturo Albacete Date: Sat, 20 Jan 2024 22:04:53 +0100 Subject: [PATCH 289/311] add docstring --- modules/ui_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui_common.py b/modules/ui_common.py index 6eb740f16..29fe7d0e9 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -38,6 +38,7 @@ def plaintext_to_html(text, classname=None): def update_logfile(logfile_path, fields): + """Update a logfile from old format to new format to maintain CSV integrity.""" with open(logfile_path, "r", encoding="utf8", newline="") as file: reader = csv.reader(file) rows = list(reader) From e36827af3254f7bac9f8c78d6d56c709959b40b6 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 21 Jan 2024 07:20:52 +0900 Subject: [PATCH 290/311] improve get_crop_region --- modules/masking.py | 43 +++++++++---------------------------------- modules/processing.py | 2 +- 2 files changed, 10 insertions(+), 35 deletions(-) diff --git a/modules/masking.py b/modules/masking.py index be9f84c76..29a394527 100644 --- a/modules/masking.py +++ b/modules/masking.py @@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps def get_crop_region(mask, pad=0): """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. - For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" - - h, w = mask.shape - - crop_left = 0 - for i in range(w): - if not (mask[:, i] == 0).all(): - break - crop_left += 1 - - crop_right = 0 - for i in reversed(range(w)): - if not (mask[:, i] == 0).all(): - break - crop_right += 1 - - crop_top = 0 - for i in range(h): - if not (mask[i] == 0).all(): - break - crop_top += 1 - - crop_bottom = 0 - for i in reversed(range(h)): - if not (mask[i] == 0).all(): - break - crop_bottom += 1 - - return ( - int(max(crop_left-pad, 0)), - int(max(crop_top-pad, 0)), - int(min(w - crop_right + pad, w)), - int(min(h - crop_bottom + pad, h)) - ) + For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)""" + mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask) + box = mask_img.getbbox() + if box: + x1, y1, x2, y2 = box + else: # when no box is found + x1, y1 = mask_img.size + x2 = y2 = 0 + return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1]) def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): diff --git a/modules/processing.py b/modules/processing.py index 6b6317951..72d8093ba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1562,7 +1562,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpaint_full_res: self.mask_for_overlay = image_mask mask = image_mask.convert('L') - crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) + crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region From 2974b9cee94dc474ffbc9e9617d14c9aaf9e1e63 Mon Sep 17 00:00:00 2001 From: Stefan Benten Date: Sun, 21 Jan 2024 14:05:47 +0100 Subject: [PATCH 291/311] modules/api/api.py: add api endpoint to refresh embeddings list --- modules/api/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index b3d74e513..b6bb9d06a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -230,6 +230,7 @@ class Api: self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) + self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) @@ -643,6 +644,10 @@ class Api: "skipped": convert_embeddings(db.skipped_embeddings), } + def refresh_embeddings(self): + with self.queue_lock: + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() From d7d3166a2749fe04f0ba1d8d0f88c8e8819a379d Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:27:24 -0500 Subject: [PATCH 292/311] Fix broken scrollbars --- html/extra-networks-tree.html | 4 +--- style.css | 20 +++++++------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html index 39649e860..beec888cd 100644 --- a/html/extra-networks-tree.html +++ b/html/extra-networks-tree.html @@ -1,5 +1,3 @@
                    -
                    - {tree} -
                    + {tree}
                    \ No newline at end of file diff --git a/style.css b/style.css index f3fd1571b..ff1d90722 100644 --- a/style.css +++ b/style.css @@ -1179,20 +1179,20 @@ body.resizing .resize-handle { resize: vertical; min-height: 52rem; flex-direction: column; + overflow: hidden; } .extra-network-pane .extra-network-pane-content { display: flex; flex: 1; - flex-direction: row; + overflow: hidden; } .extra-network-pane .extra-network-tree { flex: 1; - flex-direction: column; - display: flex; font-size: 1rem; border: 1px solid var(--block-border-color); + overflow: clip auto !important; } .extra-network-pane .extra-network-cards { @@ -1210,34 +1210,28 @@ body.resizing .resize-handle { overflow: hidden; } -.extra-network-pane .extra-network-tree .tree-list .tree-list-container { - flex: 1; - overflow: clip auto !important; - width: 100%; -} - .extra-network-pane .extra-network-cards::-webkit-scrollbar, -.extra-network-pane .tree-list-container::-webkit-scrollbar { +.extra-network-pane .extra-network-tree::-webkit-scrollbar { background-color: transparent; width: 16px; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-track, -.extra-network-pane .tree-list-container::-webkit-scrollbar-track { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-track { background-color: transparent; background-clip: content-box; } .extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb, -.extra-network-pane .tree-list-container::-webkit-scrollbar-thumb { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-thumb { background-color: var(--border-color-primary); border-radius: 16px; border: 4px solid var(--background-fill-primary); } .extra-network-pane .extra-network-cards::-webkit-scrollbar-button, -.extra-network-pane .tree-list-container::-webkit-scrollbar-button { +.extra-network-pane .extra-network-tree::-webkit-scrollbar-button { display: none; } From 26e1cd7ec47c8d234d2ea3f189b1147329c9059c Mon Sep 17 00:00:00 2001 From: Sj-Si Date: Sun, 21 Jan 2024 11:34:08 -0500 Subject: [PATCH 293/311] Remove unnecessary template and simplify tree list. --- html/extra-networks-tree.html | 3 --- modules/ui_extra_networks.py | 11 +---------- 2 files changed, 1 insertion(+), 13 deletions(-) delete mode 100644 html/extra-networks-tree.html diff --git a/html/extra-networks-tree.html b/html/extra-networks-tree.html deleted file mode 100644 index beec888cd..000000000 --- a/html/extra-networks-tree.html +++ /dev/null @@ -1,3 +0,0 @@ -
                    - {tree} -
                    \ No newline at end of file diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 80160b847..157b3a6d4 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -164,7 +164,6 @@ class ExtraNetworksPage: self.lister = util.MassFileLister() # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") - self.tree_tpl = shared.html("extra-networks-tree.html") self.card_tpl = shared.html("extra-networks-card.html") self.btn_tree_tpl = shared.html("extra-networks-tree-button.html") self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html") @@ -420,8 +419,6 @@ class ExtraNetworksPage: def create_tree_view_html(self, tabname: str) -> str: """Generates HTML for displaying folders in a tree view. - The generated HTML uses `extra-networks-tree.html` as a template. - Args: tabname: The name of the active tab. @@ -473,13 +470,7 @@ class ExtraNetworksPage: if item_html is not None: res += item_html - return self.tree_tpl.format( - **{ - "tabname": tabname, - "extra_networks_tabname": self.extra_networks_tabname, - "tree": f"
                      {res}
                    " - } - ) + return f"
                      {res}
                    " def create_card_view_html(self, tabname: str) -> str: """Generates HTML for the network Card View section for a tab. From fd383140cf405100f3c619f106472273a7545beb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 22 Jan 2024 02:52:34 -0800 Subject: [PATCH 294/311] fix: wrong devices for eye and constraint --- extensions-builtin/Lora/network_oft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 342fcd0dc..d1c46a4b2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,12 +57,12 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): oft_blocks = self.oft_blocks.to(orig_weight.device) - eye = torch.eye(self.block_size, device=self.oft_blocks.device) + eye = torch.eye(self.block_size, device=oft_blocks.device) if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device)) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) From f4e931f18fa4f94aece1f4dabd4dd0d635ecec13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 22 Jan 2024 23:20:30 +0300 Subject: [PATCH 295/311] put extra networks controls row into the tabs UI element for #14588 --- html/extra-networks-pane.html | 2 +- javascript/extraNetworks.js | 28 ++++++++++++++++--- modules/ui.py | 4 +-- modules/ui_extra_networks.py | 2 +- style.css | 51 +++++++++++++++++++---------------- 5 files changed, 56 insertions(+), 31 deletions(-) diff --git a/html/extra-networks-pane.html b/html/extra-networks-pane.html index 9f5b3ecee..0c763f710 100644 --- a/html/extra-networks-pane.html +++ b/html/extra-networks-pane.html @@ -1,5 +1,5 @@
                    -
                    +