diff --git a/.eslintrc.js b/.eslintrc.js index f33aca09f..4777c276e 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -87,5 +87,11 @@ module.exports = { modalNextImage: "readonly", // token-counters.js setupTokenCounters: "readonly", + // localStorage.js + localSet: "readonly", + localGet: "readonly", + localRemove: "readonly", + // resizeHandle.js + setupResizeHandle: "writable" } }; diff --git a/CHANGELOG.md b/CHANGELOG.md index 461fef9a7..07798b5a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,142 @@ +## 1.6.0 + +### Features: + * refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371) + * add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards + * add style editor dialog + * hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181)) + * option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227)) + * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542)) + * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers: + * makes all of them work with img2img + * makes prompt composition posssible (AND) + * makes them available for SDXL + * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808)) + * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599)) + * textual inversion inference support for SDXL + * extra networks UI: show metadata for SD checkpoints + * checkpoint merger: add metadata support + * prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177)) + * VAE: allow selecting own VAE for each checkpoint (in user metadata editor) + * VAE: add selected VAE to infotext + * options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551)) + * add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723)) + * change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it + * show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707)) + * add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models + * prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457)) + +### Minor: + * img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515)) + * postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479)) + * XYZ: in the axis labels, remove pathnames from model filenames + * XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298)) + * XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491)) + * add gradio version warning + * sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297)) + * use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326)) + * move some settings to their own section: img2img, VAE + * add checkbox to show/hide dirs for extra networks + * Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311)) + * gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355)) + * sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521)) + * update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352)) + * option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338)) + * enable cond cache by default + * git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230)) + * allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379)) + * automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254)) + * put commonly used samplers on top, make DPM++ 2M Karras the default choice + * zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727)) + * option to cache Lora networks in memory + * rework hires fix UI to use accordion + * face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back + * change quicksettings items to have variable width + * Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503)) + * Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console + * support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510)) + * add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564)) + * support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584)) + * make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634)) + * configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648)) + * make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645)) + * more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639)) + * make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635)) + * make progress bar work independently from live preview display which results in it being updated a lot more often + * forbid Full live preview method for medvram and add a setting to undo the forbidding + * make it possible to localize tooltips and placeholders + +### Extensions and API: + * gradio 3.41.2 + * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd + * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world') + * properly clear the total console progressbar when using txt2img and img2img from API + * add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294)) + * shared.py and webui.py split into many files + * add --loglevel commandline argument for logging + * add a custom UI element that combines accordion and checkbox + * avoid importing gradio in tests because it spams warnings + * put infotext label for setting into OptionInfo definition rather than in a separate list + * make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470)) + * option to make scripts UI without gr.Group + * add a way for scripts to register a callback for before/after just a single component's creation + * use dataclass for StableDiffusionProcessing + * store patches for Lora in a specialized module instead of inside torch + * support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698)) + * add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616)) + * dump current stack traces when exiting with SIGINT + * add type annotations for extra fields of shared.sd_model + +### Bug Fixes: + * Don't crash if out of local storage quota for javascriot localStorage + * XYZ plot do not fail if an exception occurs + * fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269)) + * localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307)) + * fix sdxl model invalid configuration after the hijack + * correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304)) + * open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318)) + * prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319)) + * add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327)) + * fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387)) + * fix options in main UI misbehaving when there's just one element + * make it possible to use a sampler from infotext even if it's hidden in the dropdown + * fix styles missing from the prompt in infotext when making a grid of batch of multiplie images + * prevent bogus progress output in console when calculating hires fix dimensions + * fix --use-textbox-seed + * fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466)) + * properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463)) + * MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526)) + * add second_order to samplers that mistakenly didn't have it + * when refreshing cards in extra networks UI, do not discard user's custom resolution + * fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509)) + * fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588)) + * fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586)) + * fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596)) + * auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603)) + * fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607)) + * fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638)) + * fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633)) + * attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630)) + * implement missing undo hijack for SDXL + * fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684)) + * fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689)) + * fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685)) + * fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667)) + * create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717)) + * prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version + * set devices.dtype_unet correctly + * run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745)) + * fix error that causes some extra networks to be disabled if both and are present in the prompt + * fix defaults settings page breaking when any of main UI tabs are hidden + * fix incorrect save/display of new values in Defaults page in settings + * fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working + * fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * fix style editing dialog breaking if it's opened in both img2img and txt2img tabs + + ## 1.5.2 ### Bug Fixes: diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 000000000..2c781aff4 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,7 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: + - given-names: AUTOMATIC1111 +title: "Stable Diffusion Web UI" +date-released: 2022-08-22 +url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui" diff --git a/README.md b/README.md index 007d6fde2..c7a4e363f 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Clip skip - Hypernetworks - Loras (same as Hypernetworks but more pretty) -- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt +- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt - Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API @@ -93,7 +93,10 @@ A browser interface based on Gradio library for Stable Diffusion. - Reorder elements in the UI from settings screen ## Installation and Running -Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: +- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) +- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page) Alternatively, use online services (like Google Colab): @@ -115,7 +118,7 @@ Alternatively, use online services (like Google Colab): 1. Install the dependencies: ```bash # Debian-based: -sudo apt install wget git python3 python3-venv +sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0 # Red Hat-based: sudo dnf install wget git python3 # Arch-based: @@ -123,7 +126,7 @@ sudo pacman -S wget git python3 ``` 2. Navigate to the directory you would like the webui to be installed and execute the following command: ```bash -bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh) +wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh ``` 3. Run `webui.sh`. 4. Check `webui-user.sh` for options. @@ -169,5 +172,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf +- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ba2945c6f..005ff32cb 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def __init__(self): super().__init__('lora') + self.errors = {} + """mapping of network names to the number of errors the network had during operation""" + def activate(self, p, params_list): additional = shared.opts.sd_lora + self.errors.clear() + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): - pass + if self.errors: + p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) + + self.errors.clear() diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py new file mode 100644 index 000000000..b394d8e9e --- /dev/null +++ b/extensions-builtin/Lora/lora_patches.py @@ -0,0 +1,31 @@ +import torch + +import networks +from modules import patches + + +class LoraPatches: + def __init__(self): + self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) + self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) + self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) + self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) + self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) + self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) + self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) + self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) + self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) + self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) + + def undo(self): + self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') + self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') + self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') + self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') + self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') + self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') + self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') + self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') + self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') + self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a18d69eb..d8e8dfb7f 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -133,7 +133,7 @@ class NetworkModule: return 1.0 - def finalize_updown(self, updown, orig_weight, output_shape): + def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) @@ -145,7 +145,10 @@ class NetworkModule: if orig_weight.size().numel() == updown.size().numel(): updown = updown.reshape(orig_weight.shape) - return updown * self.calc_scale() * self.multiplier() + if ex_bias is not None: + ex_bias = ex_bias * self.multiplier() + + return updown * self.calc_scale() * self.multiplier(), ex_bias def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index 109b4c2c5..bf6930e96 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule): super().__init__(net, weights) self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") def calc_updown(self, orig_weight): output_shape = self.weight.shape updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None - return self.finalize_updown(updown, orig_weight, output_shape) + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 000000000..ce4501580 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 17cbe1bb7..96f935b23 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,12 +1,15 @@ +import logging import os import re +import lora_patches import network import network_lora import network_hada import network_ia3 import network_lokr import network_full +import network_norm import torch from typing import Union @@ -19,6 +22,7 @@ module_types = [ network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), ] @@ -31,6 +35,8 @@ suffix_conversion = { "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } @@ -190,11 +196,19 @@ def load_network(name, network_on_disk): net.modules[key] = net_module if keys_failed_to_match: - print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") return net +def purge_networks_from_memory(): + while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: + name = next(iter(networks_in_memory)) + networks_in_memory.pop(name, None) + + devices.torch_gc() + + def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): already_loaded = {} @@ -212,15 +226,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No failed_to_load_networks = [] - for i, name in enumerate(names): + for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): net = already_loaded.get(name, None) - network_on_disk = networks_on_disk[i] - if network_on_disk is not None: + if net is None: + net = networks_in_memory.get(name) + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: net = load_network(name, network_on_disk) + + networks_in_memory.pop(name, None) + networks_in_memory[name] = net except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue @@ -231,7 +249,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if net is None: failed_to_load_networks.append(name) - print(f"Couldn't find network with name {name}") + logging.info(f"Couldn't find network with name {name}") continue net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 @@ -240,23 +258,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No loaded_networks.append(net) if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + + purge_networks_from_memory() -def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) + bias_backup = getattr(self, "network_bias_backup", None) - if weights_backup is None: + if weights_backup is None and bias_backup is None: return - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) + if weights_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + if bias_backup is not None: + if isinstance(self, torch.nn.MultiheadAttention): + self.out_proj.bias.copy_(bias_backup) + else: + self.bias.copy_(bias_backup) else: - self.weight.copy_(weights_backup) + if isinstance(self, torch.nn.MultiheadAttention): + self.out_proj.bias = None + else: + self.bias = None -def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. @@ -271,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) - if weights_backup is None: + if weights_backup is None and wanted_names != (): + if current_names != (): + raise RuntimeError("no backup weights found and current weights are not unchanged") + if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: @@ -279,21 +315,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_weights_backup = weights_backup + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None: + if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: + bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) + elif getattr(self, 'bias', None) is not None: + bias_backup = self.bias.to(devices.cpu, copy=True) + else: + bias_backup = None + self.network_bias_backup = bias_backup + if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): - with torch.no_grad(): - updown = module.calc_updown(self.weight) + try: + with torch.no_grad(): + updown, ex_bias = module.calc_updown(self.weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: - # inpainting model. zero pad updown to make channel[1] 4 to 9 - updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown - continue + self.weight += updown + if ex_bias is not None and hasattr(self, 'bias'): + if self.bias is None: + self.bias = torch.nn.Parameter(ex_bias) + else: + self.bias += ex_bias + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) @@ -301,21 +357,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - with torch.no_grad(): - updown_q = module_q.calc_updown(self.in_proj_weight) - updown_k = module_k.calc_updown(self.in_proj_weight) - updown_v = module_v.calc_updown(self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - updown_out = module_out.calc_updown(self.out_proj.weight) + try: + with torch.no_grad(): + updown_q, _ = module_q.calc_updown(self.in_proj_weight) + updown_k, _ = module_k.calc_updown(self.in_proj_weight) + updown_v, _ = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) - self.in_proj_weight += updown_qkv - self.out_proj.weight += updown_out - continue + self.in_proj_weight += updown_qkv + self.out_proj.weight += updown_out + if ex_bias is not None: + if self.out_proj.bias is None: + self.out_proj.bias = torch.nn.Parameter(ex_bias) + else: + self.out_proj.bias += ex_bias + + except RuntimeError as e: + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 + + continue if module is None: continue - print(f'failed to calculate network weights for layer {network_layer_name}') + logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 self.network_current_names = wanted_names @@ -342,7 +410,7 @@ def network_forward(module, input, original_forward): if module is None: continue - y = module.forward(y, input) + y = module.forward(input, y) return y @@ -354,44 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_Linear_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Linear_forward_before_network) + return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) - return torch.nn.Linear_forward_before_network(self, input) + return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + return originals.Linear_load_state_dict(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + return network_forward(self, input, originals.Conv2d_forward) network_apply_weights(self) - return torch.nn.Conv2d_forward_before_network(self, input) + return originals.Conv2d_forward(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + return originals.Conv2d_load_state_dict(self, *args, **kwargs) + + +def network_GroupNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, originals.GroupNorm_forward) + + network_apply_weights(self) + + return originals.GroupNorm_forward(self, input) + + +def network_GroupNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return originals.GroupNorm_load_state_dict(self, *args, **kwargs) + + +def network_LayerNorm_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, originals.LayerNorm_forward) + + network_apply_weights(self) + + return originals.LayerNorm_forward(self, input) + + +def network_LayerNorm_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) - return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_forward(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): @@ -459,9 +557,14 @@ def infotext_pasted(infotext, params): params["Prompt"] += "\n" + "".join(added) +originals: lora_patches.LoraPatches = None + +extra_network_lora = None + available_networks = {} available_network_aliases = {} loaded_networks = [] +networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index cd28afc92..ef23968c5 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,57 +1,30 @@ import re -import torch import gradio as gr from fastapi import FastAPI import network import networks import lora # noqa:F401 +import lora_patches import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared + def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_network - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network + networks.originals.undo() def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_network = extra_networks_lora.ExtraNetworkLora() - extra_networks.register_extra_network(extra_network) - extra_networks.register_extra_network_alias(extra_network, "lyco") + networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora() + extra_networks.register_extra_network(networks.extra_network_lora) + extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco") -if not hasattr(torch.nn, 'Linear_forward_before_network'): - torch.nn.Linear_forward_before_network = torch.nn.Linear.forward - -if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): - torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict - -if not hasattr(torch.nn, 'Conv2d_forward_before_network'): - torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward - -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): - torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict - -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): - torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward - -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): - torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict - -torch.nn.Linear.forward = networks.network_Linear_forward -torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict -torch.nn.Conv2d.forward = networks.network_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict +networks.originals = lora_patches.LoraPatches() script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -65,6 +38,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), + "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), })) @@ -121,3 +95,5 @@ def infotext_pasted(infotext, d): script_callbacks.on_infotext_pasted(infotext_pasted) + +shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2ca997f7c..390d9dde3 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) with gr.Column(scale=1, min_width=120): - generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") + generate_random_prompt = gr.Button('Generate', size="lg", scale=1) self.edit_notes = gr.TextArea(label='Notes', lines=4) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 3629e5c0c..55409a782 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): item = { "name": name, "filename": lora_on_disk.filename, + "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename), + "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 30199dcd6..45c7600ac 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -12,8 +12,22 @@ onUiLoaded(async() => { "Sketch": elementIDs.sketch }; + // Helper functions // Get active tab + + /** + * Waits for an element to be present in the DOM. + */ + const waitForElement = (id) => new Promise(resolve => { + const checkForElement = () => { + const element = document.querySelector(id); + if (element) return resolve(element); + setTimeout(checkForElement, 100); + }; + checkForElement(); + }); + function getActiveTab(elements, all = false) { const tabs = elements.img2imgTabs.querySelectorAll("button"); @@ -34,7 +48,7 @@ onUiLoaded(async() => { // Wait until opts loaded async function waitForOpts() { - for (;;) { + for (; ;) { if (window.opts && Object.keys(window.opts).length) { return window.opts; } @@ -42,6 +56,11 @@ onUiLoaded(async() => { } } + // Detect whether the element has a horizontal scroll bar + function hasHorizontalScrollbar(element) { + return element.scrollWidth > element.clientWidth; + } + // Function for defining the "Ctrl", "Shift" and "Alt" keys function isModifierKey(event, key) { switch (key) { @@ -201,7 +220,8 @@ onUiLoaded(async() => { canvas_hotkey_overlap: "KeyO", canvas_disabled_functions: [], canvas_show_tooltip: true, - canvas_blur_prompt: false + canvas_auto_expand: true, + canvas_blur_prompt: false, }; const functionMap = { @@ -249,7 +269,7 @@ onUiLoaded(async() => { input?.addEventListener("input", () => restoreImgRedMask(elements)); } - function applyZoomAndPan(elemId) { + function applyZoomAndPan(elemId, isExtension = true) { const targetElement = gradioApp().querySelector(elemId); if (!targetElement) { @@ -361,6 +381,12 @@ onUiLoaded(async() => { panY: 0 }; + if (isExtension) { + targetElement.style.overflow = "hidden"; + } + + targetElement.isZoomed = false; + fixCanvas(); targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; @@ -371,8 +397,27 @@ onUiLoaded(async() => { toggleOverlap("off"); fullScreenMode = false; + const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']"); + if (closeBtn) { + closeBtn.addEventListener("click", resetZoom); + } + + if (canvas && isExtension) { + const parentElement = targetElement.closest('[id^="component-"]'); + if ( + canvas && + parseFloat(canvas.style.width) > parentElement.offsetWidth && + parseFloat(targetElement.style.width) > parentElement.offsetWidth + ) { + fitToElement(); + return; + } + + } + if ( canvas && + !isExtension && parseFloat(canvas.style.width) > 865 && parseFloat(targetElement.style.width) > 865 ) { @@ -381,9 +426,6 @@ onUiLoaded(async() => { } targetElement.style.width = ""; - if (canvas) { - targetElement.style.height = canvas.style.height; - } } // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements @@ -439,7 +481,7 @@ onUiLoaded(async() => { // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables function updateZoom(newZoomLevel, mouseX, mouseY) { - newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15)); + newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15)); elemData[elemId].panX += mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel; @@ -450,6 +492,10 @@ onUiLoaded(async() => { targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; toggleOverlap("on"); + if (isExtension) { + targetElement.style.overflow = "visible"; + } + return newZoomLevel; } @@ -472,10 +518,12 @@ onUiLoaded(async() => { fullScreenMode = false; elemData[elemId].zoomLevel = updateZoom( elemData[elemId].zoomLevel + - (operation === "+" ? delta : -delta), + (operation === "+" ? delta : -delta), zoomPosX - targetElement.getBoundingClientRect().left, zoomPosY - targetElement.getBoundingClientRect().top ); + + targetElement.isZoomed = true; } } @@ -489,10 +537,19 @@ onUiLoaded(async() => { //Reset Zoom targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + let parentElement; + + if (isExtension) { + parentElement = targetElement.closest('[id^="component-"]'); + } else { + parentElement = targetElement.parentElement; + } + + // Get element and screen dimensions const elementWidth = targetElement.offsetWidth; const elementHeight = targetElement.offsetHeight; - const parentElement = targetElement.parentElement; + const screenWidth = parentElement.clientWidth; const screenHeight = parentElement.clientHeight; @@ -545,8 +602,12 @@ onUiLoaded(async() => { if (!canvas) return; - if (canvas.offsetWidth > 862) { - targetElement.style.width = canvas.offsetWidth + "px"; + if (canvas.offsetWidth > 862 || isExtension) { + targetElement.style.width = (canvas.offsetWidth + 2) + "px"; + } + + if (isExtension) { + targetElement.style.overflow = "visible"; } if (fullScreenMode) { @@ -648,8 +709,48 @@ onUiLoaded(async() => { mouseY = e.offsetY; } + // Simulation of the function to put a long image into the screen. + // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. + // We hide the image and show it to the user when it is ready. + + targetElement.isExpanded = false; + function autoExpand() { + const canvas = document.querySelector(`${elemId} canvas[key="interface"]`); + if (canvas) { + if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) { + targetElement.style.visibility = "hidden"; + setTimeout(() => { + fitToScreen(); + resetZoom(); + targetElement.style.visibility = "visible"; + targetElement.isExpanded = true; + }, 10); + } + } + } + targetElement.addEventListener("mousemove", getMousePosition); + //observers + // Creating an observer with a callback function to handle DOM changes + const observer = new MutationObserver((mutationsList, observer) => { + for (let mutation of mutationsList) { + // If the style attribute of the canvas has changed, by observation it happens only when the picture changes + if (mutation.type === 'attributes' && mutation.attributeName === 'style' && + mutation.target.tagName.toLowerCase() === 'canvas') { + targetElement.isExpanded = false; + setTimeout(resetZoom, 10); + } + } + }); + + // Apply auto expand if enabled + if (hotkeysConfig.canvas_auto_expand) { + targetElement.addEventListener("mousemove", autoExpand); + // Set up an observer to track attribute changes + observer.observe(targetElement, {attributes: true, childList: true, subtree: true}); + } + // Handle events only inside the targetElement let isKeyDownHandlerAttached = false; @@ -754,6 +855,11 @@ onUiLoaded(async() => { if (isMoving && elemId === activeElement) { updatePanPosition(e.movementX, e.movementY); targetElement.style.pointerEvents = "none"; + + if (isExtension) { + targetElement.style.overflow = "visible"; + } + } else { targetElement.style.pointerEvents = "auto"; } @@ -764,13 +870,93 @@ onUiLoaded(async() => { isMoving = false; }; + // Checks for extension + function checkForOutBox() { + const parentElement = targetElement.closest('[id^="component-"]'); + if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) { + resetZoom(); + targetElement.isExpanded = true; + } + + if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) { + resetZoom(); + } + + if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) { + resetZoom(); + } + } + + if (isExtension) { + targetElement.addEventListener("mousemove", checkForOutBox); + } + + + window.addEventListener('resize', (e) => { + resetZoom(); + + if (isExtension) { + targetElement.isExpanded = false; + targetElement.isZoomed = false; + } + }); + gradioApp().addEventListener("mousemove", handleMoveByKey); + + } - applyZoomAndPan(elementIDs.sketch); - applyZoomAndPan(elementIDs.inpaint); - applyZoomAndPan(elementIDs.inpaintSketch); + applyZoomAndPan(elementIDs.sketch, false); + applyZoomAndPan(elementIDs.inpaint, false); + applyZoomAndPan(elementIDs.inpaintSketch, false); // Make the function global so that other extensions can take advantage of this solution - window.applyZoomAndPan = applyZoomAndPan; + const applyZoomAndPanIntegration = async(id, elementIDs) => { + const mainEl = document.querySelector(id); + if (id.toLocaleLowerCase() === "none") { + for (const elementID of elementIDs) { + const el = await waitForElement(elementID); + if (!el) break; + applyZoomAndPan(elementID); + } + return; + } + + if (!mainEl) return; + mainEl.addEventListener("click", async() => { + for (const elementID of elementIDs) { + const el = await waitForElement(elementID); + if (!el) break; + applyZoomAndPan(elementID); + } + }, {once: true}); + }; + + window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image") + + window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension + + /* + The function `applyZoomAndPanIntegration` takes two arguments: + + 1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click. + If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event. + + 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument. + If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event. + + Example usage: + applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet". + */ + + // More examples + // Add integration with ControlNet txt2img One TAB + // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + + // Add integration with ControlNet txt2img Tabs + // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`)); + + // Add integration with Inpaint Anything + // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]); }); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 380176ce2..2d8d2d1c0 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), + "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), })) diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css index 6bcc9570c..5d8054e65 100644 --- a/extensions-builtin/canvas-zoom-and-pan/style.css +++ b/extensions-builtin/canvas-zoom-and-pan/style.css @@ -61,3 +61,6 @@ to {opacity: 1;} } +.styler { + overflow:inherit !important; +} \ No newline at end of file diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a05e10d86..983f87ff0 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,5 +1,7 @@ +import math + import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules.ui_components import FormColumn @@ -19,18 +21,38 @@ class ExtraOptionsSection(scripts.Script): def ui(self, is_img2img): self.comps = [] self.setting_names = [] + self.infotext_fields = [] + extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + + mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): - for setting_name in shared.opts.extra_options: - with FormColumn(): - comp = ui_settings.create_setting_component(setting_name) + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): - self.comps.append(comp) - self.setting_names.append(setting_name) + row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) + + for row in range(row_count): + with gr.Row(): + for col in range(shared.opts.extra_options_cols): + index = row * shared.opts.extra_options_cols + col + if index >= len(extra_options): + break + + setting_name = extra_options[index] + + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + setting_infotext_name = mapping.get(setting_name) + if setting_infotext_name is not None: + self.infotext_fields.append((comp, setting_infotext_name)) def get_settings_values(): - return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + res = [ui_settings.get_value_for_setting(key) for key in self.setting_names] + return res[0] if len(res) == 1 else res interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) @@ -43,6 +65,10 @@ class ExtraOptionsSection(scripts.Script): shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), - "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") + "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), + "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() })) + + diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 12cae4b75..652f07ac7 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -20,7 +20,13 @@ function reportWindowSize() { var button = gradioApp().getElementById(tab + '_generate_box'); var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); target.insertBefore(button, target.firstElementChild); + + gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile); } } window.addEventListener("resize", reportWindowSize); + +onUiLoaded(function() { + reportWindowSize(); +}); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5582a6e5d..493f31af2 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,20 +1,38 @@ +function toggleCss(key, css, enable) { + var style = document.getElementById(key); + if (enable && !style) { + style = document.createElement('style'); + style.id = key; + style.type = 'text/css'; + document.head.appendChild(style); + } + if (style && !enable) { + document.head.removeChild(style); + } + if (style) { + style.innerHTML == ''; + style.appendChild(document.createTextNode(css)); + } +} + function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); - var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + var searchDiv = gradioApp().getElementById(tabname + '_extra_search'); + var search = searchDiv.querySelector('textarea'); var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); + var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); - search.classList.add('search'); - sort.classList.add('sort'); - sortOrder.classList.add('sortorder'); sort.dataset.sortkey = 'sortDefault'; - tabs.appendChild(search); + tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); + tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) { }); extraNetworksApplyFilter[tabname] = applyFilter; + + var showDirsUpdate = function() { + var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; + toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); + localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); + }; + showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; + showDirs.addEventListener("change", showDirsUpdate); + showDirsUpdate(); } function applyExtraNetworkFilter(tabname) { @@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksSearchButton(tabs_id, event) { - var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); + var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); @@ -222,6 +249,15 @@ function popup(contents) { globalPopup.style.display = "flex"; } +var storedPopupIds = {}; +function popupId(id) { + if (!storedPopupIds[id]) { + storedPopupIds[id] = gradioApp().getElementById(id); + } + + popup(storedPopupIds[id]); +} + function extraNetworksShowMetadata(text) { var elem = document.createElement('pre'); elem.classList.add('popup-metadata'); @@ -305,7 +341,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) { newDiv.innerHTML = data.html; var newCard = newDiv.firstElementChild; - newCard.style = ''; + newCard.style.display = ''; card.parentElement.insertBefore(newCard, card); card.parentElement.removeChild(card); } diff --git a/javascript/hints.js b/javascript/hints.js index 4167cb28b..6de9372e8 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) { tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); } }); + +onUiLoaded(function() { + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip && comp.props.elem_id) { + var elem = gradioApp().getElementById(comp.props.elem_id); + if (elem) { + elem.title = comp.props.webui_tooltip; + } + } + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 677e95c1b..c21d396ee 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -136,6 +136,11 @@ function setupImageForLightbox(e) { var event = isFirefox ? 'mousedown' : 'click'; e.addEventListener(event, function(evt) { + if (evt.button == 1) { + open(evt.target.src); + evt.preventDefault(); + return; + } if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js new file mode 100644 index 000000000..f2839852e --- /dev/null +++ b/javascript/inputAccordion.js @@ -0,0 +1,37 @@ +var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + var elem = mutationRecord.target; + var open = elem.classList.contains('open'); + + var accordion = elem.parentNode; + accordion.classList.toggle('input-accordion-open', open); + + var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + checkbox.checked = open; + updateInput(checkbox); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + extra.style.display = open ? "" : "none"; + } + }); +}); + +function inputAccordionChecked(id, checked) { + var label = gradioApp().querySelector('#' + id + " .label-wrap"); + if (label.classList.contains('open') != checked) { + label.click(); + } +} + +onUiLoaded(function() { + for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { + var labelWrap = accordion.querySelector('.label-wrap'); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); + } + } +}); diff --git a/javascript/localStorage.js b/javascript/localStorage.js new file mode 100644 index 000000000..dc1a36c32 --- /dev/null +++ b/javascript/localStorage.js @@ -0,0 +1,26 @@ + +function localSet(k, v) { + try { + localStorage.setItem(k, v); + } catch (e) { + console.warn(`Failed to save ${k} to localStorage: ${e}`); + } +} + +function localGet(k, def) { + try { + return localStorage.getItem(k); + } catch (e) { + console.warn(`Failed to load ${k} from localStorage: ${e}`); + } + + return def; +} + +function localRemove(k) { + try { + return localStorage.removeItem(k); + } catch (e) { + console.warn(`Failed to remove ${k} from localStorage: ${e}`); + } +} diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e..8f00c1868 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,11 +11,11 @@ var ignore_ids_for_localization = { train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', - setting_random_artist_categories: 'SPAN', - setting_face_restoration_model: 'SPAN', - setting_realesrgan_enabled_models: 'SPAN', - extras_upscaler_1: 'SPAN', - extras_upscaler_2: 'SPAN', + setting_random_artist_categories: 'OPTION', + setting_face_restoration_model: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + extras_upscaler_1: 'OPTION', + extras_upscaler_2: 'OPTION', }; var re_num = /^[.\d]+$/; @@ -107,12 +107,41 @@ function processNode(node) { }); } +function localizeWholePage() { + processNode(gradioApp()); + + function elem(comp) { + var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id; + return gradioApp().getElementById(elem_id); + } + + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip) { + let e = elem(comp); + + let tl = e ? getTranslation(e.title) : undefined; + if (tl !== undefined) { + e.title = tl; + } + } + if (comp.props.placeholder) { + let e = elem(comp); + let textbox = e ? e.querySelector('[placeholder]') : null; + + let tl = textbox ? getTranslation(textbox.placeholder) : undefined; + if (tl !== undefined) { + textbox.placeholder = tl; + } + } + } +} + function dumpTranslations() { if (!hasLocalization()) { // If we don't have any localization, // we will not have traversed the app to find // original_lines, so do that now. - processNode(gradioApp()); + localizeWholePage(); } var dumped = {}; if (localization.rtl) { @@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() { }); }); - processNode(gradioApp()); + localizeWholePage(); if (localization.rtl) { // if the language is from right to left, (new MutationObserver((mutations, observer) => { // wait for the style to load diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 29299787e..777614954 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var dateStart = new Date(); var wasEverActive = false; var parentProgressbar = progressbarContainer.parentNode; - var parentGallery = gallery ? gallery.parentNode : null; var divProgress = document.createElement('div'); divProgress.className = 'progressDiv'; @@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.appendChild(divInner); parentProgressbar.insertBefore(divProgress, progressbarContainer); - if (parentGallery) { - var livePreview = document.createElement('div'); - livePreview.className = 'livePreview'; - parentGallery.insertBefore(livePreview, gallery); - } + var livePreview = null; var removeProgressBar = function() { + if (!divProgress) return; + setTitle(""); parentProgressbar.removeChild(divProgress); - if (parentGallery) parentGallery.removeChild(livePreview); + if (gallery && livePreview) gallery.removeChild(livePreview); atEnd(); + + divProgress = null; }; - var fun = function(id_task, id_live_preview) { - request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { + var funProgress = function(id_task) { + request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) { if (res.completed) { removeProgressBar(); return; } - var rect = progressbarContainer.getBoundingClientRect(); - - if (rect.width) { - divProgress.style.width = rect.width + "px"; - } - let progressText = ""; divInner.style.width = ((res.progress || 0) * 100.0) + '%'; @@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre progressText += " ETA: " + formatTime(res.eta); } - setTitle(progressText); if (res.textinfo && res.textinfo.indexOf("\n") == -1) { @@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre return; } + if (onProgress) { + onProgress(res); + } + + setTimeout(() => { + funProgress(id_task, res.id_live_preview); + }, opts.live_preview_refresh_period || 500); + }, function() { + removeProgressBar(); + }); + }; + + var funLivePreview = function(id_task, id_live_preview) { + request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { + if (!divProgress) { + return; + } if (res.live_preview && gallery) { - rect = gallery.getBoundingClientRect(); - if (rect.width) { - livePreview.style.width = rect.width + "px"; - livePreview.style.height = rect.height + "px"; - } - var img = new Image(); img.onload = function() { + if (!livePreview) { + livePreview = document.createElement('div'); + livePreview.className = 'livePreview'; + gallery.insertBefore(livePreview, gallery.firstElementChild); + } + livePreview.appendChild(img); if (livePreview.childElementCount > 2) { livePreview.removeChild(livePreview.firstElementChild); @@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre img.src = res.live_preview; } - - if (onProgress) { - onProgress(res); - } - setTimeout(() => { - fun(id_task, res.id_live_preview); + funLivePreview(id_task, res.id_live_preview); }, opts.live_preview_refresh_period || 500); }, function() { removeProgressBar(); }); }; - fun(id_task, 0); + funProgress(id_task, 0); + + if (gallery) { + funLivePreview(id_task, 0); + } + } diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js new file mode 100644 index 000000000..8c5c51692 --- /dev/null +++ b/javascript/resizeHandle.js @@ -0,0 +1,141 @@ +(function() { + const GRADIO_MIN_WIDTH = 320; + const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr'; + const PAD = 16; + const DEBOUNCE_TIME = 100; + + const R = { + tracking: false, + parent: null, + parentWidth: null, + leftCol: null, + leftColStartWidth: null, + screenX: null, + }; + + let resizeTimer; + let parents = []; + + function setLeftColGridTemplate(el, width) { + el.style.gridTemplateColumns = `${width}px 16px 1fr`; + } + + function displayResizeHandle(parent) { + if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) { + parent.style.display = 'flex'; + if (R.handle != null) { + R.handle.style.opacity = '0'; + } + return false; + } else { + parent.style.display = 'grid'; + if (R.handle != null) { + R.handle.style.opacity = '100'; + } + return true; + } + } + + function afterResize(parent) { + if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) { + const oldParentWidth = R.parentWidth; + const newParentWidth = parent.offsetWidth; + const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]); + + const ratio = newParentWidth / oldParentWidth; + + const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(parent, newWidthL); + + R.parentWidth = newParentWidth; + } + } + + function setup(parent) { + const leftCol = parent.firstElementChild; + const rightCol = parent.lastElementChild; + + parents.push(parent); + + parent.style.display = 'grid'; + parent.style.gap = '0'; + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + + const resizeHandle = document.createElement('div'); + resizeHandle.classList.add('resize-handle'); + parent.insertBefore(resizeHandle, rightCol); + + resizeHandle.addEventListener('mousedown', (evt) => { + if (evt.button !== 0) return; + + evt.preventDefault(); + evt.stopPropagation(); + + document.body.classList.add('resizing'); + + R.tracking = true; + R.parent = parent; + R.parentWidth = parent.offsetWidth; + R.handle = resizeHandle; + R.leftCol = leftCol; + R.leftColStartWidth = leftCol.offsetWidth; + R.screenX = evt.screenX; + }); + + resizeHandle.addEventListener('dblclick', (evt) => { + evt.preventDefault(); + evt.stopPropagation(); + + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + }); + + afterResize(parent); + } + + window.addEventListener('mousemove', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + const delta = R.screenX - evt.screenX; + const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(R.parent, leftColWidth); + } + }); + + window.addEventListener('mouseup', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + R.tracking = false; + + document.body.classList.remove('resizing'); + } + }); + + + window.addEventListener('resize', () => { + clearTimeout(resizeTimer); + + resizeTimer = setTimeout(function() { + for (const parent of parents) { + afterResize(parent); + } + }, DEBOUNCE_TIME); + }); + + setupResizeHandle = setup; +})(); + +onUiLoaded(function() { + for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) { + if (!elem.querySelector('.resize-handle')) { + setupResizeHandle(elem); + } + } +}); diff --git a/javascript/ui.js b/javascript/ui.js index d70a681bf..bedcbf3e2 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -19,28 +19,11 @@ function all_gallery_buttons() { } function selected_gallery_button() { - var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected'); - var visibleCurrentButton = null; - allCurrentButtons.forEach(function(elem) { - if (elem.parentElement.offsetParent) { - visibleCurrentButton = elem; - } - }); - return visibleCurrentButton; + return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null; } function selected_gallery_index() { - var buttons = all_gallery_buttons(); - var button = selected_gallery_button(); - - var result = -1; - buttons.forEach(function(v, i) { - if (v == button) { - result = i; - } - }); - - return result; + return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected')); } function extract_image_from_gallery(gallery) { @@ -152,11 +135,11 @@ function submit() { showSubmitButtons('txt2img', false); var id = randomId(); - localStorage.setItem("txt2img_task_id", id); + localSet("txt2img_task_id", id); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); - localStorage.removeItem("txt2img_task_id"); + localRemove("txt2img_task_id"); showRestoreProgressButton('txt2img', false); }); @@ -171,11 +154,11 @@ function submit_img2img() { showSubmitButtons('img2img', false); var id = randomId(); - localStorage.setItem("img2img_task_id", id); + localSet("img2img_task_id", id); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); - localStorage.removeItem("img2img_task_id"); + localRemove("img2img_task_id"); showRestoreProgressButton('img2img', false); }); @@ -189,9 +172,7 @@ function submit_img2img() { function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); - var id = localStorage.getItem("txt2img_task_id"); - - id = localStorage.getItem("txt2img_task_id"); + var id = localGet("txt2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { @@ -205,7 +186,7 @@ function restoreProgressTxt2img() { function restoreProgressImg2img() { showRestoreProgressButton("img2img", false); - var id = localStorage.getItem("img2img_task_id"); + var id = localGet("img2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { @@ -218,8 +199,8 @@ function restoreProgressImg2img() { onUiLoaded(function() { - showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id")); - showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id")); + showRestoreProgressButton('txt2img', localGet("txt2img_task_id")); + showRestoreProgressButton('img2img', localGet("img2img_task_id")); }); diff --git a/launch.py b/launch.py index 1dbc4c6e3..e4c2ce99e 100644 --- a/launch.py +++ b/launch.py @@ -1,6 +1,5 @@ from modules import launch_utils - args = launch_utils.args python = launch_utils.python git = launch_utils.git @@ -26,8 +25,11 @@ start = launch_utils.start def main(): - if not args.skip_prepare_environment: - prepare_environment() + launch_utils.startup_timer.record("initial startup") + + with launch_utils.startup_timer.subcategory("prepare environment"): + if not args.skip_prepare_environment: + prepare_environment() if args.test_server: configure_for_tests() diff --git a/modules/api/api.py b/modules/api/api.py index 606db179d..844e31ee7 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,6 +4,8 @@ import os import time import datetime import uvicorn +import ipaddress +import requests import gradio as gr from threading import Lock from io import BytesIO @@ -15,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases -from modules.sd_vae import vae_dict +from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -56,7 +57,41 @@ def setUpscalers(req: dict): return reqDict +def verify_url(url): + """Returns True if the url refers to a global resource.""" + + import socket + from urllib.parse import urlparse + try: + parsed_url = urlparse(url) + domain_name = parsed_url.netloc + host = socket.gethostbyname_ex(domain_name) + for ip in host[2]: + ip_addr = ipaddress.ip_address(ip) + if not ip_addr.is_global: + return False + except Exception: + return False + + return True + + def decode_base64_to_image(encoding): + if encoding.startswith("http://") or encoding.startswith("https://"): + if not opts.api_enable_requests: + raise HTTPException(status_code=500, detail="Requests not allowed") + + if opts.api_forbid_local_requests and not verify_url(encoding): + raise HTTPException(status_code=500, detail="Request to local resource not allowed") + + headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + response = requests.get(encoding, timeout=30, headers=headers) + try: + image = Image.open(BytesIO(response.content)) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid image url") from e + if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] try: @@ -197,6 +232,7 @@ class Api: self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) @@ -207,6 +243,7 @@ class Api: self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) + self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem]) if shared.cmd_opts.api_server_stop: self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"]) @@ -329,6 +366,7 @@ class Api: with self.queue_lock: with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p: + p.is_api = True p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -343,6 +381,7 @@ class Api: processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -388,6 +427,7 @@ class Api: with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] + p.is_api = True p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples @@ -402,6 +442,7 @@ class Api: processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -433,9 +474,6 @@ class Api: return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) def pnginfoapi(self, req: models.PNGInfoRequest): - if(not req.image.strip()): - return models.PNGInfoResponse(info="") - image = decode_base64_to_image(req.image.strip()) if image is None: return models.PNGInfoResponse(info="") @@ -444,9 +482,10 @@ class Api: if geninfo is None: geninfo = "" - items = {**{'parameters': geninfo}, **items} + params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + script_callbacks.infotext_pasted_callback(geninfo, params) - return models.PNGInfoResponse(info=geninfo, items=items) + return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) def progressapi(self, req: models.ProgressRequest = Depends()): # copy from check_progress_call of ui.py @@ -530,7 +569,7 @@ class Api: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): - shared.opts.set(k, v) + shared.opts.set(k, v, is_api=True) shared.opts.save(shared.config_filename) return @@ -562,10 +601,12 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + import modules.sd_models as sd_models + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] def get_sd_vaes(self): - return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + import modules.sd_vae as sd_vae + return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -608,6 +649,10 @@ class Api: with self.queue_lock: shared.refresh_checkpoints() + def refresh_vae(self): + with self.queue_lock: + shared_items.refresh_vae_list() + def create_embedding(self, args: dict): try: shared.state.begin(job="create_embedding") @@ -724,6 +769,25 @@ class Api: cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) + def get_extensions_list(self): + from modules import extensions + extensions.list_extensions() + ext_list = [] + for ext in extensions.extensions: + ext: extensions.Extension + ext.read_info_from_repo() + if ext.remote is not None: + ext_list.append({ + "name": ext.name, + "remote": ext.remote, + "branch": ext.branch, + "commit_hash":ext.commit_hash, + "commit_date":ext.commit_date, + "version":ext.version, + "enabled":ext.enabled + }) + return ext_list + def launch(self, server_name, port, root_path): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) diff --git a/modules/api/models.py b/modules/api/models.py index 800c9b93f..94eca97dc 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -50,10 +50,12 @@ class PydanticModelGenerator: additional_fields = None, ): def field_type_generator(k, v): - # field_type = str if not overrides.get(k) else overrides[k]["type"] - # print(k, v.annotation, v.default) field_type = v.annotation + if field_type == 'Image': + # images are sent as base64 strings via API + field_type = 'str' + return Optional[field_type] def merge_class_params(class_): @@ -63,7 +65,6 @@ class PydanticModelGenerator: parameters = {**parameters, **inspect.signature(classes.__init__).parameters} return parameters - self._model_name = model_name self._class_data = merge_class_params(class_instance) @@ -72,7 +73,7 @@ class PydanticModelGenerator: field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), - field_value=v.default + field_value=None if isinstance(v.default, property) else v.default ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] @@ -177,7 +178,8 @@ class PNGInfoRequest(BaseModel): class PNGInfoResponse(BaseModel): info: str = Field(title="Image info", description="A string with the parameters used to generate the image") - items: dict = Field(title="Items", description="An object containing all the info the image had") + items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had") + parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields") class ProgressRequest(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") @@ -310,3 +312,12 @@ class ScriptInfo(BaseModel): is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script") is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script") args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments") + +class ExtensionItem(BaseModel): + name: str = Field(title="Name", description="Extension name") + remote: str = Field(title="Remote", description="Extension Repository URL") + branch: str = Field(title="Branch", description="Extension Repository Branch") + commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash") + version: str = Field(title="Version", description="Extension Version") + commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date") + enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled") diff --git a/modules/cache.py b/modules/cache.py index 71fe63021..ff26a2132 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,11 +1,12 @@ import json +import os import os.path import threading import time from modules.paths import data_path, script_path -cache_filename = os.path.join(data_path, "cache.json") +cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json")) cache_data = None cache_lock = threading.Lock() @@ -29,9 +30,12 @@ def dump_cache(): time.sleep(1) with cache_lock: - with open(cache_filename, "w", encoding="utf8") as file: + cache_filename_tmp = cache_filename + "-" + with open(cache_filename_tmp, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) + os.replace(cache_filename_tmp, cache_filename) + dump_cache_after = None dump_cache_thread = None diff --git a/modules/call_queue.py b/modules/call_queue.py index f2eb17d61..ddf0d5738 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,11 +1,10 @@ from functools import wraps import html -import threading import time -from modules import shared, progress, errors, devices +from modules import shared, progress, errors, devices, fifo_lock -queue_lock = threading.Lock() +queue_lock = fifo_lock.FIFOLock() def wrap_queued_call(func): diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e401f6413..f0f361bde 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") +parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None) parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) @@ -33,9 +35,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") -parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") @@ -66,6 +69,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -78,7 +82,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") -parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it") +parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) @@ -110,3 +114,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/config_states.py b/modules/config_states.py index 6f1ab53fc..b766aef11 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -8,14 +8,12 @@ import time import tqdm from datetime import datetime -from collections import OrderedDict import git from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir - -all_config_states = OrderedDict() +all_config_states = {} def list_config_states(): @@ -28,10 +26,14 @@ def list_config_states(): for filename in os.listdir(config_states_dir): if filename.endswith(".json"): path = os.path.join(config_states_dir, filename) - with open(path, "r", encoding="utf-8") as f: - j = json.load(f) - j["filepath"] = path - config_states.append(j) + try: + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + assert "created_at" in j, '"created_at" does not exist' + j["filepath"] = path + config_states.append(j) + except Exception as e: + print(f'[ERROR]: Config states {path}, {e}') config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) diff --git a/modules/devices.py b/modules/devices.py index 57e51da30..c01f06024 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific @@ -17,8 +17,6 @@ def has_mps() -> bool: def get_cuda_device_string(): - from modules import shared - if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -40,8 +38,6 @@ def get_optimal_device(): def get_device_for(task): - from modules import shared - if task in shared.cmd_opts.use_cpu: return cpu @@ -71,14 +67,17 @@ def enable_tf32(): torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -90,26 +89,10 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input -def randn(seed, shape): - from modules.shared import opts - - torch.manual_seed(seed) - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) - - -def randn_without_seed(shape): - from modules.shared import opts - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - return torch.randn(shape, device=device) +nv_rng = None def autocast(disable=False): - from modules import shared - if disable: return contextlib.nullcontext() @@ -128,8 +111,6 @@ class NansException(Exception): def test_for_nans(x, where): - from modules import shared - if shared.cmd_opts.disable_nan_check: return @@ -169,3 +150,4 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + diff --git a/modules/errors.py b/modules/errors.py index dffabe45c..8c339464d 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -84,3 +84,53 @@ def run(code, task): code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.41.2" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/modules/extensions.py b/modules/extensions.py index 3ad5ed531..bf9a1878f 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os import threading -from modules import shared, errors, cache +from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True) def active(): - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": return [] - elif shared.opts.disable_all_extensions == "extra": + elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -90,8 +90,6 @@ class Extension: self.have_info_from_repo = True def list_files(self, subdir, extension): - from modules import scripts - dirpath = os.path.join(self.path, subdir) if not os.path.isdir(dirpath): return [] @@ -141,8 +139,12 @@ def list_extensions(): if not os.path.isdir(extensions_dir): return - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions: + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": print("*** \"Disable all extensions\" option was set, will not load any extensions ***") + elif shared.cmd_opts.disable_extra_extensions: + print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91b..b95336778 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,4 +1,7 @@ +import json +import os import re +import logging from collections import defaultdict from modules import errors @@ -84,27 +87,55 @@ class ExtraNetwork: raise NotImplementedError +def lookup_extra_networks(extra_network_data): + """returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks. + + Example input: + { + 'lora': [], + 'lyco': [], + 'hypernet': [] + } + + Example output: + + { + : [, ], + : [] + } + """ + + res = {} + + for extra_network_name, extra_network_args in list(extra_network_data.items()): + extra_network = extra_network_registry.get(extra_network_name, None) + alias = extra_network_aliases.get(extra_network_name, None) + + if alias is not None and extra_network is None: + extra_network = alias + + if extra_network is None: + logging.info(f"Skipping unknown extra network: {extra_network_name}") + continue + + res.setdefault(extra_network, []).extend(extra_network_args) + + return res + + def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" activated = [] - for extra_network_name, extra_network_args in extra_network_data.items(): - extra_network = extra_network_registry.get(extra_network_name, None) - - if extra_network is None: - extra_network = extra_network_aliases.get(extra_network_name, None) - - if extra_network is None: - print(f"Skipping unknown extra network: {extra_network_name}") - continue + for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items(): try: extra_network.activate(p, extra_network_args) activated.append(extra_network) except Exception as e: - errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): if extra_network in activated: @@ -123,19 +154,16 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - for extra_network_name in extra_network_data: - extra_network = extra_network_registry.get(extra_network_name, None) - if extra_network is None: - continue + data = lookup_extra_networks(extra_network_data) + for extra_network in data: try: extra_network.deactivate(p) except Exception as e: - errors.display(e, f"deactivating extra network {extra_network_name}") + errors.display(e, f"deactivating extra network {extra_network.name}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in data: continue try: @@ -177,3 +205,20 @@ def parse_prompts(prompts): return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/extras.py b/modules/extras.py index e9c0263ec..2a310ae3f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import json import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,7 +72,20 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} + + for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]: + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") - metadata = None + metadata = {} + + if save_metadata and copy_metadata_fields: + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: + metadata.update(secondary_model_info.metadata) + if tertiary_model_info: + metadata.update(tertiary_model_info.metadata) if save_metadata: - metadata = {"format": "pt"} + try: + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + metadata["format"] = "pt" + + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) + safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None) else: torch.save(theta_0, output_modelname) diff --git a/modules/fifo_lock.py b/modules/fifo_lock.py new file mode 100644 index 000000000..c35b3ae25 --- /dev/null +++ b/modules/fifo_lock.py @@ -0,0 +1,37 @@ +import threading +import collections + + +# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a +class FIFOLock(object): + def __init__(self): + self._lock = threading.Lock() + self._inner_lock = threading.Lock() + self._pending_threads = collections.deque() + + def acquire(self, blocking=True): + with self._inner_lock: + lock_acquired = self._lock.acquire(False) + if lock_acquired: + return True + elif not blocking: + return False + + release_event = threading.Event() + self._pending_threads.append(release_event) + + release_event.wait() + return self._lock.acquire() + + def release(self): + with self._inner_lock: + if self._pending_threads: + release_event = self._pending_threads.popleft() + release_event.set() + + self._lock.release() + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9d..2ca160554 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks +from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' @@ -32,6 +32,7 @@ class ParamBinding: def reset(): paste_fields.clear() + registered_param_bindings.clear() def quote(text): @@ -198,7 +199,6 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - from modules import processing firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width @@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" @@ -304,32 +307,28 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Schedule rho" not in res: res["Schedule rho"] = 0 + if "VAE Encoder" not in res: + res["VAE Encoder"] = "Full" + + if "VAE Decoder" not in res: + res["VAE Decoder"] = "Full" + return res infotext_to_setting_name_mapping = [ - ('Clip skip', 'CLIP_stop_at_last_layers', ), + +] +"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead. +Example content: + +infotext_to_setting_name_mapping = [ ('Conditional mask weight', 'inpainting_mask_weight'), ('Model hash', 'sd_model_checkpoint'), ('ENSD', 'eta_noise_seed_delta'), ('Schedule type', 'k_sched_type'), - ('Schedule max sigma', 'sigma_max'), - ('Schedule min sigma', 'sigma_min'), - ('Schedule rho', 'rho'), - ('Noise multiplier', 'initial_noise_multiplier'), - ('Eta', 'eta_ancestral'), - ('Eta DDIM', 'eta_ddim'), - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ('Token merging ratio', 'token_merging_ratio'), - ('Token merging ratio hr', 'token_merging_ratio_hr'), - ('RNG', 'randn_source'), - ('NGMS', 's_min_uncond'), - ('Pad conds', 'pad_cond_uncond'), ] +""" def create_override_settings_dict(text_pairs): @@ -350,7 +349,8 @@ def create_override_settings_dict(text_pairs): params[k] = v.strip() - for param_name, setting_name in infotext_to_setting_name_mapping: + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: value = params.get(param_name, None) if value is None: @@ -399,10 +399,16 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, return res if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + def paste_settings(params): vals = {} - for param_name, setting_name in infotext_to_setting_name_mapping: + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in already_handled_fields: + continue + v = params.get(param_name, None) if v is None: continue diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py new file mode 100644 index 000000000..e6b6835ad --- /dev/null +++ b/modules/gradio_extensons.py @@ -0,0 +1,73 @@ +import gradio as gr + +from modules import scripts, ui_tempdir, patches + + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + config.pop('example_inputs', None) + + return config + + +def BlockContext_init(self, *args, **kwargs): + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + return res + + +def Blocks_get_config_file(self, *args, **kwargs): + config = original_Blocks_get_config_file(self, *args, **kwargs) + + for comp_config in config["components"]: + if "example_inputs" in comp_config: + comp_config["example_inputs"] = {"serialized": []} + + return config + + +original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) +original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) +original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) +original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) + + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21a..70f1cbd26 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images allows training previews to have infotext. Importing it at the top causes a circular import problem. - from modules import images + from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/images.py b/modules/images.py index 38aa933d6..eb6447338 100644 --- a/modules/images.py +++ b/modules/images.py @@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors from modules.paths_internal import roboto_ttf_file from modules.shared import opts -import modules.sd_vae as sd_vae - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -318,7 +316,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None): return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') @@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True): class FilenameGenerator: - def get_vae_filename(self): #get the name of the VAE file. - if sd_vae.loaded_vae_file is None: - return "NoneType" - file_name = os.path.basename(sd_vae.loaded_vae_file) - split_file_name = file_name.split('.') - if len(split_file_name) > 1 and split_file_name[0] == '': - return split_file_name[1] # if the first character of the filename is "." then [1] is obtained. - else: - return split_file_name[0] - replacements = { 'seed': lambda self: self.seed if self.seed is not None else '', 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], @@ -367,7 +355,9 @@ class FilenameGenerator: 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime