From 24f21583cdba2ae6cc51773b956c6ce068d3dfe4 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Fri, 4 Aug 2023 11:43:27 +0800 Subject: [PATCH 01/18] fix: prevent cache model.state_dict() after model hijack Signed-off-by: AnyISalIn --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1d93d8935..ba15b4518 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) - del state_dict timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_info] = state_dict + + del state_dict if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) From ac8dfd9386785127c2a71ee2c1ae4f950a46f4fd Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 03:52:22 -0400 Subject: [PATCH 02/18] Toggle extras checkbox for infotext paste --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 822a76602..fde79a8ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -603,6 +603,7 @@ def create_ui(): (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), + (seed_checkbox, lambda d: any(x in ["Variation seed", "Variation seed strength", "Seed resize from-1", "Seed resize from-2"] for x in d)), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) @@ -979,6 +980,7 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), + (seed_checkbox, lambda d: any(x in ["Variation seed", "Variation seed strength", "Seed resize from-1", "Seed resize from-2"] for x in d)), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From d89a915b74cd999e13e559d6a702c7da9404db5e Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 04:03:37 -0400 Subject: [PATCH 03/18] Only enable hr fix if hr scale or upscale in infotext on paste --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 822a76602..4628bc891 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -591,8 +591,8 @@ def create_ui(): (seed_resize_from_h, "Seed resize from-2"), (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (enable_hr, lambda d: "Denoising strength" in d and any(x in ["Hires upscale", "Hires upscaler"] for x in d)), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and any(x in ["Hires upscale", "Hires upscaler"] for x in d))), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), (hr_second_pass_steps, "Hires steps"), From 82b415c9c141d8616e1e9ccb55e47a1884d652ba Mon Sep 17 00:00:00 2001 From: daxijiu <127850313+daxijiu@users.noreply.github.com> Date: Fri, 4 Aug 2023 16:03:49 +0800 Subject: [PATCH 04/18] fix some content are ignore by localization in setting "Face restoration model" and "Select which Real-ESRGAN models" and in extras "upscaler 1 & 2" are ignore by localization --- javascript/localization.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e..0c9032f9b 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,11 +11,11 @@ var ignore_ids_for_localization = { train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', - setting_random_artist_categories: 'SPAN', - setting_face_restoration_model: 'SPAN', - setting_realesrgan_enabled_models: 'SPAN', - extras_upscaler_1: 'SPAN', - extras_upscaler_2: 'SPAN', + setting_random_artist_categories: 'OPTION', + setting_face_restoration_model: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + extras_upscaler_1: 'OPTION', + extras_upscaler_2: 'OPTION', }; var re_num = /^[.\d]+$/; From 67312653d766fbbebaecae87b26d5e49789970d3 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 04:40:56 -0400 Subject: [PATCH 05/18] Cleanup hr infotext paste check --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 4628bc891..cc2980313 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -592,7 +592,7 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d and any(x in ["Hires upscale", "Hires upscaler"] for x in d)), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and any(x in ["Hires upscale", "Hires upscaler"] for x in d))), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), (hr_second_pass_steps, "Hires steps"), From 7c5480eb969a421786089c6f8bc664a7444e75ee Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 04:42:35 -0400 Subject: [PATCH 06/18] Cleanup hr infotext paste check mk2 --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index cc2980313..8b38d2136 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -591,7 +591,7 @@ def create_ui(): (seed_resize_from_h, "Seed resize from-2"), (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d and any(x in ["Hires upscale", "Hires upscaler"] for x in d)), + (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), From f5994e84a213e40f9f6f0eb24df1d4fe38a45a2e Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 04:57:01 -0400 Subject: [PATCH 07/18] Cleanup extras checkbox infotext paste check --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index fde79a8ab..2f3f74b5b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -585,6 +585,7 @@ def create_ui(): (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), @@ -603,7 +604,6 @@ def create_ui(): (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), - (seed_checkbox, lambda d: any(x in ["Variation seed", "Variation seed strength", "Seed resize from-1", "Seed resize from-2"] for x in d)), *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) @@ -973,6 +973,7 @@ def create_ui(): (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), @@ -980,7 +981,6 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - (seed_checkbox, lambda d: any(x in ["Variation seed", "Variation seed strength", "Seed resize from-1", "Seed resize from-2"] for x in d)), *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) From df9fd1d3ae5e01a9de67ba4aa34e96e9f789704a Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 05:31:38 -0400 Subject: [PATCH 08/18] Fix inpaint mask for Gradio 3.39.0 --- modules/img2img.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 68e415ef5..4ae2ba722 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask).convert('L') + mask = mask.convert('RGBA').split()[3].convert('L').point(lambda x: 255 if x > 0 else 0) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch From e219211ff6ffcdb4094334dbc4bff9a2d33af55c Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 05:35:47 -0400 Subject: [PATCH 09/18] Remove unused import in img2img --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/img2img.py b/modules/img2img.py index 4ae2ba722..ed21e82c7 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,7 +3,7 @@ from contextlib import closing from pathlib import Path import numpy as np -from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr from modules import sd_samplers, images as imgutil From 2dc2bc4ab53837e326a9c70ae250031ff6e8c929 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 05:40:13 -0400 Subject: [PATCH 10/18] Fix string quotes --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/img2img.py b/modules/img2img.py index ed21e82c7..b50678a6c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,7 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - mask = mask.convert('RGBA').split()[3].convert('L').point(lambda x: 255 if x > 0 else 0) + mask = mask.convert("RGBA").split()[3].convert("L").point(lambda x: 255 if x > 0 else 0) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch From cd4e053e5e73b90a129d5ebe5a0334a07765598f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 05:43:53 -0400 Subject: [PATCH 11/18] Simply img2img mask conversion, fix threshold --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/img2img.py b/modules/img2img.py index b50678a6c..85d920648 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,7 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - mask = mask.convert("RGBA").split()[3].convert("L").point(lambda x: 255 if x > 0 else 0) + mask = mask.split()[-1].convert('L').point(lambda x: 255 if x > 128 else 0) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch From 99f5f8e76b31c86d3091b92414a1586c29508086 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 05:47:25 -0400 Subject: [PATCH 12/18] Fix string quotes --- modules/img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/img2img.py b/modules/img2img.py index 85d920648..d8e1c534c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,7 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - mask = mask.split()[-1].convert('L').point(lambda x: 255 if x > 128 else 0) + mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch From daee41e0d64e51adaebbd0d6ba4ba85e0b59d0ae Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 06:45:12 -0400 Subject: [PATCH 13/18] Fix Gradio 3.39.0 textbox overflow --- style.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style.css b/style.css index 86b4f61ef..14e6c0114 100644 --- a/style.css +++ b/style.css @@ -140,6 +140,10 @@ div.styler{ background: var(--background-fill-primary); } +.block.gradio-textbox{ + overflow: visible !important; +} + /* general styled components */ From fadbab378183c654f3af35865022acbac877de24 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 14:56:27 +0300 Subject: [PATCH 14/18] Curse you, gradio!!! fixes broken refresh button #12309 --- modules/ui_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_common.py b/modules/ui_common.py index eefe0c0ec..1dda16272 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -239,13 +239,13 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele for comp in refresh_components: setattr(comp, k, v) - return [gr.update(**(args or {})) for _ in refresh_components] + return (gr.update(**(args or {})) for _ in refresh_components) if len(refresh_components) > 1 else gr.update(**(args or {})) refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh") refresh_button.click( fn=refresh, inputs=[], - outputs=[*refresh_components] + outputs=refresh_components ) return refresh_button From 682ff8936df018330e0d2a259794a262dc3251b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 18:51:15 +0300 Subject: [PATCH 15/18] glorious, glorious wonderful clear milky white butter smooth color for inpainting you are the best, gradio how I yearned for this day i always believed in you i knew you had it in you this day marks a new beginning thank you, everyone thank you --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 586174b86..6cf3dff88 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -667,11 +667,11 @@ def create_ui(): add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height) + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height) + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color='#ffffff') inpaint_color_sketch_orig = gr.State(None) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) From 9213d5cb3b5614b85f4752a0ba54f0cdf1282857 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 12:26:37 -0400 Subject: [PATCH 16/18] Open raw sysinfo link in new page --- modules/ui_settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui_settings.py b/modules/ui_settings.py index a6076bf30..6dde4b6aa 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -158,7 +158,7 @@ class UiSettings: loadsave.create_ui() with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): - gr.HTML('Download system info
(or open as text in a new page)', elem_id="sysinfo_download") + gr.HTML('Download system info
(or open as text in a new page)', elem_id="sysinfo_download") with gr.Row(): with gr.Column(scale=1): From b6596cdb19414cdb31a762e4c4ffbdce17d2d6e9 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:26:37 -0400 Subject: [PATCH 17/18] Prompt parser: account for empty field in alternating words syntax --- modules/prompt_parser.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 8169a4596..32d214e3a 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* | "(" prompt ":" prompt ")" | "[" prompt "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]" -alternate: "[" prompt ("|" prompt)+ "]" +alternate: "[" prompt ("|" [prompt])+ "]" WHITESPACE: /\s+/ plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER @@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[3, '((a][:b:c '], [10, '((a][:b:c d']] >>> g("[a|(b:1.1)]") [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] + >>> g("[fe|]male") + [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g("[fe|||]male") + [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] """ def collect_steps(steps, tree): @@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): before, after, _, when, _ = args yield before or () if step <= when else after def alternate(self, args): - yield next(args[(step - 1)%len(args)]) + args = ["" if not arg else arg for arg in args] + yield args[(step - 1) % len(args)] def start(self, args): def flatten(x): if type(x) == str: From 45601766409e531d2b4ee512bf1433600f140183 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 22:05:40 +0300 Subject: [PATCH 18/18] added VAE selection to checkpoint user metadata --- modules/extra_networks.py | 19 ++++++ modules/sd_vae.py | 13 +++- modules/ui_extra_networks.py | 13 +--- modules/ui_extra_networks_checkpoints.py | 3 + ...xtra_networks_checkpoints_user_metadata.py | 60 +++++++++++++++++++ 5 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 modules/ui_extra_networks_checkpoints_user_metadata.py diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91b..fa28ac752 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,3 +1,5 @@ +import json +import os import re from collections import defaultdict @@ -177,3 +179,20 @@ def parse_prompts(prompts): return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 84271db0b..0bd5e19bb 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +16,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f2752f107..c6390db79 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -101,16 +101,7 @@ class ExtraNetworksPage: def read_user_metadata(self, item): filename = item.get("filename", None) - basename, ext = os.path.splitext(filename) - metadata_filename = basename + '.json' - - metadata = {} - try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) - except Exception as e: - errors.display(e, f"reading extra network user metadata from {metadata_filename}") + metadata = extra_networks.get_user_metadata(filename) desc = metadata.get("description", None) if desc is not None: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 891d8f2cf..778850222 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models from modules.ui_extra_networks import quote_js +from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + def create_user_metadata_editor(self, ui, tabname): + return CheckpointUserMetadataEditor(ui, tabname, self) diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 000000000..2c69aab86 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,60 @@ +import gradio as gr + +from modules import ui_extra_networks_user_metadata, sd_vae +from modules.ui_common import create_refresh_button + + +class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_vae = None + + def save_user_metadata(self, name, desc, notes, vae): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + user_metadata["vae"] = vae + + self.write_user_metadata(name, user_metadata) + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + return [ + *values[0:5], + user_metadata.get('vae', ''), + ] + + def create_editor(self): + self.create_default_editor_elems() + + with gr.Row(): + self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") + create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_vae, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_notes, + self.select_vae, + ] + + self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)