diff --git a/.eslintrc.js b/.eslintrc.js index f33aca09f..e3b4fb769 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -87,5 +87,9 @@ module.exports = { modalNextImage: "readonly", // token-counters.js setupTokenCounters: "readonly", + // localStorage.js + localSet: "readonly", + localGet: "readonly", + localRemove: "readonly" } }; diff --git a/README.md b/README.md index 2fd6e425d..940176d04 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ Alternatively, use online services (like Google Colab): 1. Install the dependencies: ```bash # Debian-based: -sudo apt install wget git python3 python3-venv +sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: sudo dnf install wget git python3 # Arch-based: @@ -123,7 +123,7 @@ sudo pacman -S wget git python3 ``` 2. Navigate to the directory you would like the webui to be installed and execute the following command: ```bash -bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) +wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh ``` 3. Run `webui.sh`. 4. Check `webui-user.sh` for options. diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2ca997f7c..390d9dde3 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) with gr.Column(scale=1, min_width=120): - generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") + generate_random_prompt = gr.Button('Generate', size="lg", scale=1) self.edit_notes = gr.TextArea(label='Notes', lines=4) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a05e10d86..588b64d23 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,5 +1,7 @@ +import math + import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules.ui_components import FormColumn @@ -19,18 +21,37 @@ class ExtraOptionsSection(scripts.Script): def ui(self, is_img2img): self.comps = [] self.setting_names = [] + self.infotext_fields = [] + + mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): - for setting_name in shared.opts.extra_options: - with FormColumn(): - comp = ui_settings.create_setting_component(setting_name) + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(): - self.comps.append(comp) - self.setting_names.append(setting_name) + row_count = math.ceil(len(shared.opts.extra_options) / shared.opts.extra_options_cols) + + for row in range(row_count): + with gr.Row(): + for col in range(shared.opts.extra_options_cols): + index = row * shared.opts.extra_options_cols + col + if index >= len(shared.opts.extra_options): + break + + setting_name = shared.opts.extra_options[index] + + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + setting_infotext_name = mapping.get(setting_name) + if setting_infotext_name is not None: + self.infotext_fields.append((comp, setting_infotext_name)) def get_settings_values(): - return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + res = [ui_settings.get_value_for_setting(key) for key in self.setting_names] + return res[0] if len(res) == 1 else res interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) @@ -43,6 +64,9 @@ class ExtraOptionsSection(scripts.Script): shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), - "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") + "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() })) + + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5582a6e5d..897ebeba7 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,20 +1,38 @@ +function toggleCss(key, css, enable) { + var style = document.getElementById(key); + if (enable && !style) { + style = document.createElement('style'); + style.id = key; + style.type = 'text/css'; + document.head.appendChild(style); + } + if (style && !enable) { + document.head.removeChild(style); + } + if (style) { + style.innerHTML == ''; + style.appendChild(document.createTextNode(css)); + } +} + function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); + var search = searchDiv.querySelector('textarea'); var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); + var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); - search.classList.add('search'); - sort.classList.add('sort'); - sortOrder.classList.add('sortorder'); sort.dataset.sortkey = 'sortDefault'; - tabs.appendChild(search); + tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); + tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) { }); extraNetworksApplyFilter[tabname] = applyFilter; + + var showDirsUpdate = function() { + var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; + toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); + localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); + }; + showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; + showDirs.addEventListener("change", showDirsUpdate); + showDirsUpdate(); } function applyExtraNetworkFilter(tabname) { @@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksSearchButton(tabs_id, event) { - var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); + var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 677e95c1b..c21d396ee 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -136,6 +136,11 @@ function setupImageForLightbox(e) { var event = isFirefox ? 'mousedown' : 'click'; e.addEventListener(event, function(evt) { + if (evt.button == 1) { + open(evt.target.src); + evt.preventDefault(); + return; + } if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/javascript/localStorage.js b/javascript/localStorage.js new file mode 100644 index 000000000..dc1a36c32 --- /dev/null +++ b/javascript/localStorage.js @@ -0,0 +1,26 @@ + +function localSet(k, v) { + try { + localStorage.setItem(k, v); + } catch (e) { + console.warn(`Failed to save ${k} to localStorage: ${e}`); + } +} + +function localGet(k, def) { + try { + return localStorage.getItem(k); + } catch (e) { + console.warn(`Failed to load ${k} from localStorage: ${e}`); + } + + return def; +} + +function localRemove(k) { + try { + return localStorage.removeItem(k); + } catch (e) { + console.warn(`Failed to remove ${k} from localStorage: ${e}`); + } +} diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e..0c9032f9b 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,11 +11,11 @@ var ignore_ids_for_localization = { train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', - setting_random_artist_categories: 'SPAN', - setting_face_restoration_model: 'SPAN', - setting_realesrgan_enabled_models: 'SPAN', - extras_upscaler_1: 'SPAN', - extras_upscaler_2: 'SPAN', + setting_random_artist_categories: 'OPTION', + setting_face_restoration_model: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + extras_upscaler_1: 'OPTION', + extras_upscaler_2: 'OPTION', }; var re_num = /^[.\d]+$/; diff --git a/javascript/ui.js b/javascript/ui.js index abf23a78c..bade30893 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -152,15 +152,11 @@ function submit() { showSubmitButtons('txt2img', false); var id = randomId(); - try { - localStorage.setItem("txt2img_task_id", id); - } catch (e) { - console.warn(`Failed to save txt2img task id to localStorage: ${e}`); - } + localSet("txt2img_task_id", id); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); - localStorage.removeItem("txt2img_task_id"); + localRemove("txt2img_task_id"); showRestoreProgressButton('txt2img', false); }); @@ -175,15 +171,11 @@ function submit_img2img() { showSubmitButtons('img2img', false); var id = randomId(); - try { - localStorage.setItem("img2img_task_id", id); - } catch (e) { - console.warn(`Failed to save img2img task id to localStorage: ${e}`); - } + localSet("img2img_task_id", id); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); - localStorage.removeItem("img2img_task_id"); + localRemove("img2img_task_id"); showRestoreProgressButton('img2img', false); }); @@ -197,7 +189,7 @@ function submit_img2img() { function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); - var id = localStorage.getItem("txt2img_task_id"); + var id = localGet("txt2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { @@ -211,7 +203,7 @@ function restoreProgressTxt2img() { function restoreProgressImg2img() { showRestoreProgressButton("img2img", false); - var id = localStorage.getItem("img2img_task_id"); + var id = localGet("img2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { @@ -224,8 +216,8 @@ function restoreProgressImg2img() { onUiLoaded(function() { - showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id")); - showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id")); + showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); + showRestoreProgressButton('img2img', localGet("img2img_task_id")); }); diff --git a/modules/cmd_args.py b/modules/cmd_args.py index cb4ec5f7d..64f21e011 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/devices.py b/modules/devices.py index 57e51da30..00a00b18a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors +from modules import errors, rng_philox if sys.platform == "darwin": from modules import mac_specific @@ -71,14 +71,17 @@ def enable_tf32(): torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -90,23 +93,87 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input +nv_rng = None + + def randn(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + from modules.shared import opts - torch.manual_seed(seed) + manual_seed(seed) + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + from modules.shared import opts + + if opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=device) + + local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + from modules.shared import opts + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=cpu).to(x.device) + + return torch.randn_like(x) + + def randn_without_seed(shape): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + from modules.shared import opts + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + def autocast(disable=False): from modules import shared diff --git a/modules/errors.py b/modules/errors.py index dffabe45c..192cd8ffd 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -84,3 +84,53 @@ def run(code, task): code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.39.0" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/modules/extensions.py b/modules/extensions.py index 3ad5ed531..e4633af40 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True) def active(): - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": return [] - elif shared.opts.disable_all_extensions == "extra": + elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -141,8 +141,12 @@ def list_extensions(): if not os.path.isdir(extensions_dir): return - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions: + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": print("*** \"Disable all extensions\" option was set, will not load any extensions ***") + elif shared.cmd_opts.disable_extra_extensions: + print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91b..fa28ac752 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,3 +1,5 @@ +import json +import os import re from collections import defaultdict @@ -177,3 +179,20 @@ def parse_prompts(prompts): return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9d..5758e6f30 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" @@ -304,6 +307,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Schedule rho" not in res: res["Schedule rho"] = 0 + if "VAE Encoder" not in res: + res["VAE Encoder"] = "Full" + + if "VAE Decoder" not in res: + res["VAE Decoder"] = "Full" + return res @@ -319,6 +328,10 @@ infotext_to_setting_name_mapping = [ ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), + ('Sigma churn', 's_churn'), + ('Sigma tmin', 's_tmin'), + ('Sigma tmax', 's_tmax'), + ('Sigma noise', 's_noise'), ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), ('UniPC variant', 'uni_pc_variant'), ('UniPC skip type', 'uni_pc_skip_type'), @@ -329,6 +342,8 @@ infotext_to_setting_name_mapping = [ ('RNG', 'randn_source'), ('NGMS', 's_min_uncond'), ('Pad conds', 'pad_cond_uncond'), + ('VAE Encoder', 'sd_vae_encode_method'), + ('VAE Decoder', 'sd_vae_decode_method'), ] @@ -399,10 +414,15 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, return res if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + def paste_settings(params): vals = {} for param_name, setting_name in infotext_to_setting_name_mapping: + if param_name in already_handled_fields: + continue + v = params.get(param_name, None) if v is None: continue diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py new file mode 100644 index 000000000..5af7fd8ec --- /dev/null +++ b/modules/gradio_extensons.py @@ -0,0 +1,60 @@ +import gradio as gr + +from modules import scripts + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + return config + + +def BlockContext_init(self, *args, **kwargs): + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + return res + + +original_IOComponent_init = gr.components.IOComponent.__init__ +original_Block_get_config = gr.blocks.Block.get_config +original_BlockContext_init = gr.blocks.BlockContext.__init__ + +gr.components.IOComponent.__init__ = IOComponent_init +gr.blocks.Block.get_config = Block_get_config +gr.blocks.BlockContext.__init__ = BlockContext_init diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21a..70f1cbd26 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images allows training previews to have infotext. Importing it at the top causes a circular import problem. - from modules import images + from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/images.py b/modules/images.py index 38aa933d6..ba3c43a45 100644 --- a/modules/images.py +++ b/modules/images.py @@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') diff --git a/modules/img2img.py b/modules/img2img.py index 68e415ef5..d8e1c534c 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,7 +3,7 @@ from contextlib import closing from pathlib import Path import numpy as np -from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr from modules import sd_samplers, images as imgutil @@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask).convert('L') + mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch diff --git a/modules/launch_utils.py b/modules/launch_utils.py index f77b577a5..5be30a183 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -139,6 +139,27 @@ def check_run_python(code: str) -> bool: return result.returncode == 0 +def git_fix_workspace(dir, name): + run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True) + run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True) + return + + +def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True): + try: + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + except RuntimeError: + pass + + if not autofix: + return None + + print(f"{errdesc}, attempting autofix...") + git_fix_workspace(dir, name) + + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + + def git_clone(url, dir, name, commithash=None): # TODO clone into temporary dir and move if successful @@ -146,12 +167,14 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() + current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() if current_hash == commithash: return - run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + + run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) + return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f8306643..96f52b7b4 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} diff --git a/modules/processing.py b/modules/processing.py index b0992ee15..317450065 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack +from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -83,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -109,7 +110,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -147,8 +148,8 @@ class StableDiffusionProcessing: self.s_min_uncond = s_min_uncond or opts.s_min_uncond self.s_churn = s_churn or opts.s_churn self.s_tmin = s_tmin or opts.s_tmin - self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option - self.s_noise = s_noise or opts.s_noise + self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf') + self.s_noise = s_noise if s_noise is not None else opts.s_noise self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False @@ -202,7 +203,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -215,7 +216,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) return conditioning_image @@ -294,7 +295,7 @@ class StableDiffusionProcessing: self.sampler = None self.c = None self.uc = None - if not opts.experimental_persistent_cond_cache: + if not opts.persistent_cond_cache: StableDiffusionProcessing.cached_c = [None, None] StableDiffusionProcessing.cached_uc = [None, None] @@ -318,6 +319,21 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] + def cached_params(self, required_prompts, steps, extra_network_data): + """Returns parameters that invalidate the cond cache if changed""" + + return ( + required_prompts, + steps, + opts.CLIP_stop_at_last_layers, + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, + opts.sdxl_crop_top, + self.width, + self.height, + ) + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) @@ -331,17 +347,7 @@ class StableDiffusionProcessing: caches is a list with items described above. """ - cached_params = ( - required_prompts, - steps, - opts.CLIP_stop_at_last_layers, - shared.sd_model.sd_checkpoint_info, - extra_network_data, - opts.sdxl_crop_left, - opts.sdxl_crop_top, - self.width, - self.height, - ) + cached_params = self.cached_params(required_prompts, steps, extra_network_data) for cache in caches: if cache[0] is not None and cached_params == cache[0]: @@ -367,6 +373,10 @@ class StableDiffusionProcessing: def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) + def save_samples(self) -> bool: + """Returns whether generated images need to be written to disk""" + return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped) + class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""): @@ -492,7 +502,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) subnoise = None - if subseeds is not None: + if subseeds is not None and subseed_strength != 0: subseed = 0 if i >= len(subseeds) else subseeds[i] subnoise = devices.randn(subseed, noise_shape) @@ -524,7 +534,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see cnt = p.sampler.number_of_needed_noises(p) if eta_noise_seed_delta > 0: - torch.manual_seed(seed + eta_noise_seed_delta) + devices.manual_seed(seed + eta_noise_seed_delta) for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -538,8 +548,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +class DecodedSamples(list): + already_decoded = True + + def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): - samples = [] + samples = DecodedSamples() for i in range(batch.shape[0]): sample = decode_first_stage(model, batch[i:i + 1])[0] @@ -572,12 +586,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): return samples -def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x - - def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) @@ -636,7 +644,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), - "RNG": opts.randn_source if opts.randn_source != "GPU" else None, + "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, @@ -793,7 +801,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + if getattr(samples_ddim, 'already_decoded', False): + x_samples_ddim = samples_ddim + else: + if opts.sd_vae_decode_method != 'Full': + p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -817,6 +832,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(index=0, use_main_prompt=False): return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + save_samples = p.save_samples() + for i, x_sample in enumerate(x_samples_ddim): p.batch_index = i @@ -824,7 +841,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_sample = x_sample.astype(np.uint8) if p.restore_faces: - if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: + if save_samples and opts.save_images_before_face_restoration: images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration") devices.torch_gc() @@ -838,16 +855,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image - if p.color_corrections is not None and i < len(p.color_corrections): - if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: + if save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) image = apply_overlay(image, p.paste_to, i, p.overlay_images) - if opts.samples_save and not p.do_not_save_samples: + if save_samples: images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) @@ -855,8 +871,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): image_mask = p.mask_for_overlay.convert('RGB') image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') @@ -892,7 +907,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 - if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) @@ -935,7 +949,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -946,11 +960,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y self.hr_upscale_to_x = hr_resize_x self.hr_upscale_to_y = hr_resize_y + self.hr_checkpoint_name = hr_checkpoint_name + self.hr_checkpoint_info = None self.hr_sampler_name = hr_sampler_name self.hr_prompt = hr_prompt self.hr_negative_prompt = hr_negative_prompt self.all_hr_prompts = None self.all_hr_negative_prompts = None + self.latent_scale_mode = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -973,6 +990,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if self.hr_checkpoint_name: + self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) + + if self.hr_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}') + + self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -982,6 +1007,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") + if self.enable_hr and self.latent_scale_mode is None: + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): self.hr_resize_x = self.width self.hr_resize_y = self.height @@ -1020,14 +1050,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f - # special case: the user has chosen to do nothing - if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: - self.enable_hr = False - self.denoising_strength = None - self.extra_generation_params.pop("Hires upscale", None) - self.extra_generation_params.pop("Hires resize", None) - return - if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter @@ -1045,17 +1067,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") - if self.enable_hr and latent_scale_mode is None: - if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): - raise Exception(f"could not find upscaler named {self.hr_upscaler}") - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + del x if not self.enable_hr: return samples + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None + + current = shared.sd_model.sd_checkpoint_info + try: + if self.hr_checkpoint_info is not None: + self.sampler = None + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + devices.torch_gc() + + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) + finally: + self.sampler = None + sd_models.reload_model_weights(info=current) + devices.torch_gc() + + def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): self.is_hr_pass = True target_width = self.hr_upscale_to_x @@ -1064,7 +1101,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" - if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: + if not self.save_samples() or not opts.save_images_before_highres_fix: return if not isinstance(image, Image.Image): @@ -1073,11 +1110,18 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix") - if latent_scale_mode is not None: + img2img_sampler_name = self.hr_sampler_name or self.sampler_name + + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' + + self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) + + if self.latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1086,7 +1130,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: image_conditioning = self.txt2img_image_conditioning(samples) else: - decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] @@ -1103,28 +1146,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) - decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. + decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) shared.state.nextjob() - img2img_sampler_name = self.hr_sampler_name or self.sampler_name - - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory - x = None devices.torch_gc() if not self.disable_extra_networks: @@ -1143,15 +1179,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + self.is_hr_pass = False - return samples + return decoded_samples def close(self): super().close() self.hr_c = None self.hr_uc = None - if not opts.experimental_persistent_cond_cache: + if not opts.persistent_cond_cache: StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] @@ -1184,8 +1222,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_c is not None: return - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) + hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() @@ -1193,7 +1234,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None self.hr_c = None - if self.enable_hr: + if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() @@ -1344,10 +1385,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - image = 2. * image - 1. image = image.to(shared.device, dtype=devices.dtype_vae) - self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) + if opts.sd_vae_encode_method != 'Full': + self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method + + self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) + devices.torch_gc() if self.resize_mode == 3: self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 8169a4596..32d214e3a 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -20,7 +20,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* | "(" prompt ":" prompt ")" | "[" prompt "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]" -alternate: "[" prompt ("|" prompt)+ "]" +alternate: "[" prompt ("|" [prompt])+ "]" WHITESPACE: /\s+/ plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER @@ -53,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[3, '((a][:b:c '], [10, '((a][:b:c d']] >>> g("[a|(b:1.1)]") [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] + >>> g("[fe|]male") + [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g("[fe|||]male") + [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] """ def collect_steps(steps, tree): @@ -78,7 +82,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): before, after, _, when, _ = args yield before or () if step <= when else after def alternate(self, args): - yield next(args[(step - 1)%len(args)]) + args = ["" if not arg else arg for arg in args] + yield args[(step - 1) % len(args)] def start(self, args): def flatten(x): if type(x) == str: diff --git a/modules/rng_philox.py b/modules/rng_philox.py new file mode 100644 index 000000000..5532cf9dd --- /dev/null +++ b/modules/rng_philox.py @@ -0,0 +1,102 @@ +"""RNG imitiating torch cuda randn on CPU. You are welcome. + +Usage: + +``` +g = Generator(seed=0) +print(g.randn(shape=(3, 4))) +``` + +Expected output: +``` +[[-0.92466259 -0.42534415 -2.6438457 0.14518388] + [-0.12086647 -0.57972564 -0.62285122 -0.32838709] + [-1.07454231 -0.36314407 -1.67105067 2.26550497]] +``` +""" + +import numpy as np + +philox_m = [0xD2511F53, 0xCD9E8D57] +philox_w = [0x9E3779B9, 0xBB67AE85] + +two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32) +two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) + + +def uint32(x): + """Converts (N,) np.uint64 array into (2, N) np.unit32 array.""" + return x.view(np.uint32).reshape(-1, 2).transpose(1, 0) + + +def philox4_round(counter, key): + """A single round of the Philox 4x32 random number generator.""" + + v1 = uint32(counter[0].astype(np.uint64) * philox_m[0]) + v2 = uint32(counter[2].astype(np.uint64) * philox_m[1]) + + counter[0] = v2[1] ^ counter[1] ^ key[0] + counter[1] = v2[0] + counter[2] = v1[1] ^ counter[3] ^ key[1] + counter[3] = v1[0] + + +def philox4_32(counter, key, rounds=10): + """Generates 32-bit random numbers using the Philox 4x32 random number generator. + + Parameters: + counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation). + key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed). + rounds (int): The number of rounds to perform. + + Returns: + numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers. + """ + + for _ in range(rounds - 1): + philox4_round(counter, key) + + key[0] = key[0] + philox_w[0] + key[1] = key[1] + philox_w[1] + + philox4_round(counter, key) + return counter + + +def box_muller(x, y): + """Returns just the first out of two numbers generated by Box–Muller transform algorithm.""" + u = x * two_pow32_inv + two_pow32_inv / 2 + v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2 + + s = np.sqrt(-2.0 * np.log(u)) + + r1 = s * np.sin(v) + return r1.astype(np.float32) + + +class Generator: + """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU""" + + def __init__(self, seed): + self.seed = seed + self.offset = 0 + + def randn(self, shape): + """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.""" + + n = 1 + for x in shape: + n *= x + + counter = np.zeros((4, n), dtype=np.uint32) + counter[0] = self.offset + counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3] + self.offset += 1 + + key = np.empty(n, dtype=np.uint64) + key.fill(self.seed) + key = uint32(key) + + g = philox4_32(counter, key) + + return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3] diff --git a/modules/scripts.py b/modules/scripts.py index edf7347ea..f7d060aa5 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -631,63 +631,3 @@ def reload_script_body_only(): reload_scripts = load_scripts # compatibility alias - - -def add_classes_to_gradio_component(comp): - """ - this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others - """ - - comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] - - if getattr(comp, 'multiselect', False): - comp.elem_classes.append('multiselect') - - - -def IOComponent_init(self, *args, **kwargs): - self.webui_tooltip = kwargs.pop('tooltip', None) - - if scripts_current is not None: - scripts_current.before_component(self, **kwargs) - - script_callbacks.before_component_callback(self, **kwargs) - - res = original_IOComponent_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - script_callbacks.after_component_callback(self, **kwargs) - - if scripts_current is not None: - scripts_current.after_component(self, **kwargs) - - return res - - -def Block_get_config(self): - config = original_Block_get_config(self) - - webui_tooltip = getattr(self, 'webui_tooltip', None) - if webui_tooltip: - config["webui_tooltip"] = webui_tooltip - - return config - - -original_IOComponent_init = gr.components.IOComponent.__init__ -original_Block_get_config = gr.components.Block.get_config -gr.components.IOComponent.__init__ = IOComponent_init -gr.components.Block.get_config = Block_get_config - - -def BlockContext_init(self, *args, **kwargs): - res = original_BlockContext_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - return res - - -original_BlockContext_init = gr.blocks.BlockContext.__init__ -gr.blocks.BlockContext.__init__ = BlockContext_init diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0ebb..9ad981998 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,11 +2,10 @@ import torch from torch.nn.functional import silu from types import MethodType -import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -30,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 -ldm.modules.attention.print = lambda *args: None -ldm.modules.diffusionmodules.model.print = lambda *args: None +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print + +sd_hijack_inpainting.do_inpainting_hijack() optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None @@ -164,12 +167,13 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def __init__(self): + import modules.textual_inversion.textual_inversion + self.extra_generation_params = {} self.comments = [] + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 2f9d569b1..8f29057a9 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): hashes.append(f"{name}: {shorthash}") if hashes: + if self.hijack.extra_generation_params.get("TI hashes"): + hashes.append(self.hijack.extra_generation_params.get("TI hashes")) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) if getattr(self.wrapped, 'return_pooled', False): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b194..2d44b8566 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b5f85ba51..0e810eec8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + slice_size = q.shape[1] // steps for i in range(0, q.shape[1], slice_size): - end = i + slice_size + end = min(i + slice_size, q.shape[1]) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s2 = s1.softmax(dim=-1, dtype=q.dtype) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f72f21d1..53c1df54c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -67,8 +66,9 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' - self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -87,6 +87,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -107,14 +108,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -137,11 +132,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -151,6 +149,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None @@ -303,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) - del state_dict timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_info] = state_dict + + del state_dict if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -352,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") @@ -429,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -443,6 +448,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -454,11 +460,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -466,19 +485,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -518,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -531,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -552,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -564,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return + return sd_model + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + + if sd_model is not None: sd_unet.apply_unet("None") - - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) - + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -584,7 +673,9 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -607,6 +698,9 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + sd_unet.apply_unet() + return sd_model diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index bc2195087..011233216 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -98,10 +98,10 @@ def extend_sdxl(model): model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None -sgm.modules.encoders.modules.print = lambda *args: None +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 763829f1c..39586b40e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,10 +2,8 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd - +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state -import modules.shared as shared SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -25,19 +23,29 @@ def setup_img2img_steps(p, steps=None): approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} -def single_sample_to_image(sample, approximation=None): +def samples_to_images_tensor(sample, approximation=None, model=None): + '''latents -> images [-1, 1]''' if approximation is None: approximation = approximation_indexes.get(opts.show_progress_type, 0) if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 + x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: x_sample = sample * 1.5 - x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach() + x_sample = x_sample * 2 - 1 else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + if model is None: + model = shared.sd_model + x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) + + return x_sample + + +def single_sample_to_image(sample, approximation=None): + x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -46,6 +54,12 @@ def single_sample_to_image(sample, approximation=None): return Image.fromarray(x_sample) +def decode_first_stage(model, x): + x = x.to(devices.dtype_vae) + approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) + return samples_to_images_tensor(x, approx_index, model) + + def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) @@ -54,6 +68,24 @@ def samples_to_image_grid(samples, approximation=None): return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) +def images_tensor_to_samples(image, approximation=None, model=None): + '''image[0, 1] -> latent''' + if approximation is None: + approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0) + + if approximation == 3: + image = image.to(devices.device, devices.dtype) + x_latent = sd_vae_taesd.encoder_model()(image) + else: + if model is None: + model = shared.sd_model + image = image.to(shared.device, dtype=devices.dtype_vae) + image = image * 2 - 1 + x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) + + return x_latent + + def store_latent(decoded): state.current_latent = decoded @@ -85,11 +117,13 @@ class InterruptedException(BaseException): pass -if opts.randn_source == "CPU": +def replace_torchsde_browinan(): import torchsde._brownian.brownian_interval def torchsde_randn(size, dtype, device, seed): - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) + return devices.randn_local(seed, size).to(device=device, dtype=dtype) torchsde._brownian.brownian_interval._randn = torchsde_randn + + +replace_torchsde_browinan() diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index e0da34259..db71a549a 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -4,6 +4,7 @@ import inspect import k_diffusion.sampling from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules.processing import StableDiffusionProcessing from modules.shared import opts, state import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback @@ -30,6 +31,7 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), + ('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}), ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] @@ -260,10 +262,7 @@ class TorchHijack: if noise.shape == x.shape: return noise - if opts.randn_source == "CPU" or x.device.type == 'mps': - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) + return devices.randn_like(x) class KDiffusionSampler: @@ -282,6 +281,14 @@ class KDiffusionSampler: self.last_latent = None self.s_min_uncond = None + # NOTE: These are also defined in the StableDiffusionProcessing class. + # They should have been here to begin with but we're going to + # leave that class __init__ signature alone. + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + self.conditioning_key = sd_model.model.conditioning_key def callback_state(self, d): @@ -316,7 +323,7 @@ class KDiffusionSampler: def number_of_needed_noises(self, p): return p.steps - def initialize(self, p): + def initialize(self, p: StableDiffusionProcessing): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 @@ -337,6 +344,29 @@ class KDiffusionSampler: extra_params_kwargs['eta'] = self.eta + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + return extra_params_kwargs def get_sigmas(self, p, steps): @@ -378,6 +408,9 @@ class KDiffusionSampler: sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff29946..38bcb8402 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,8 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from dataclasses import dataclass + +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +18,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -50,6 +53,7 @@ def get_filename(filepath): def refresh_vae_list(): + global vae_dict vae_dict.clear() paths = [ @@ -83,6 +87,8 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath + vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] @@ -93,27 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + def tuple(self): + return self.vae, self.source - vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: if shared.opts.sd_vae == "None": - return None, None + return VaeResolution() vae_from_options = vae_dict.get(shared.opts.sd_vae, None) if vae_from_options is not None: - return vae_from_options, 'specified in settings' + return VaeResolution(vae_from_options, 'specified in settings') - if not is_automatic: + if not is_automatic(): print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") - return None, None + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return VaeResolution() + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return VaeResolution(vae_from_metadata, "from user metadata") + + return VaeResolution(resolved=False) + + +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') + + return VaeResolution(resolved=False) + + +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') + + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() + + return res def load_vae_dict(filename, map_location): @@ -187,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 86bd658ad..3965e223e 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -81,6 +81,6 @@ def cheap_approximation(sample): coefs = torch.tensor(coeffs).to(sample.device) - x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) + x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs) return x_sample diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5bf7c76e1..808eb3624 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -44,7 +44,17 @@ def decoder(): ) -class TAESD(nn.Module): +def encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + + +class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 @@ -55,21 +65,28 @@ class TAESD(nn.Module): self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - @staticmethod - def unscale_latents(x): - """[0, 1] -> raw latents""" - return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + +class TAESDEncoder(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = encoder() + self.encoder.load_state_dict( + torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) - print(f'Downloading TAESD decoder to: {model_path}') + print(f'Downloading TAESD model to: {model_path}') torch.hub.download_url_to_file(model_url, model_path) -def model(): +def decoder_model(): model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name) @@ -78,7 +95,7 @@ def model(): download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - loaded_model = TAESD(model_path) + loaded_model = TAESDDecoder(model_path) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) sd_vae_taesd_models[model_name] = loaded_model @@ -86,3 +103,22 @@ def model(): raise FileNotFoundError('TAESD model not found') return loaded_model.decoder + + +def encoder_model(): + model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) + + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) + + if os.path.exists(model_path): + loaded_model = TAESDEncoder(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model + else: + raise FileNotFoundError('TAESD model not found') + + return loaded_model.encoder diff --git a/modules/shared.py b/modules/shared.py index 5a7be85b4..97f1eab52 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -51,15 +51,34 @@ restricted_opts = { # https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json gradio_hf_hub_themes = [ + "gradio/base", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", - "freddyaboulton/dracula_revamped", "gradio/dracula_test", "abidlabs/dracula_test", + "abidlabs/Lime", "abidlabs/pakistan", + "Ama434/neutral-barlow", "dawood/microsoft_windows", + "finlaymacklon/smooth_slate", + "Franklisi/darkmode", + "freddyaboulton/dracula_revamped", + "freddyaboulton/test-blue", + "gstaff/xkcd", + "Insuz/Mocha", + "Insuz/SimpleIndigo", + "JohnSmith9982/small_and_pretty", + "nota-ai/theme", + "nuttea/Softblue", + "ParityError/Anime", + "reilnuud/polite", + "remilia/Ghostly", + "rottenlittlecreature/Moon_Goblin", + "step-3-profit/Midnight-Deep", + "Taithrah/Minimal", + "ysharma/huggingface", "ysharma/steampunk" ] @@ -220,12 +239,19 @@ class State: return import modules.sd_samplers - if opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) - else: - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - self.current_image_sampling_step = self.sampling_step + try: + if opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() def assign_current_image(self, image): self.current_image = image @@ -252,6 +278,7 @@ class OptionInfo: self.onchange = onchange self.section = section self.refresh = refresh + self.do_not_save = False self.comment_before = comment_before """HTML text that will be added after label in UI""" @@ -279,8 +306,17 @@ class OptionInfo: self.comment_after += " (requires restart)" return self + def needs_reload_ui(self): + self.comment_after += " (requires Reload UI)" + return self +class OptionHTML(OptionInfo): + def __init__(self, text): + super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs)) + + self.do_not_save = True + def options_section(section_identifier, options_dict): for v in options_dict.values(): @@ -349,6 +385,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { @@ -386,13 +423,15 @@ options_templates.update(options_section(('face-restoration', "Face restoration" options_templates.update(options_section(('system', "System"), { "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), - "show_warnings": OptionInfo(False, "Show warnings in console."), + "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), + "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -412,24 +451,17 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), - "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), - "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), - "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), - "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(), "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), - "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), - "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { @@ -439,6 +471,35 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) +options_templates.update(options_section(('vae', "VAE"), { + "sd_vae_explanation": OptionHTML(""" +VAE is a neural network that transforms a standard RGB +image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling +(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished. +For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling. +"""), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), + "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), + "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), + "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), +})) + +options_templates.update(options_section(('img2img', "img2img"), { + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}), + "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(), + "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(), + "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(), + "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), + "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), + "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), +})) + options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), @@ -446,7 +507,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"), - "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."), + "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("Do not recalculate conds from prompts if prompts have not changed since previous calculation"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -458,7 +519,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), })) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { +options_templates.update(options_section(('interrogate', "Interrogate"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"), "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"), "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), @@ -482,19 +543,17 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), - "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) options_templates.update(options_section(('ui', "User interface"), { - "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(), - "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(), - "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(), + "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), + "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).info("you can also manually enter any of themes from the gallery.").needs_reload_ui(), + "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), "return_grid": OptionInfo(True, "Show grid in results for web"), - "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), - "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), @@ -503,21 +562,22 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(), - "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), - "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), - "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), - "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), - "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), - "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), - "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), + "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), + "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(), + "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_reload_ui(), + "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), + "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), + "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), })) + options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), @@ -545,17 +605,18 @@ options_templates.update(options_section(('ui', "Live previews"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(), + "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_reload_ui(), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"), "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}).info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}).info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), + 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}).info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}).info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'), + 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), - 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"), - 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), + 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), @@ -595,6 +656,9 @@ class Options: assert not cmd_opts.freeze_settings, "changing settings is disabled" info = opts.data_labels.get(key, None) + if info.do_not_save: + return + comp_args = info.component_args if info else None if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: raise RuntimeError(f"not possible to set {key} because it is restricted") @@ -624,6 +688,9 @@ class Options: if oldval == value: return False + if self.data_labels[key].do_not_save: + return False + try: setattr(self, key, value) except RuntimeError: @@ -667,6 +734,10 @@ class Options: with open(filename, "r", encoding="utf8") as file: self.data = json.load(file) + # 1.6.0 VAE defaults + if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: + self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') + # 1.1.1 quicksettings list migration if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] @@ -800,13 +871,19 @@ def reload_gradio_theme(theme_name=None): gradio_theme = gr.themes.Default(**default_theme_args) else: try: - gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) + theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes') + theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json') + if opts.gradio_themes_cache and os.path.exists(theme_cache_path): + gradio_theme = gr.themes.ThemeClass.load(theme_cache_path) + else: + os.makedirs(theme_cache_dir, exist_ok=True) + gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) + gradio_theme.dump(theme_cache_path) except Exception as e: errors.display(e, "changing gradio theme") gradio_theme = gr.themes.Default(**default_theme_args) - class TotalTQDM: def __init__(self): self._tqdm = None @@ -890,3 +967,10 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs) diff --git a/modules/styles.py b/modules/styles.py index ec0e1bc51..0740fe1b1 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -106,10 +106,7 @@ class StyleDatabase: if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR | os.O_CREAT) - with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + with open(path, "w", encoding="utf-8-sig", newline='') as file: writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() writer.writerows(style._asdict() for k, style in self.styles.items()) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4713bc2d9..aa79dc098 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes +from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + from modules import processing + save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8cb..935ed4181 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, diff --git a/modules/ui.py b/modules/ui.py index ac2787eb4..5150dae41 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,39 +12,38 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger +from modules import gradio_extensons # noqa: F401 +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button from modules.ui_gradio_extensions import reload_javascript - from modules.shared import opts, cmd_opts -import modules.codeformer_model import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts +import modules.hypernetworks.ui as hypernetworks_ui +import modules.textual_inversion.ui as textual_inversion_ui +import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui +import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text -import modules.extras create_setting_component = ui_settings.create_setting_component warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) +warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +# Likewise, add explicit content-type header for certain missing image types +mimetypes.add_type('image/webp', '.webp') + if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home gradio.utils.version_check = lambda: None @@ -92,19 +91,6 @@ def send_gradio_gallery_to_image(x): return image_from_url_text(x[0]) -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] - - def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): from modules import processing, devices @@ -129,13 +115,6 @@ def resize_from_to_html(width, height, scale_by): return f"resize: from {width}x{height} to {target_width}x{target_height}" -def apply_styles(prompt, prompt_neg, styles): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] - - def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): if mode in {0, 1, 3, 4}: return [interrogation_function(ii_singles[mode]), None] @@ -172,7 +151,6 @@ def interrogate_deepbooru(image): def create_seed_inputs(target_interface): with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"): seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed") - seed.style(container=False) random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed') reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed') @@ -184,7 +162,6 @@ def create_seed_inputs(target_interface): with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1: seed_extras.append(seed_extra_row_1) subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed") - subseed.style(container=False) random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed") reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed") subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength") @@ -267,71 +244,76 @@ def update_token_counter(text, steps): return f"{token_count}/{max_length}" -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" +class Toprow: + """Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation""" - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + def __init__(self, is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + self.id_part = id_part - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False) - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + self.button_interrogate = None + self.button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_classes="interrogate-col"): + self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) + with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): + with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): + self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) - with gr.Row(elem_id=f"{id_part}_tools"): - paste = ToolButton(value=paste_symbol, elem_id="paste") - clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") - prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") - save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") - restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) + self.interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) - token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + with gr.Row(elem_id=f"{id_part}_tools"): + self.paste = ToolButton(value=paste_symbol, elem_id="paste") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) + self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - with gr.Row(elem_id=f"{id_part}_styles_row"): - prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) - create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) + + self.prompt_img.change( + fn=modules.images.image_data, + inputs=[self.prompt_img], + outputs=[self.prompt, self.prompt_img], + show_progress=False, + ) def setup_progressbar(*args, **kwargs): @@ -415,22 +397,20 @@ def create_ui(): parameters_copypaste.reset() - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + scripts.scripts_current = scripts.scripts_txt2img + scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + toprow = Toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) - with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: - from modules import ui_extra_networks - extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs") + extra_tabs.__enter__() - with gr.Row().style(equal_height=False): + with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): - modules.scripts.scripts_txt2img.prepare_ui() + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": @@ -476,6 +456,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: @@ -498,10 +482,10 @@ def create_ui(): elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + custom_inputs = scripts.scripts_txt2img.setup_ui() else: - modules.scripts.scripts_txt2img.setup_ui_for_section(category) + scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -532,9 +516,9 @@ def create_ui(): _js="submit", inputs=[ dummy_component, - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, steps, sampler_index, restore_faces, @@ -553,6 +537,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + hr_checkpoint_name, hr_sampler_index, hr_prompt, hr_negative_prompt, @@ -569,12 +554,12 @@ def create_ui(): show_progress=False, ) - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) + toprow.prompt.submit(**txt2img_args) + toprow.submit.click(**txt2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressTxt2img", inputs=[dummy_component], @@ -587,18 +572,6 @@ def create_ui(): show_progress=False, ) - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ], - show_progress=False, - ) - enable_hr.change( fn=lambda x: gr_show(x), inputs=[enable_hr], @@ -607,8 +580,8 @@ def create_ui(): ) txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -617,34 +590,36 @@ def create_ui(): (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), + (hr_checkpoint_name, "Hires checkpoint"), (hr_sampler_index, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), - *modules.scripts.scripts_txt2img.infotext_fields + *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, + toprow.prompt, + toprow.negative_prompt, steps, sampler_index, cfg_scale, @@ -653,24 +628,25 @@ def create_ui(): height, ] - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) - ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + extra_tabs.__exit__() + + scripts.scripts_current = scripts.scripts_img2img + scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + toprow = Toprow(is_img2img=True) - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") + extra_tabs.__enter__() - with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: - from modules import ui_extra_networks - extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') - - with FormRow().style(equal_height=False): + with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] copy_image_destinations = {} @@ -692,19 +668,19 @@ def create_ui(): img2img_selected_tab = gr.State(0) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height) + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height) add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height) + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color) add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height) + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height) + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) inpaint_color_sketch_orig = gr.State(None) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) @@ -764,7 +740,7 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - modules.scripts.scripts_img2img.prepare_ui() + scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": @@ -845,7 +821,7 @@ def create_ui(): elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() + custom_inputs = scripts.scripts_img2img.setup_ui() elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: @@ -876,34 +852,22 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha], ) else: - modules.scripts.scripts_img2img.setup_ui_for_section(category) + scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ], - show_progress=False, - ) - img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ dummy_component, dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, init_img, sketch, init_img_with_mask, @@ -962,11 +926,11 @@ def create_ui(): inpaint_color_sketch, init_img_inpaint, ], - outputs=[img2img_prompt, dummy_component], + outputs=[toprow.prompt, dummy_component], ) - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) + toprow.prompt.submit(**img2img_args) + toprow.submit.click(**img2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) @@ -978,7 +942,7 @@ def create_ui(): show_progress=False, ) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressImg2img", inputs=[dummy_component], @@ -991,46 +955,22 @@ def create_ui(): show_progress=False, ) - img2img_interrogate.click( + toprow.button_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) - img2img_deepbooru.click( + toprow.button_deepbooru.click( fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_styles, img2img_prompt_styles], - ) - - for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, styles], - outputs=[prompt, negative_prompt, styles], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) - - ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -1040,28 +980,35 @@ def create_ui(): (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields + *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) - modules.scripts.scripts_current = None + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + extra_tabs.__exit__() + + scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(variant='panel'): image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") @@ -1086,10 +1033,10 @@ def create_ui(): modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger() with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") - with gr.Row(variant="compact").style(equal_height=False): + with gr.Row(variant="compact", equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding", id="create_embedding"): @@ -1109,7 +1056,7 @@ def create_ui(): new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") @@ -1249,12 +1196,12 @@ def create_ui(): with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4) + gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4) gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, + fn=textual_inversion_ui.create_embedding, inputs=[ new_embedding_name, initialization_text, @@ -1269,7 +1216,7 @@ def create_ui(): ) create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, + fn=hypernetworks_ui.create_hypernetwork, inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, @@ -1289,7 +1236,7 @@ def create_ui(): ) run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1325,7 +1272,7 @@ def create_ui(): ) train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1359,7 +1306,7 @@ def create_ui(): ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py index 4863d8610..f9c5dd6bb 100644 --- a/modules/ui_checkpoint_merger.py +++ b/modules/ui_checkpoint_merger.py @@ -29,7 +29,7 @@ def modelmerger(*args): class UiCheckpointMerger: def __init__(self): with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(variant='compact'): self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") diff --git a/modules/ui_common.py b/modules/ui_common.py index 11eb2a4b2..303af9cd7 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -134,7 +134,7 @@ Requested path was: {f} with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4) + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4) generation_info = None with gr.Column(): @@ -223,20 +223,44 @@ Requested path was: {f} def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component] + + label = None + for comp in refresh_components: + label = getattr(comp, 'label', None) + if label is not None: + break + def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): - setattr(refresh_component, k, v) + for comp in refresh_components: + setattr(comp, k, v) - return gr.update(**(args or {})) + return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {})) - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh") refresh_button.click( fn=refresh, inputs=[], - outputs=[refresh_component] + outputs=refresh_components ) return refresh_button + +def setup_dialog(button_show, dialog, *, button_close=None): + """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.""" + + dialog.visible = False + + button_show.click( + fn=lambda: gr.update(visible=True), + inputs=[], + outputs=[dialog], + ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }") + + if button_close: + button_close.click(fn=None, _js="closePopup") + diff --git a/modules/ui_components.py b/modules/ui_components.py index 64451df7a..8f8a70885 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column): class FormGroup(FormComponent, gr.Group): - """Same as gr.Row but fits inside gradio forms""" + """Same as gr.Group but fits inside gradio forms""" def get_block_name(self): return "group" diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3e4fba7e..15a8b0bf4 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -164,7 +164,7 @@ def extension_table(): ext_status = ext.status style = "" - if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": style = STYLE_PRIMARY version_link = ext.version @@ -533,16 +533,20 @@ def create_ui(): apply = gr.Button(value=apply_label, variant="primary") check = gr.Button(value="Check for updates") extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") - extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) - extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False) html = "" - if shared.opts.disable_all_extensions != "none": - html = """ - - "Disable all extensions" was set, change it to "none" to load all extensions again - - """ + + if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none": + if shared.cmd_opts.disable_all_extensions: + msg = '"--disable-all-extensions" was used, remove it to load all extensions again' + elif shared.opts.disable_all_extensions != "none": + msg = '"Disable all extensions" was set, change it to "none" to load all extensions again' + elif shared.cmd_opts.disable_extra_extensions: + msg = '"--disable-extra-extensions" was used, remove it to load all extensions again' + html = f'{msg}' + info = gr.HTML(html) extensions_table = gr.HTML('Loading...') ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) @@ -565,7 +569,7 @@ def create_ui(): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") - available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False) + available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) @@ -574,7 +578,7 @@ def create_ui(): sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): - search_extensions_text = gr.Text(label="Search").style(container=False) + search_extensions_text = gr.Text(label="Search", container=False) install_result = gr.HTML() available_extensions_table = gr.HTML() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f2752f107..e0b932b94 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -101,16 +101,7 @@ class ExtraNetworksPage: def read_user_metadata(self, item): filename = item.get("filename", None) - basename, ext = os.path.splitext(filename) - metadata_filename = basename + '.json' - - metadata = {} - try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) - except Exception as e: - errors.display(e, f"reading extra network user metadata from {metadata_filename}") + metadata = extra_networks.get_user_metadata(filename) desc = metadata.get("description", None) if desc is not None: @@ -164,7 +155,7 @@ class ExtraNetworksPage: subdirs = {"": 1, **subdirs} subdirs_html = "".join([f""" - """ for subdir in subdirs]) @@ -356,7 +347,7 @@ def pages_in_preferred_order(pages): return sorted(pages, key=lambda x: tab_scores[x.name]) -def create_ui(container, button, tabname): +def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -364,48 +355,42 @@ def create_ui(container, button, tabname): ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname - with gr.Tabs(elem_id=tabname+"_extra_tabs"): - for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page): - elem_id = f"{tabname}_{page.id_page}_cards_html" - page_elem = gr.HTML('Loading...', elem_id=elem_id) - ui.pages.append(page_elem) + related_tabs = [] - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + for page in ui.stored_extra_pages: + with gr.Tab(page.title, id=page.id_page) as tab: + elem_id = f"{tabname}_{page.id_page}_cards_html" + page_elem = gr.HTML('Loading...', elem_id=elem_id) + ui.pages.append(page_elem) - editor = page.create_user_metadata_editor(ui, tabname) - editor.create_ui() - ui.user_metadata_editors.append(editor) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) - gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) - ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) + + related_tabs.append(tab) + + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") + button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) + checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - def toggle_visibility(is_visible): - is_visible = not is_visible + for tab in unrelated_tabs: + tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) - return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) - - def fill_tabs(is_empty): - """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time.""" + for tab in related_tabs: + tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) + def pages_html(): if not ui.pages_contents: - refresh() + return refresh() - if is_empty: - return True, *ui.pages_contents - - return True, *[gr.update() for _ in ui.pages_contents] - - state_visible = gr.State(value=False) - button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False) - - state_empty = gr.State(value=True) - button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False) + return ui.pages_contents def refresh(): for pg in ui.stored_extra_pages: @@ -415,6 +400,7 @@ def create_ui(container, button, tabname): return ui.pages_contents + interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 2bb0a222e..778850222 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models from modules.ui_extra_networks import quote_js +from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { @@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + def create_user_metadata_editor(self, ui, tabname): + return CheckpointUserMetadataEditor(ui, tabname, self) diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 000000000..25df0a807 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,66 @@ +import gradio as gr + +from modules import ui_extra_networks_user_metadata, sd_vae, shared +from modules.ui_common import create_refresh_button + + +class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_vae = None + + def save_user_metadata(self, name, desc, notes, vae): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + user_metadata["vae"] = vae + + self.write_user_metadata(name, user_metadata) + + def update_vae(self, name): + if name == shared.sd_model.sd_checkpoint_info.name_for_extra: + sd_vae.reload_vae_weights() + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + return [ + *values[0:5], + user_metadata.get('vae', ''), + ] + + def create_editor(self): + self.create_default_editor_elems() + + with gr.Row(): + self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") + create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_vae, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_notes, + self.select_vae, + ] + + self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) + self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index e53ccb428..514a45624 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): full_path = shared.hypernetworks[name] path, ext = os.path.splitext(full_path) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d1794e501..73134698e 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) path, ext = os.path.splitext(embedding.filename) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index c7dc11540..802e1ce71 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste def create_ui(): tab_index = gr.State(value=0) - with gr.Row().style(equal_height=False, variant='compact'): + with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py new file mode 100644 index 000000000..85eb3a641 --- /dev/null +++ b/modules/ui_prompt_styles.py @@ -0,0 +1,110 @@ +import gradio as gr + +from modules import shared, ui_common, ui_components, styles + +styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ +styles_materialize_symbol = '\U0001f4cb' # 📋 + + +def select_style(name): + style = shared.prompt_styles.styles.get(name) + existing = style is not None + empty = not name + + prompt = style.prompt if style else gr.update() + negative_prompt = style.negative_prompt if style else gr.update() + + return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty) + + +def save_style(name, prompt, negative_prompt): + if not name: + return gr.update(visible=False) + + style = styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.styles_filename) + + return gr.update(visible=True) + + +def delete_style(name): + if name == "": + return + + shared.prompt_styles.styles.pop(name, None) + shared.prompt_styles.save_styles(shared.styles_filename) + + return '', '', '' + + +def materialize_styles(prompt, negative_prompt, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])] + + +def refresh_styles(): + return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles)) + + +class UiPromptStyles: + def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): + self.tabname = tabname + + with gr.Row(elem_id=f"{tabname}_styles_row"): + self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") + edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles") + + with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog: + with gr.Row(): + self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") + ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + + with gr.Row(): + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + + with gr.Row(): + self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) + self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False) + self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close') + + self.selection.change( + fn=select_style, + inputs=[self.selection], + outputs=[self.prompt, self.neg_prompt, self.delete, self.save], + show_progress=False, + ) + + self.save.click( + fn=save_style, + inputs=[self.selection, self.prompt, self.neg_prompt], + outputs=[self.delete], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.delete.click( + fn=delete_style, + _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }', + inputs=[self.selection], + outputs=[self.selection, self.prompt, self.neg_prompt], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.materialize.click( + fn=materialize_styles, + inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) + + + + diff --git a/modules/ui_settings.py b/modules/ui_settings.py index a6076bf30..6dde4b6aa 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -158,7 +158,7 @@ class UiSettings: loadsave.create_ui() with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): - gr.HTML('Download system info
(or open as text in a new page)', elem_id="sysinfo_download") + gr.HTML('Download system info
(or open as text in a new page)', elem_id="sysinfo_download") with gr.Row(): with gr.Column(scale=1): diff --git a/requirements.txt b/requirements.txt index b3f8a7f41..9a47d6d0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ blendmodes clean-fid einops gfpgan -gradio==3.32.0 +gradio==3.39.0 inflection jsonmerge kornia @@ -30,4 +30,4 @@ tomesd torch torchdiffeq torchsde -transformers==4.25.1 +transformers==4.30.2 diff --git a/requirements_versions.txt b/requirements_versions.txt index d07ab456c..dec45df38 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,13 +1,13 @@ -GitPython==3.1.30 +GitPython==3.1.32 Pillow==9.5.0 -accelerate==0.18.0 +accelerate==0.21.0 basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 -gradio==3.32.0 +gradio==3.39.0 httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 @@ -22,10 +22,10 @@ pytorch_lightning==1.9.4 realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 -scikit-image==0.20.0 -timm==0.6.7 -tomesd==0.1.2 +scikit-image==0.21.0 +timm==0.9.2 +tomesd==0.1.3 torch torchdiffeq==0.2.3 torchsde==0.2.5 -transformers==4.25.1 +transformers==4.30.2 diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 103bf1048..d37b428fc 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -67,14 +67,6 @@ def apply_order(p, x, xs): p.prompt = prompt_tmp + p.prompt -def apply_sampler(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - raise RuntimeError(f"Unknown sampler: {x}") - - p.sampler_name = sampler_name - - def confirm_samplers(p, xs): for x in xs: if x.lower() not in sd_samplers.samplers_map: @@ -224,8 +216,9 @@ axis_options = [ AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")), AxisOption("Sigma Churn", float, apply_field("s_churn")), diff --git a/style.css b/style.css index 6c92d6e78..8c1f273c0 100644 --- a/style.css +++ b/style.css @@ -8,6 +8,7 @@ --checkbox-label-gap: 0.25em 0.1em; --section-header-text-size: 12pt; --block-background-fill: transparent; + } .block.padded:not(.gradio-accordion) { @@ -42,7 +43,8 @@ div.form{ .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, -.block.gradio-colorpicker +.block.gradio-colorpicker, +div.gradio-group { border-width: 0 !important; box-shadow: none !important; @@ -133,6 +135,15 @@ a{ cursor: pointer; } +div.styler{ + border: none; + background: var(--background-fill-primary); +} + +.block.gradio-textbox{ + overflow: visible !important; +} + /* general styled components */ @@ -164,7 +175,7 @@ a{ .checkboxes-row > div{ flex: 0; white-space: nowrap; - min-width: auto; + min-width: auto !important; } button.custom-button{ @@ -388,6 +399,7 @@ div#extras_scale_to_tab div.form{ #quicksettings > div, #quicksettings > fieldset{ max-width: 24em; min-width: 24em; + width: 24em; padding: 0; border: none; box-shadow: none; @@ -482,6 +494,13 @@ table.popup-table .link{ font-size: 18pt; } +#settings .settings-info{ + max-width: 48em; + border: 1px dotted #777; + margin: 0; + padding: 1em; +} + /* live preview */ .progressDiv{ @@ -767,9 +786,14 @@ footer { /* extra networks UI */ .extra-network-cards{ - height: 725px; - overflow: scroll; + height: calc(100vh - 24rem); + overflow: clip scroll; resize: vertical; + min-height: 52rem; +} + +.extra-networks > div.tab-nav{ + min-height: 3.4rem; } .extra-networks > div > [id *= '_extra_']{ @@ -784,10 +808,12 @@ footer { margin: 0 0.15em; } .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort{ - display: inline-block; +.extra-networks .tab-nav .sort, +.extra-networks .tab-nav .show-dirs +{ margin: 0.3em; align-self: center; + width: auto; } .extra-networks .tab-nav .search { @@ -972,3 +998,16 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata-buttons{ margin-top: 1.5em; } + + + + +div.block.gradio-box.popup-dialog, .popup-dialog { + width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ + margin-top: 1em; +} diff --git a/webui.py b/webui.py index 844e25481..6d36f8806 100644 --- a/webui.py +++ b/webui.py @@ -14,7 +14,6 @@ from typing import Iterable from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from packaging import version import logging @@ -50,6 +49,7 @@ startup_timer.record("setup paths") import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") + from modules import extra_networks from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 @@ -58,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states +from modules import shared + +if not shared.cmd_opts.skip_version_check: + errors.check_versions() + import modules.codeformer_model as codeformer -import modules.face_restoration import modules.gfpgan_model as gfpgan +from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states +import modules.face_restoration import modules.img2img import modules.lowvram @@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy(): asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) -def check_versions(): - if shared.cmd_opts.skip_version_check: - return - - expected_torch_version = "2.0.0" - - if version.parse(torch.__version__) < version.parse(expected_torch_version): - errors.print_error_explanation(f""" -You are running torch {torch.__version__}. -The program is tested to work with torch {expected_torch_version}. -To reinstall the desired version, run with commandline flag --reinstall-torch. -Beware that this will cause a lot of large files to be downloaded, as well as -there are reports of issues with training tab on the latest version. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - expected_xformers_version = "0.0.20" - if shared.xformers_available: - import xformers - - if version.parse(xformers.__version__) < version.parse(expected_xformers_version): - errors.print_error_explanation(f""" -You are running xformers {xformers.__version__}. -The program is tested to work with xformers {expected_xformers_version}. -To reinstall the desired version, run with commandline flag --reinstall-xformers. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - def restore_config_state_file(): config_state_file = shared.opts.restore_config_state_file if config_state_file == "": @@ -237,7 +211,7 @@ def configure_sigint_handler(): def configure_opts_onchange(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) @@ -248,7 +222,6 @@ def initialize(): fix_asyncio_event_loop_policy() validate_tls_options() configure_sigint_handler() - check_versions() modelloader.cleanup_models() configure_opts_onchange() @@ -368,6 +341,7 @@ def api_only(): setup_middleware(app) api = create_api(app) + modules.script_callbacks.before_ui_callback() modules.script_callbacks.app_started_callback(None, app) print(f"Startup time: {startup_timer.summary()}.")