mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
Merge branch 'dev' into patch-1
This commit is contained in:
commit
9e14cac318
@ -87,5 +87,11 @@ module.exports = {
|
||||
modalNextImage: "readonly",
|
||||
// token-counters.js
|
||||
setupTokenCounters: "readonly",
|
||||
// localStorage.js
|
||||
localSet: "readonly",
|
||||
localGet: "readonly",
|
||||
localRemove: "readonly",
|
||||
// resizeHandle.js
|
||||
setupResizeHandle: "writable"
|
||||
}
|
||||
};
|
||||
|
139
CHANGELOG.md
139
CHANGELOG.md
@ -1,3 +1,142 @@
|
||||
## 1.6.0
|
||||
|
||||
### Features:
|
||||
* refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)
|
||||
* add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards
|
||||
* add style editor dialog
|
||||
* hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))
|
||||
* option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))
|
||||
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
||||
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
||||
* makes all of them work with img2img
|
||||
* makes prompt composition posssible (AND)
|
||||
* makes them available for SDXL
|
||||
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
||||
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
||||
* textual inversion inference support for SDXL
|
||||
* extra networks UI: show metadata for SD checkpoints
|
||||
* checkpoint merger: add metadata support
|
||||
* prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))
|
||||
* VAE: allow selecting own VAE for each checkpoint (in user metadata editor)
|
||||
* VAE: add selected VAE to infotext
|
||||
* options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))
|
||||
* add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723))
|
||||
* change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it
|
||||
* show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707))
|
||||
* add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models
|
||||
* prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457))
|
||||
|
||||
### Minor:
|
||||
* img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))
|
||||
* postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))
|
||||
* XYZ: in the axis labels, remove pathnames from model filenames
|
||||
* XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))
|
||||
* XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))
|
||||
* add gradio version warning
|
||||
* sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))
|
||||
* use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))
|
||||
* move some settings to their own section: img2img, VAE
|
||||
* add checkbox to show/hide dirs for extra networks
|
||||
* Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))
|
||||
* gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))
|
||||
* sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))
|
||||
* update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))
|
||||
* option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))
|
||||
* enable cond cache by default
|
||||
* git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))
|
||||
* allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))
|
||||
* automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))
|
||||
* put commonly used samplers on top, make DPM++ 2M Karras the default choice
|
||||
* zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727))
|
||||
* option to cache Lora networks in memory
|
||||
* rework hires fix UI to use accordion
|
||||
* face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back
|
||||
* change quicksettings items to have variable width
|
||||
* Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))
|
||||
* Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
|
||||
* support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))
|
||||
* add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))
|
||||
* support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584))
|
||||
* make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634))
|
||||
* configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648))
|
||||
* make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645))
|
||||
* more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639))
|
||||
* make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635))
|
||||
* make progress bar work independently from live preview display which results in it being updated a lot more often
|
||||
* forbid Full live preview method for medvram and add a setting to undo the forbidding
|
||||
* make it possible to localize tooltips and placeholders
|
||||
|
||||
### Extensions and API:
|
||||
* gradio 3.41.2
|
||||
* also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd
|
||||
* support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')
|
||||
* properly clear the total console progressbar when using txt2img and img2img from API
|
||||
* add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))
|
||||
* shared.py and webui.py split into many files
|
||||
* add --loglevel commandline argument for logging
|
||||
* add a custom UI element that combines accordion and checkbox
|
||||
* avoid importing gradio in tests because it spams warnings
|
||||
* put infotext label for setting into OptionInfo definition rather than in a separate list
|
||||
* make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))
|
||||
* option to make scripts UI without gr.Group
|
||||
* add a way for scripts to register a callback for before/after just a single component's creation
|
||||
* use dataclass for StableDiffusionProcessing
|
||||
* store patches for Lora in a specialized module instead of inside torch
|
||||
* support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698))
|
||||
* add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616))
|
||||
* dump current stack traces when exiting with SIGINT
|
||||
* add type annotations for extra fields of shared.sd_model
|
||||
|
||||
### Bug Fixes:
|
||||
* Don't crash if out of local storage quota for javascriot localStorage
|
||||
* XYZ plot do not fail if an exception occurs
|
||||
* fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))
|
||||
* localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))
|
||||
* fix sdxl model invalid configuration after the hijack
|
||||
* correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))
|
||||
* open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))
|
||||
* prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))
|
||||
* add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))
|
||||
* fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))
|
||||
* fix options in main UI misbehaving when there's just one element
|
||||
* make it possible to use a sampler from infotext even if it's hidden in the dropdown
|
||||
* fix styles missing from the prompt in infotext when making a grid of batch of multiplie images
|
||||
* prevent bogus progress output in console when calculating hires fix dimensions
|
||||
* fix --use-textbox-seed
|
||||
* fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))
|
||||
* properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))
|
||||
* MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))
|
||||
* add second_order to samplers that mistakenly didn't have it
|
||||
* when refreshing cards in extra networks UI, do not discard user's custom resolution
|
||||
* fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))
|
||||
* fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588))
|
||||
* fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586))
|
||||
* fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596))
|
||||
* auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603))
|
||||
* fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607))
|
||||
* fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638))
|
||||
* fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633))
|
||||
* attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630))
|
||||
* implement missing undo hijack for SDXL
|
||||
* fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684))
|
||||
* fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689))
|
||||
* fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685))
|
||||
* fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667))
|
||||
* create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717))
|
||||
* prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version
|
||||
* set devices.dtype_unet correctly
|
||||
* run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||
* prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745))
|
||||
* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
|
||||
* fix defaults settings page breaking when any of main UI tabs are hidden
|
||||
* fix incorrect save/display of new values in Defaults page in settings
|
||||
* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
|
||||
* fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||
* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||
* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||
* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
|
||||
|
||||
|
||||
## 1.5.2
|
||||
|
||||
### Bug Fixes:
|
||||
|
7
CITATION.cff
Normal file
7
CITATION.cff
Normal file
@ -0,0 +1,7 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- given-names: AUTOMATIC1111
|
||||
title: "Stable Diffusion Web UI"
|
||||
date-released: 2022-08-22
|
||||
url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
|
12
README.md
12
README.md
@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
||||
- Clip skip
|
||||
- Hypernetworks
|
||||
- Loras (same as Hypernetworks but more pretty)
|
||||
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
||||
- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
||||
- Can select to load a different VAE from settings screen
|
||||
- Estimated completion time in progress bar
|
||||
- API
|
||||
@ -93,7 +93,10 @@ A browser interface based on Gradio library for Stable Diffusion.
|
||||
- Reorder elements in the UI from settings screen
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
|
||||
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
|
||||
|
||||
Alternatively, use online services (like Google Colab):
|
||||
|
||||
@ -115,7 +118,7 @@ Alternatively, use online services (like Google Colab):
|
||||
1. Install the dependencies:
|
||||
```bash
|
||||
# Debian-based:
|
||||
sudo apt install wget git python3 python3-venv
|
||||
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||
# Red Hat-based:
|
||||
sudo dnf install wget git python3
|
||||
# Arch-based:
|
||||
@ -123,7 +126,7 @@ sudo pacman -S wget git python3
|
||||
```
|
||||
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
||||
```bash
|
||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
||||
wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
|
||||
```
|
||||
3. Run `webui.sh`.
|
||||
4. Check `webui-user.sh` for options.
|
||||
@ -169,5 +172,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||
- LyCORIS - KohakuBlueleaf
|
||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||
- (You)
|
||||
|
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
|
||||
self.errors = {}
|
||||
"""mapping of network names to the number of errors the network had during operation"""
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
self.errors.clear()
|
||||
|
||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
if self.errors:
|
||||
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||
|
||||
self.errors.clear()
|
||||
|
31
extensions-builtin/Lora/lora_patches.py
Normal file
31
extensions-builtin/Lora/lora_patches.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
import networks
|
||||
from modules import patches
|
||||
|
||||
|
||||
class LoraPatches:
|
||||
def __init__(self):
|
||||
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
|
||||
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
|
||||
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
|
||||
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
|
||||
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
|
||||
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
|
||||
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
|
||||
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
|
||||
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
|
||||
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
|
||||
|
||||
def undo(self):
|
||||
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
|
||||
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
|
||||
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
|
||||
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
|
||||
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
|
||||
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
|
||||
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
|
||||
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
|
||||
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
|
||||
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
|
||||
|
@ -133,7 +133,7 @@ class NetworkModule:
|
||||
|
||||
return 1.0
|
||||
|
||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
@ -145,7 +145,10 @@ class NetworkModule:
|
||||
if orig_weight.size().numel() == updown.size().numel():
|
||||
updown = updown.reshape(orig_weight.shape)
|
||||
|
||||
return updown * self.calc_scale() * self.multiplier()
|
||||
if ex_bias is not None:
|
||||
ex_bias = ex_bias * self.multiplier()
|
||||
|
||||
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||
|
||||
def calc_updown(self, target):
|
||||
raise NotImplementedError()
|
||||
|
@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.weight = weights.w.get("diff")
|
||||
self.ex_bias = weights.w.get("diff_b")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||
|
28
extensions-builtin/Lora/network_norm.py
Normal file
28
extensions-builtin/Lora/network_norm.py
Normal file
@ -0,0 +1,28 @@
|
||||
import network
|
||||
|
||||
|
||||
class ModuleTypeNorm(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||
return NetworkModuleNorm(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleNorm(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.w_norm = weights.w.get("w_norm")
|
||||
self.b_norm = weights.w.get("b_norm")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import lora_patches
|
||||
import network
|
||||
import network_lora
|
||||
import network_hada
|
||||
import network_ia3
|
||||
import network_lokr
|
||||
import network_full
|
||||
import network_norm
|
||||
|
||||
import torch
|
||||
from typing import Union
|
||||
@ -19,6 +22,7 @@ module_types = [
|
||||
network_ia3.ModuleTypeIa3(),
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
]
|
||||
|
||||
|
||||
@ -31,6 +35,8 @@ suffix_conversion = {
|
||||
"resnets": {
|
||||
"conv1": "in_layers_2",
|
||||
"conv2": "out_layers_3",
|
||||
"norm1": "in_layers_0",
|
||||
"norm2": "out_layers_0",
|
||||
"time_emb_proj": "emb_layers_1",
|
||||
"conv_shortcut": "skip_connection",
|
||||
}
|
||||
@ -190,11 +196,19 @@ def load_network(name, network_on_disk):
|
||||
net.modules[key] = net_module
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def purge_networks_from_memory():
|
||||
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
||||
name = next(iter(networks_in_memory))
|
||||
networks_in_memory.pop(name, None)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
already_loaded = {}
|
||||
|
||||
@ -212,15 +226,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
|
||||
failed_to_load_networks = []
|
||||
|
||||
for i, name in enumerate(names):
|
||||
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||
net = already_loaded.get(name, None)
|
||||
|
||||
network_on_disk = networks_on_disk[i]
|
||||
|
||||
if network_on_disk is not None:
|
||||
if net is None:
|
||||
net = networks_in_memory.get(name)
|
||||
|
||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||
try:
|
||||
net = load_network(name, network_on_disk)
|
||||
|
||||
networks_in_memory.pop(name, None)
|
||||
networks_in_memory[name] = net
|
||||
except Exception as e:
|
||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||
continue
|
||||
@ -231,7 +249,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
print(f"Couldn't find network with name {name}")
|
||||
logging.info(f"Couldn't find network with name {name}")
|
||||
continue
|
||||
|
||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||
@ -240,23 +258,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
loaded_networks.append(net)
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
|
||||
purge_networks_from_memory()
|
||||
|
||||
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
|
||||
if weights_backup is None:
|
||||
if weights_backup is None and bias_backup is None:
|
||||
return
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
if weights_backup is not None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.in_proj_weight.copy_(weights_backup[0])
|
||||
self.out_proj.weight.copy_(weights_backup[1])
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
|
||||
if bias_backup is not None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias.copy_(bias_backup)
|
||||
else:
|
||||
self.bias.copy_(bias_backup)
|
||||
else:
|
||||
self.weight.copy_(weights_backup)
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
self.out_proj.bias = None
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||
"""
|
||||
Applies the currently selected set of networks to the weights of torch layer self.
|
||||
If weights already have this particular set of networks applied, does nothing.
|
||||
@ -271,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||
|
||||
weights_backup = getattr(self, "network_weights_backup", None)
|
||||
if weights_backup is None:
|
||||
if weights_backup is None and wanted_names != ():
|
||||
if current_names != ():
|
||||
raise RuntimeError("no backup weights found and current weights are not unchanged")
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention):
|
||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||
else:
|
||||
@ -279,21 +315,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
|
||||
self.network_weights_backup = weights_backup
|
||||
|
||||
bias_backup = getattr(self, "network_bias_backup", None)
|
||||
if bias_backup is None:
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||
elif getattr(self, 'bias', None) is not None:
|
||||
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||
else:
|
||||
bias_backup = None
|
||||
self.network_bias_backup = bias_backup
|
||||
|
||||
if current_names != wanted_names:
|
||||
network_restore_weights_from_backup(self)
|
||||
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
with torch.no_grad():
|
||||
updown = module.calc_updown(self.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
continue
|
||||
self.weight += updown
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||
@ -301,21 +357,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
with torch.no_grad():
|
||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
continue
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
if ex_bias is not None:
|
||||
if self.out_proj.bias is None:
|
||||
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.out_proj.bias += ex_bias
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
@ -342,7 +410,7 @@ def network_forward(module, input, original_forward):
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
y = module.forward(y, input)
|
||||
y = module.forward(input, y)
|
||||
|
||||
return y
|
||||
|
||||
@ -354,44 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
|
||||
def network_Linear_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
||||
return network_forward(self, input, originals.Linear_forward)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.Linear_forward_before_network(self, input)
|
||||
return originals.Linear_forward(self, input)
|
||||
|
||||
|
||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
||||
return originals.Linear_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_Conv2d_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
||||
return network_forward(self, input, originals.Conv2d_forward)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.Conv2d_forward_before_network(self, input)
|
||||
return originals.Conv2d_forward(self, input)
|
||||
|
||||
|
||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
||||
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_GroupNorm_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.GroupNorm_forward)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return originals.GroupNorm_forward(self, input)
|
||||
|
||||
|
||||
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_LayerNorm_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.LayerNorm_forward)
|
||||
|
||||
network_apply_weights(self)
|
||||
|
||||
return originals.LayerNorm_forward(self, input)
|
||||
|
||||
|
||||
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||
network_apply_weights(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
||||
return originals.MultiheadAttention_forward(self, *args, **kwargs)
|
||||
|
||||
|
||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||
network_reset_cached_weight(self)
|
||||
|
||||
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
||||
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||
|
||||
|
||||
def list_available_networks():
|
||||
@ -459,9 +557,14 @@ def infotext_pasted(infotext, params):
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
|
||||
originals: lora_patches.LoraPatches = None
|
||||
|
||||
extra_network_lora = None
|
||||
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
loaded_networks = []
|
||||
networks_in_memory = {}
|
||||
available_network_hash_lookup = {}
|
||||
forbidden_network_aliases = {}
|
||||
|
||||
|
@ -1,57 +1,30 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
import gradio as gr
|
||||
from fastapi import FastAPI
|
||||
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
||||
networks.originals.undo()
|
||||
|
||||
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
|
||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(extra_network)
|
||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
||||
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
||||
|
||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||
networks.originals = lora_patches.LoraPatches()
|
||||
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
@ -65,6 +38,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
}))
|
||||
|
||||
|
||||
@ -121,3 +95,5 @@ def infotext_pasted(infotext, d):
|
||||
|
||||
|
||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||
|
||||
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||
|
@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||
|
||||
with gr.Column(scale=1, min_width=120):
|
||||
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
|
||||
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
||||
|
||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||
|
||||
|
@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
item = {
|
||||
"name": name,
|
||||
"filename": lora_on_disk.filename,
|
||||
"shorthash": lora_on_disk.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": lora_on_disk.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||
|
@ -12,8 +12,22 @@ onUiLoaded(async() => {
|
||||
"Sketch": elementIDs.sketch
|
||||
};
|
||||
|
||||
|
||||
// Helper functions
|
||||
// Get active tab
|
||||
|
||||
/**
|
||||
* Waits for an element to be present in the DOM.
|
||||
*/
|
||||
const waitForElement = (id) => new Promise(resolve => {
|
||||
const checkForElement = () => {
|
||||
const element = document.querySelector(id);
|
||||
if (element) return resolve(element);
|
||||
setTimeout(checkForElement, 100);
|
||||
};
|
||||
checkForElement();
|
||||
});
|
||||
|
||||
function getActiveTab(elements, all = false) {
|
||||
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||
|
||||
@ -34,7 +48,7 @@ onUiLoaded(async() => {
|
||||
|
||||
// Wait until opts loaded
|
||||
async function waitForOpts() {
|
||||
for (;;) {
|
||||
for (; ;) {
|
||||
if (window.opts && Object.keys(window.opts).length) {
|
||||
return window.opts;
|
||||
}
|
||||
@ -42,6 +56,11 @@ onUiLoaded(async() => {
|
||||
}
|
||||
}
|
||||
|
||||
// Detect whether the element has a horizontal scroll bar
|
||||
function hasHorizontalScrollbar(element) {
|
||||
return element.scrollWidth > element.clientWidth;
|
||||
}
|
||||
|
||||
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
||||
function isModifierKey(event, key) {
|
||||
switch (key) {
|
||||
@ -201,7 +220,8 @@ onUiLoaded(async() => {
|
||||
canvas_hotkey_overlap: "KeyO",
|
||||
canvas_disabled_functions: [],
|
||||
canvas_show_tooltip: true,
|
||||
canvas_blur_prompt: false
|
||||
canvas_auto_expand: true,
|
||||
canvas_blur_prompt: false,
|
||||
};
|
||||
|
||||
const functionMap = {
|
||||
@ -249,7 +269,7 @@ onUiLoaded(async() => {
|
||||
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
||||
}
|
||||
|
||||
function applyZoomAndPan(elemId) {
|
||||
function applyZoomAndPan(elemId, isExtension = true) {
|
||||
const targetElement = gradioApp().querySelector(elemId);
|
||||
|
||||
if (!targetElement) {
|
||||
@ -361,6 +381,12 @@ onUiLoaded(async() => {
|
||||
panY: 0
|
||||
};
|
||||
|
||||
if (isExtension) {
|
||||
targetElement.style.overflow = "hidden";
|
||||
}
|
||||
|
||||
targetElement.isZoomed = false;
|
||||
|
||||
fixCanvas();
|
||||
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
||||
|
||||
@ -371,8 +397,27 @@ onUiLoaded(async() => {
|
||||
toggleOverlap("off");
|
||||
fullScreenMode = false;
|
||||
|
||||
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
|
||||
if (closeBtn) {
|
||||
closeBtn.addEventListener("click", resetZoom);
|
||||
}
|
||||
|
||||
if (canvas && isExtension) {
|
||||
const parentElement = targetElement.closest('[id^="component-"]');
|
||||
if (
|
||||
canvas &&
|
||||
parseFloat(canvas.style.width) > parentElement.offsetWidth &&
|
||||
parseFloat(targetElement.style.width) > parentElement.offsetWidth
|
||||
) {
|
||||
fitToElement();
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (
|
||||
canvas &&
|
||||
!isExtension &&
|
||||
parseFloat(canvas.style.width) > 865 &&
|
||||
parseFloat(targetElement.style.width) > 865
|
||||
) {
|
||||
@ -381,9 +426,6 @@ onUiLoaded(async() => {
|
||||
}
|
||||
|
||||
targetElement.style.width = "";
|
||||
if (canvas) {
|
||||
targetElement.style.height = canvas.style.height;
|
||||
}
|
||||
}
|
||||
|
||||
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
||||
@ -439,7 +481,7 @@ onUiLoaded(async() => {
|
||||
|
||||
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
||||
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
||||
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
|
||||
newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15));
|
||||
|
||||
elemData[elemId].panX +=
|
||||
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
|
||||
@ -450,6 +492,10 @@ onUiLoaded(async() => {
|
||||
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
||||
|
||||
toggleOverlap("on");
|
||||
if (isExtension) {
|
||||
targetElement.style.overflow = "visible";
|
||||
}
|
||||
|
||||
return newZoomLevel;
|
||||
}
|
||||
|
||||
@ -472,10 +518,12 @@ onUiLoaded(async() => {
|
||||
fullScreenMode = false;
|
||||
elemData[elemId].zoomLevel = updateZoom(
|
||||
elemData[elemId].zoomLevel +
|
||||
(operation === "+" ? delta : -delta),
|
||||
(operation === "+" ? delta : -delta),
|
||||
zoomPosX - targetElement.getBoundingClientRect().left,
|
||||
zoomPosY - targetElement.getBoundingClientRect().top
|
||||
);
|
||||
|
||||
targetElement.isZoomed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -489,10 +537,19 @@ onUiLoaded(async() => {
|
||||
//Reset Zoom
|
||||
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||
|
||||
let parentElement;
|
||||
|
||||
if (isExtension) {
|
||||
parentElement = targetElement.closest('[id^="component-"]');
|
||||
} else {
|
||||
parentElement = targetElement.parentElement;
|
||||
}
|
||||
|
||||
|
||||
// Get element and screen dimensions
|
||||
const elementWidth = targetElement.offsetWidth;
|
||||
const elementHeight = targetElement.offsetHeight;
|
||||
const parentElement = targetElement.parentElement;
|
||||
|
||||
const screenWidth = parentElement.clientWidth;
|
||||
const screenHeight = parentElement.clientHeight;
|
||||
|
||||
@ -545,8 +602,12 @@ onUiLoaded(async() => {
|
||||
|
||||
if (!canvas) return;
|
||||
|
||||
if (canvas.offsetWidth > 862) {
|
||||
targetElement.style.width = canvas.offsetWidth + "px";
|
||||
if (canvas.offsetWidth > 862 || isExtension) {
|
||||
targetElement.style.width = (canvas.offsetWidth + 2) + "px";
|
||||
}
|
||||
|
||||
if (isExtension) {
|
||||
targetElement.style.overflow = "visible";
|
||||
}
|
||||
|
||||
if (fullScreenMode) {
|
||||
@ -648,8 +709,48 @@ onUiLoaded(async() => {
|
||||
mouseY = e.offsetY;
|
||||
}
|
||||
|
||||
// Simulation of the function to put a long image into the screen.
|
||||
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
|
||||
// We hide the image and show it to the user when it is ready.
|
||||
|
||||
targetElement.isExpanded = false;
|
||||
function autoExpand() {
|
||||
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
|
||||
if (canvas) {
|
||||
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
|
||||
targetElement.style.visibility = "hidden";
|
||||
setTimeout(() => {
|
||||
fitToScreen();
|
||||
resetZoom();
|
||||
targetElement.style.visibility = "visible";
|
||||
targetElement.isExpanded = true;
|
||||
}, 10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targetElement.addEventListener("mousemove", getMousePosition);
|
||||
|
||||
//observers
|
||||
// Creating an observer with a callback function to handle DOM changes
|
||||
const observer = new MutationObserver((mutationsList, observer) => {
|
||||
for (let mutation of mutationsList) {
|
||||
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
|
||||
mutation.target.tagName.toLowerCase() === 'canvas') {
|
||||
targetElement.isExpanded = false;
|
||||
setTimeout(resetZoom, 10);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Apply auto expand if enabled
|
||||
if (hotkeysConfig.canvas_auto_expand) {
|
||||
targetElement.addEventListener("mousemove", autoExpand);
|
||||
// Set up an observer to track attribute changes
|
||||
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
|
||||
}
|
||||
|
||||
// Handle events only inside the targetElement
|
||||
let isKeyDownHandlerAttached = false;
|
||||
|
||||
@ -754,6 +855,11 @@ onUiLoaded(async() => {
|
||||
if (isMoving && elemId === activeElement) {
|
||||
updatePanPosition(e.movementX, e.movementY);
|
||||
targetElement.style.pointerEvents = "none";
|
||||
|
||||
if (isExtension) {
|
||||
targetElement.style.overflow = "visible";
|
||||
}
|
||||
|
||||
} else {
|
||||
targetElement.style.pointerEvents = "auto";
|
||||
}
|
||||
@ -764,13 +870,93 @@ onUiLoaded(async() => {
|
||||
isMoving = false;
|
||||
};
|
||||
|
||||
// Checks for extension
|
||||
function checkForOutBox() {
|
||||
const parentElement = targetElement.closest('[id^="component-"]');
|
||||
if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) {
|
||||
resetZoom();
|
||||
targetElement.isExpanded = true;
|
||||
}
|
||||
|
||||
if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) {
|
||||
resetZoom();
|
||||
}
|
||||
|
||||
if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) {
|
||||
resetZoom();
|
||||
}
|
||||
}
|
||||
|
||||
if (isExtension) {
|
||||
targetElement.addEventListener("mousemove", checkForOutBox);
|
||||
}
|
||||
|
||||
|
||||
window.addEventListener('resize', (e) => {
|
||||
resetZoom();
|
||||
|
||||
if (isExtension) {
|
||||
targetElement.isExpanded = false;
|
||||
targetElement.isZoomed = false;
|
||||
}
|
||||
});
|
||||
|
||||
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
||||
|
||||
|
||||
}
|
||||
|
||||
applyZoomAndPan(elementIDs.sketch);
|
||||
applyZoomAndPan(elementIDs.inpaint);
|
||||
applyZoomAndPan(elementIDs.inpaintSketch);
|
||||
applyZoomAndPan(elementIDs.sketch, false);
|
||||
applyZoomAndPan(elementIDs.inpaint, false);
|
||||
applyZoomAndPan(elementIDs.inpaintSketch, false);
|
||||
|
||||
// Make the function global so that other extensions can take advantage of this solution
|
||||
window.applyZoomAndPan = applyZoomAndPan;
|
||||
const applyZoomAndPanIntegration = async(id, elementIDs) => {
|
||||
const mainEl = document.querySelector(id);
|
||||
if (id.toLocaleLowerCase() === "none") {
|
||||
for (const elementID of elementIDs) {
|
||||
const el = await waitForElement(elementID);
|
||||
if (!el) break;
|
||||
applyZoomAndPan(elementID);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!mainEl) return;
|
||||
mainEl.addEventListener("click", async() => {
|
||||
for (const elementID of elementIDs) {
|
||||
const el = await waitForElement(elementID);
|
||||
if (!el) break;
|
||||
applyZoomAndPan(elementID);
|
||||
}
|
||||
}, {once: true});
|
||||
};
|
||||
|
||||
window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
|
||||
|
||||
window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
|
||||
|
||||
/*
|
||||
The function `applyZoomAndPanIntegration` takes two arguments:
|
||||
|
||||
1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
|
||||
If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
|
||||
|
||||
2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
|
||||
If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
|
||||
|
||||
Example usage:
|
||||
applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||
In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
|
||||
*/
|
||||
|
||||
// More examples
|
||||
// Add integration with ControlNet txt2img One TAB
|
||||
// applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||
|
||||
// Add integration with ControlNet txt2img Tabs
|
||||
// applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
|
||||
|
||||
// Add integration with Inpaint Anything
|
||||
// applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
|
||||
});
|
||||
|
@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||
}))
|
||||
|
@ -61,3 +61,6 @@
|
||||
to {opacity: 1;}
|
||||
}
|
||||
|
||||
.styler {
|
||||
overflow:inherit !important;
|
||||
}
|
@ -1,5 +1,7 @@
|
||||
import math
|
||||
|
||||
import gradio as gr
|
||||
from modules import scripts, shared, ui_components, ui_settings
|
||||
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||
from modules.ui_components import FormColumn
|
||||
|
||||
|
||||
@ -19,18 +21,38 @@ class ExtraOptionsSection(scripts.Script):
|
||||
def ui(self, is_img2img):
|
||||
self.comps = []
|
||||
self.setting_names = []
|
||||
self.infotext_fields = []
|
||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||
|
||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
||||
for setting_name in shared.opts.extra_options:
|
||||
with FormColumn():
|
||||
comp = ui_settings.create_setting_component(setting_name)
|
||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
||||
|
||||
self.comps.append(comp)
|
||||
self.setting_names.append(setting_name)
|
||||
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||
|
||||
for row in range(row_count):
|
||||
with gr.Row():
|
||||
for col in range(shared.opts.extra_options_cols):
|
||||
index = row * shared.opts.extra_options_cols + col
|
||||
if index >= len(extra_options):
|
||||
break
|
||||
|
||||
setting_name = extra_options[index]
|
||||
|
||||
with FormColumn():
|
||||
comp = ui_settings.create_setting_component(setting_name)
|
||||
|
||||
self.comps.append(comp)
|
||||
self.setting_names.append(setting_name)
|
||||
|
||||
setting_infotext_name = mapping.get(setting_name)
|
||||
if setting_infotext_name is not None:
|
||||
self.infotext_fields.append((comp, setting_infotext_name))
|
||||
|
||||
def get_settings_values():
|
||||
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||
return res[0] if len(res) == 1 else res
|
||||
|
||||
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||
|
||||
@ -43,6 +65,10 @@ class ExtraOptionsSection(scripts.Script):
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
|
||||
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
||||
}))
|
||||
|
||||
|
||||
|
@ -20,7 +20,13 @@ function reportWindowSize() {
|
||||
var button = gradioApp().getElementById(tab + '_generate_box');
|
||||
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
||||
target.insertBefore(button, target.firstElementChild);
|
||||
|
||||
gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener("resize", reportWindowSize);
|
||||
|
||||
onUiLoaded(function() {
|
||||
reportWindowSize();
|
||||
});
|
||||
|
@ -1,20 +1,38 @@
|
||||
function toggleCss(key, css, enable) {
|
||||
var style = document.getElementById(key);
|
||||
if (enable && !style) {
|
||||
style = document.createElement('style');
|
||||
style.id = key;
|
||||
style.type = 'text/css';
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
if (style && !enable) {
|
||||
document.head.removeChild(style);
|
||||
}
|
||||
if (style) {
|
||||
style.innerHTML == '';
|
||||
style.appendChild(document.createTextNode(css));
|
||||
}
|
||||
}
|
||||
|
||||
function setupExtraNetworksForTab(tabname) {
|
||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||
|
||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
|
||||
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
||||
var search = searchDiv.querySelector('textarea');
|
||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||
|
||||
search.classList.add('search');
|
||||
sort.classList.add('sort');
|
||||
sortOrder.classList.add('sortorder');
|
||||
sort.dataset.sortkey = 'sortDefault';
|
||||
tabs.appendChild(search);
|
||||
tabs.appendChild(searchDiv);
|
||||
tabs.appendChild(sort);
|
||||
tabs.appendChild(sortOrder);
|
||||
tabs.appendChild(refresh);
|
||||
tabs.appendChild(showDirsDiv);
|
||||
|
||||
var applyFilter = function() {
|
||||
var searchTerm = search.value.toLowerCase();
|
||||
@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) {
|
||||
});
|
||||
|
||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||
|
||||
var showDirsUpdate = function() {
|
||||
var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
|
||||
toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
|
||||
localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
|
||||
};
|
||||
showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
|
||||
showDirs.addEventListener("change", showDirsUpdate);
|
||||
showDirsUpdate();
|
||||
}
|
||||
|
||||
function applyExtraNetworkFilter(tabname) {
|
||||
@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) {
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event) {
|
||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
|
||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
||||
var button = event.target;
|
||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||
|
||||
@ -222,6 +249,15 @@ function popup(contents) {
|
||||
globalPopup.style.display = "flex";
|
||||
}
|
||||
|
||||
var storedPopupIds = {};
|
||||
function popupId(id) {
|
||||
if (!storedPopupIds[id]) {
|
||||
storedPopupIds[id] = gradioApp().getElementById(id);
|
||||
}
|
||||
|
||||
popup(storedPopupIds[id]);
|
||||
}
|
||||
|
||||
function extraNetworksShowMetadata(text) {
|
||||
var elem = document.createElement('pre');
|
||||
elem.classList.add('popup-metadata');
|
||||
@ -305,7 +341,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
|
||||
newDiv.innerHTML = data.html;
|
||||
var newCard = newDiv.firstElementChild;
|
||||
|
||||
newCard.style = '';
|
||||
newCard.style.display = '';
|
||||
card.parentElement.insertBefore(newCard, card);
|
||||
card.parentElement.removeChild(card);
|
||||
}
|
||||
|
@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
|
||||
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||
}
|
||||
});
|
||||
|
||||
onUiLoaded(function() {
|
||||
for (var comp of window.gradio_config.components) {
|
||||
if (comp.props.webui_tooltip && comp.props.elem_id) {
|
||||
var elem = gradioApp().getElementById(comp.props.elem_id);
|
||||
if (elem) {
|
||||
elem.title = comp.props.webui_tooltip;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
|
||||
var event = isFirefox ? 'mousedown' : 'click';
|
||||
|
||||
e.addEventListener(event, function(evt) {
|
||||
if (evt.button == 1) {
|
||||
open(evt.target.src);
|
||||
evt.preventDefault();
|
||||
return;
|
||||
}
|
||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||
|
||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||
|
37
javascript/inputAccordion.js
Normal file
37
javascript/inputAccordion.js
Normal file
@ -0,0 +1,37 @@
|
||||
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
var elem = mutationRecord.target;
|
||||
var open = elem.classList.contains('open');
|
||||
|
||||
var accordion = elem.parentNode;
|
||||
accordion.classList.toggle('input-accordion-open', open);
|
||||
|
||||
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||
checkbox.checked = open;
|
||||
updateInput(checkbox);
|
||||
|
||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||
if (extra) {
|
||||
extra.style.display = open ? "" : "none";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
function inputAccordionChecked(id, checked) {
|
||||
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
||||
if (label.classList.contains('open') != checked) {
|
||||
label.click();
|
||||
}
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||
var labelWrap = accordion.querySelector('.label-wrap');
|
||||
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||
|
||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||
if (extra) {
|
||||
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||
}
|
||||
}
|
||||
});
|
26
javascript/localStorage.js
Normal file
26
javascript/localStorage.js
Normal file
@ -0,0 +1,26 @@
|
||||
|
||||
function localSet(k, v) {
|
||||
try {
|
||||
localStorage.setItem(k, v);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to save ${k} to localStorage: ${e}`);
|
||||
}
|
||||
}
|
||||
|
||||
function localGet(k, def) {
|
||||
try {
|
||||
return localStorage.getItem(k);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to load ${k} from localStorage: ${e}`);
|
||||
}
|
||||
|
||||
return def;
|
||||
}
|
||||
|
||||
function localRemove(k) {
|
||||
try {
|
||||
return localStorage.removeItem(k);
|
||||
} catch (e) {
|
||||
console.warn(`Failed to remove ${k} from localStorage: ${e}`);
|
||||
}
|
||||
}
|
@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
|
||||
train_hypernetwork: 'OPTION',
|
||||
txt2img_styles: 'OPTION',
|
||||
img2img_styles: 'OPTION',
|
||||
setting_random_artist_categories: 'SPAN',
|
||||
setting_face_restoration_model: 'SPAN',
|
||||
setting_realesrgan_enabled_models: 'SPAN',
|
||||
extras_upscaler_1: 'SPAN',
|
||||
extras_upscaler_2: 'SPAN',
|
||||
setting_random_artist_categories: 'OPTION',
|
||||
setting_face_restoration_model: 'OPTION',
|
||||
setting_realesrgan_enabled_models: 'OPTION',
|
||||
extras_upscaler_1: 'OPTION',
|
||||
extras_upscaler_2: 'OPTION',
|
||||
};
|
||||
|
||||
var re_num = /^[.\d]+$/;
|
||||
@ -107,12 +107,41 @@ function processNode(node) {
|
||||
});
|
||||
}
|
||||
|
||||
function localizeWholePage() {
|
||||
processNode(gradioApp());
|
||||
|
||||
function elem(comp) {
|
||||
var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id;
|
||||
return gradioApp().getElementById(elem_id);
|
||||
}
|
||||
|
||||
for (var comp of window.gradio_config.components) {
|
||||
if (comp.props.webui_tooltip) {
|
||||
let e = elem(comp);
|
||||
|
||||
let tl = e ? getTranslation(e.title) : undefined;
|
||||
if (tl !== undefined) {
|
||||
e.title = tl;
|
||||
}
|
||||
}
|
||||
if (comp.props.placeholder) {
|
||||
let e = elem(comp);
|
||||
let textbox = e ? e.querySelector('[placeholder]') : null;
|
||||
|
||||
let tl = textbox ? getTranslation(textbox.placeholder) : undefined;
|
||||
if (tl !== undefined) {
|
||||
textbox.placeholder = tl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function dumpTranslations() {
|
||||
if (!hasLocalization()) {
|
||||
// If we don't have any localization,
|
||||
// we will not have traversed the app to find
|
||||
// original_lines, so do that now.
|
||||
processNode(gradioApp());
|
||||
localizeWholePage();
|
||||
}
|
||||
var dumped = {};
|
||||
if (localization.rtl) {
|
||||
@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||
});
|
||||
});
|
||||
|
||||
processNode(gradioApp());
|
||||
localizeWholePage();
|
||||
|
||||
if (localization.rtl) { // if the language is from right to left,
|
||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||
|
@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
||||
var dateStart = new Date();
|
||||
var wasEverActive = false;
|
||||
var parentProgressbar = progressbarContainer.parentNode;
|
||||
var parentGallery = gallery ? gallery.parentNode : null;
|
||||
|
||||
var divProgress = document.createElement('div');
|
||||
divProgress.className = 'progressDiv';
|
||||
@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
||||
divProgress.appendChild(divInner);
|
||||
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
||||
|
||||
if (parentGallery) {
|
||||
var livePreview = document.createElement('div');
|
||||
livePreview.className = 'livePreview';
|
||||
parentGallery.insertBefore(livePreview, gallery);
|
||||
}
|
||||
var livePreview = null;
|
||||
|
||||
var removeProgressBar = function() {
|
||||
if (!divProgress) return;
|
||||
|
||||
setTitle("");
|
||||
parentProgressbar.removeChild(divProgress);
|
||||
if (parentGallery) parentGallery.removeChild(livePreview);
|
||||
if (gallery && livePreview) gallery.removeChild(livePreview);
|
||||
atEnd();
|
||||
|
||||
divProgress = null;
|
||||
};
|
||||
|
||||
var fun = function(id_task, id_live_preview) {
|
||||
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
||||
var funProgress = function(id_task) {
|
||||
request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) {
|
||||
if (res.completed) {
|
||||
removeProgressBar();
|
||||
return;
|
||||
}
|
||||
|
||||
var rect = progressbarContainer.getBoundingClientRect();
|
||||
|
||||
if (rect.width) {
|
||||
divProgress.style.width = rect.width + "px";
|
||||
}
|
||||
|
||||
let progressText = "";
|
||||
|
||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
||||
@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
||||
progressText += " ETA: " + formatTime(res.eta);
|
||||
}
|
||||
|
||||
|
||||
setTitle(progressText);
|
||||
|
||||
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
||||
@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
||||
return;
|
||||
}
|
||||
|
||||
if (onProgress) {
|
||||
onProgress(res);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
funProgress(id_task, res.id_live_preview);
|
||||
}, opts.live_preview_refresh_period || 500);
|
||||
}, function() {
|
||||
removeProgressBar();
|
||||
});
|
||||
};
|
||||
|
||||
var funLivePreview = function(id_task, id_live_preview) {
|
||||
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
||||
if (!divProgress) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (res.live_preview && gallery) {
|
||||
rect = gallery.getBoundingClientRect();
|
||||
if (rect.width) {
|
||||
livePreview.style.width = rect.width + "px";
|
||||
livePreview.style.height = rect.height + "px";
|
||||
}
|
||||
|
||||
var img = new Image();
|
||||
img.onload = function() {
|
||||
if (!livePreview) {
|
||||
livePreview = document.createElement('div');
|
||||
livePreview.className = 'livePreview';
|
||||
gallery.insertBefore(livePreview, gallery.firstElementChild);
|
||||
}
|
||||
|
||||
livePreview.appendChild(img);
|
||||
if (livePreview.childElementCount > 2) {
|
||||
livePreview.removeChild(livePreview.firstElementChild);
|
||||
@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
||||
img.src = res.live_preview;
|
||||
}
|
||||
|
||||
|
||||
if (onProgress) {
|
||||
onProgress(res);
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
fun(id_task, res.id_live_preview);
|
||||
funLivePreview(id_task, res.id_live_preview);
|
||||
}, opts.live_preview_refresh_period || 500);
|
||||
}, function() {
|
||||
removeProgressBar();
|
||||
});
|
||||
};
|
||||
|
||||
fun(id_task, 0);
|
||||
funProgress(id_task, 0);
|
||||
|
||||
if (gallery) {
|
||||
funLivePreview(id_task, 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
141
javascript/resizeHandle.js
Normal file
141
javascript/resizeHandle.js
Normal file
@ -0,0 +1,141 @@
|
||||
(function() {
|
||||
const GRADIO_MIN_WIDTH = 320;
|
||||
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
|
||||
const PAD = 16;
|
||||
const DEBOUNCE_TIME = 100;
|
||||
|
||||
const R = {
|
||||
tracking: false,
|
||||
parent: null,
|
||||
parentWidth: null,
|
||||
leftCol: null,
|
||||
leftColStartWidth: null,
|
||||
screenX: null,
|
||||
};
|
||||
|
||||
let resizeTimer;
|
||||
let parents = [];
|
||||
|
||||
function setLeftColGridTemplate(el, width) {
|
||||
el.style.gridTemplateColumns = `${width}px 16px 1fr`;
|
||||
}
|
||||
|
||||
function displayResizeHandle(parent) {
|
||||
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
|
||||
parent.style.display = 'flex';
|
||||
if (R.handle != null) {
|
||||
R.handle.style.opacity = '0';
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
parent.style.display = 'grid';
|
||||
if (R.handle != null) {
|
||||
R.handle.style.opacity = '100';
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
function afterResize(parent) {
|
||||
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
|
||||
const oldParentWidth = R.parentWidth;
|
||||
const newParentWidth = parent.offsetWidth;
|
||||
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
|
||||
|
||||
const ratio = newParentWidth / oldParentWidth;
|
||||
|
||||
const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
|
||||
setLeftColGridTemplate(parent, newWidthL);
|
||||
|
||||
R.parentWidth = newParentWidth;
|
||||
}
|
||||
}
|
||||
|
||||
function setup(parent) {
|
||||
const leftCol = parent.firstElementChild;
|
||||
const rightCol = parent.lastElementChild;
|
||||
|
||||
parents.push(parent);
|
||||
|
||||
parent.style.display = 'grid';
|
||||
parent.style.gap = '0';
|
||||
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||
|
||||
const resizeHandle = document.createElement('div');
|
||||
resizeHandle.classList.add('resize-handle');
|
||||
parent.insertBefore(resizeHandle, rightCol);
|
||||
|
||||
resizeHandle.addEventListener('mousedown', (evt) => {
|
||||
if (evt.button !== 0) return;
|
||||
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
document.body.classList.add('resizing');
|
||||
|
||||
R.tracking = true;
|
||||
R.parent = parent;
|
||||
R.parentWidth = parent.offsetWidth;
|
||||
R.handle = resizeHandle;
|
||||
R.leftCol = leftCol;
|
||||
R.leftColStartWidth = leftCol.offsetWidth;
|
||||
R.screenX = evt.screenX;
|
||||
});
|
||||
|
||||
resizeHandle.addEventListener('dblclick', (evt) => {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||
});
|
||||
|
||||
afterResize(parent);
|
||||
}
|
||||
|
||||
window.addEventListener('mousemove', (evt) => {
|
||||
if (evt.button !== 0) return;
|
||||
|
||||
if (R.tracking) {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
const delta = R.screenX - evt.screenX;
|
||||
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
|
||||
setLeftColGridTemplate(R.parent, leftColWidth);
|
||||
}
|
||||
});
|
||||
|
||||
window.addEventListener('mouseup', (evt) => {
|
||||
if (evt.button !== 0) return;
|
||||
|
||||
if (R.tracking) {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
R.tracking = false;
|
||||
|
||||
document.body.classList.remove('resizing');
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
clearTimeout(resizeTimer);
|
||||
|
||||
resizeTimer = setTimeout(function() {
|
||||
for (const parent of parents) {
|
||||
afterResize(parent);
|
||||
}
|
||||
}, DEBOUNCE_TIME);
|
||||
});
|
||||
|
||||
setupResizeHandle = setup;
|
||||
})();
|
||||
|
||||
onUiLoaded(function() {
|
||||
for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
|
||||
if (!elem.querySelector('.resize-handle')) {
|
||||
setupResizeHandle(elem);
|
||||
}
|
||||
}
|
||||
});
|
@ -19,28 +19,11 @@ function all_gallery_buttons() {
|
||||
}
|
||||
|
||||
function selected_gallery_button() {
|
||||
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
|
||||
var visibleCurrentButton = null;
|
||||
allCurrentButtons.forEach(function(elem) {
|
||||
if (elem.parentElement.offsetParent) {
|
||||
visibleCurrentButton = elem;
|
||||
}
|
||||
});
|
||||
return visibleCurrentButton;
|
||||
return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;
|
||||
}
|
||||
|
||||
function selected_gallery_index() {
|
||||
var buttons = all_gallery_buttons();
|
||||
var button = selected_gallery_button();
|
||||
|
||||
var result = -1;
|
||||
buttons.forEach(function(v, i) {
|
||||
if (v == button) {
|
||||
result = i;
|
||||
}
|
||||
});
|
||||
|
||||
return result;
|
||||
return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
|
||||
}
|
||||
|
||||
function extract_image_from_gallery(gallery) {
|
||||
@ -152,11 +135,11 @@ function submit() {
|
||||
showSubmitButtons('txt2img', false);
|
||||
|
||||
var id = randomId();
|
||||
localStorage.setItem("txt2img_task_id", id);
|
||||
localSet("txt2img_task_id", id);
|
||||
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||
showSubmitButtons('txt2img', true);
|
||||
localStorage.removeItem("txt2img_task_id");
|
||||
localRemove("txt2img_task_id");
|
||||
showRestoreProgressButton('txt2img', false);
|
||||
});
|
||||
|
||||
@ -171,11 +154,11 @@ function submit_img2img() {
|
||||
showSubmitButtons('img2img', false);
|
||||
|
||||
var id = randomId();
|
||||
localStorage.setItem("img2img_task_id", id);
|
||||
localSet("img2img_task_id", id);
|
||||
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||
showSubmitButtons('img2img', true);
|
||||
localStorage.removeItem("img2img_task_id");
|
||||
localRemove("img2img_task_id");
|
||||
showRestoreProgressButton('img2img', false);
|
||||
});
|
||||
|
||||
@ -189,9 +172,7 @@ function submit_img2img() {
|
||||
|
||||
function restoreProgressTxt2img() {
|
||||
showRestoreProgressButton("txt2img", false);
|
||||
var id = localStorage.getItem("txt2img_task_id");
|
||||
|
||||
id = localStorage.getItem("txt2img_task_id");
|
||||
var id = localGet("txt2img_task_id");
|
||||
|
||||
if (id) {
|
||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||
@ -205,7 +186,7 @@ function restoreProgressTxt2img() {
|
||||
function restoreProgressImg2img() {
|
||||
showRestoreProgressButton("img2img", false);
|
||||
|
||||
var id = localStorage.getItem("img2img_task_id");
|
||||
var id = localGet("img2img_task_id");
|
||||
|
||||
if (id) {
|
||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||
@ -218,8 +199,8 @@ function restoreProgressImg2img() {
|
||||
|
||||
|
||||
onUiLoaded(function() {
|
||||
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
|
||||
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
|
||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||
});
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from modules import launch_utils
|
||||
|
||||
|
||||
args = launch_utils.args
|
||||
python = launch_utils.python
|
||||
git = launch_utils.git
|
||||
@ -26,8 +25,11 @@ start = launch_utils.start
|
||||
|
||||
|
||||
def main():
|
||||
if not args.skip_prepare_environment:
|
||||
prepare_environment()
|
||||
launch_utils.startup_timer.record("initial startup")
|
||||
|
||||
with launch_utils.startup_timer.subcategory("prepare environment"):
|
||||
if not args.skip_prepare_environment:
|
||||
prepare_environment()
|
||||
|
||||
if args.test_server:
|
||||
configure_for_tests()
|
||||
|
@ -4,6 +4,8 @@ import os
|
||||
import time
|
||||
import datetime
|
||||
import uvicorn
|
||||
import ipaddress
|
||||
import requests
|
||||
import gradio as gr
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
@ -15,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||
from modules.sd_vae import vae_dict
|
||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@ -56,7 +57,41 @@ def setUpscalers(req: dict):
|
||||
return reqDict
|
||||
|
||||
|
||||
def verify_url(url):
|
||||
"""Returns True if the url refers to a global resource."""
|
||||
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
domain_name = parsed_url.netloc
|
||||
host = socket.gethostbyname_ex(domain_name)
|
||||
for ip in host[2]:
|
||||
ip_addr = ipaddress.ip_address(ip)
|
||||
if not ip_addr.is_global:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
if not opts.api_enable_requests:
|
||||
raise HTTPException(status_code=500, detail="Requests not allowed")
|
||||
|
||||
if opts.api_forbid_local_requests and not verify_url(encoding):
|
||||
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
||||
|
||||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
image = Image.open(BytesIO(response.content))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
@ -197,6 +232,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||
@ -207,6 +243,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem])
|
||||
|
||||
if shared.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||
@ -329,6 +366,7 @@ class Api:
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.is_api = True
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
@ -343,6 +381,7 @@ class Api:
|
||||
processed = process_images(p)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@ -388,6 +427,7 @@ class Api:
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
p.is_api = True
|
||||
p.scripts = script_runner
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
@ -402,6 +442,7 @@ class Api:
|
||||
processed = process_images(p)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
|
||||
@ -433,9 +474,6 @@ class Api:
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return models.PNGInfoResponse(info="")
|
||||
@ -444,9 +482,10 @@ class Api:
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
@ -530,7 +569,7 @@ class Api:
|
||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||
|
||||
for k, v in req.items():
|
||||
shared.opts.set(k, v)
|
||||
shared.opts.set(k, v, is_api=True)
|
||||
|
||||
shared.opts.save(shared.config_filename)
|
||||
return
|
||||
@ -562,10 +601,12 @@ class Api:
|
||||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
import modules.sd_models as sd_models
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
||||
|
||||
def get_sd_vaes(self):
|
||||
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
||||
import modules.sd_vae as sd_vae
|
||||
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
@ -608,6 +649,10 @@ class Api:
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
def refresh_vae(self):
|
||||
with self.queue_lock:
|
||||
shared_items.refresh_vae_list()
|
||||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin(job="create_embedding")
|
||||
@ -724,6 +769,25 @@ class Api:
|
||||
cuda = {'error': f'{err}'}
|
||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||
|
||||
def get_extensions_list(self):
|
||||
from modules import extensions
|
||||
extensions.list_extensions()
|
||||
ext_list = []
|
||||
for ext in extensions.extensions:
|
||||
ext: extensions.Extension
|
||||
ext.read_info_from_repo()
|
||||
if ext.remote is not None:
|
||||
ext_list.append({
|
||||
"name": ext.name,
|
||||
"remote": ext.remote,
|
||||
"branch": ext.branch,
|
||||
"commit_hash":ext.commit_hash,
|
||||
"commit_date":ext.commit_date,
|
||||
"version":ext.version,
|
||||
"enabled":ext.enabled
|
||||
})
|
||||
return ext_list
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||
|
@ -50,10 +50,12 @@ class PydanticModelGenerator:
|
||||
additional_fields = None,
|
||||
):
|
||||
def field_type_generator(k, v):
|
||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
||||
# print(k, v.annotation, v.default)
|
||||
field_type = v.annotation
|
||||
|
||||
if field_type == 'Image':
|
||||
# images are sent as base64 strings via API
|
||||
field_type = 'str'
|
||||
|
||||
return Optional[field_type]
|
||||
|
||||
def merge_class_params(class_):
|
||||
@ -63,7 +65,6 @@ class PydanticModelGenerator:
|
||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||
return parameters
|
||||
|
||||
|
||||
self._model_name = model_name
|
||||
self._class_data = merge_class_params(class_instance)
|
||||
|
||||
@ -72,7 +73,7 @@ class PydanticModelGenerator:
|
||||
field=underscore(k),
|
||||
field_alias=k,
|
||||
field_type=field_type_generator(k, v),
|
||||
field_value=v.default
|
||||
field_value=None if isinstance(v.default, property) else v.default
|
||||
)
|
||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||
]
|
||||
@ -177,7 +178,8 @@ class PNGInfoRequest(BaseModel):
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||
items: dict = Field(title="Items", description="An object containing all the info the image had")
|
||||
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
|
||||
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
@ -310,3 +312,12 @@ class ScriptInfo(BaseModel):
|
||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
class ExtensionItem(BaseModel):
|
||||
name: str = Field(title="Name", description="Extension name")
|
||||
remote: str = Field(title="Remote", description="Extension Repository URL")
|
||||
branch: str = Field(title="Branch", description="Extension Repository Branch")
|
||||
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
||||
version: str = Field(title="Version", description="Extension Version")
|
||||
commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||
|
@ -1,11 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
import os.path
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules.paths import data_path, script_path
|
||||
|
||||
cache_filename = os.path.join(data_path, "cache.json")
|
||||
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
||||
cache_data = None
|
||||
cache_lock = threading.Lock()
|
||||
|
||||
@ -29,9 +30,12 @@ def dump_cache():
|
||||
time.sleep(1)
|
||||
|
||||
with cache_lock:
|
||||
with open(cache_filename, "w", encoding="utf8") as file:
|
||||
cache_filename_tmp = cache_filename + "-"
|
||||
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||
json.dump(cache_data, file, indent=4)
|
||||
|
||||
os.replace(cache_filename_tmp, cache_filename)
|
||||
|
||||
dump_cache_after = None
|
||||
dump_cache_thread = None
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
from functools import wraps
|
||||
import html
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules import shared, progress, errors, devices
|
||||
from modules import shared, progress, errors, devices, fifo_lock
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
queue_lock = fifo_lock.FIFOLock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
|
@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
|
||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
@ -33,9 +35,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||
@ -66,6 +69,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@ -78,7 +82,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
@ -110,3 +114,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
|
@ -8,14 +8,12 @@ import time
|
||||
import tqdm
|
||||
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
import git
|
||||
|
||||
from modules import shared, extensions, errors
|
||||
from modules.paths_internal import script_path, config_states_dir
|
||||
|
||||
|
||||
all_config_states = OrderedDict()
|
||||
all_config_states = {}
|
||||
|
||||
|
||||
def list_config_states():
|
||||
@ -28,10 +26,14 @@ def list_config_states():
|
||||
for filename in os.listdir(config_states_dir):
|
||||
if filename.endswith(".json"):
|
||||
path = os.path.join(config_states_dir, filename)
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
j = json.load(f)
|
||||
j["filepath"] = path
|
||||
config_states.append(j)
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
j = json.load(f)
|
||||
assert "created_at" in j, '"created_at" does not exist'
|
||||
j["filepath"] = path
|
||||
config_states.append(j)
|
||||
except Exception as e:
|
||||
print(f'[ERROR]: Config states {path}, {e}')
|
||||
|
||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||
|
||||
|
@ -3,7 +3,7 @@ import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors
|
||||
from modules import errors, shared
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@ -17,8 +17,6 @@ def has_mps() -> bool:
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
|
||||
@ -40,8 +38,6 @@ def get_optimal_device():
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
from modules import shared
|
||||
|
||||
if task in shared.cmd_opts.use_cpu:
|
||||
return cpu
|
||||
|
||||
@ -71,14 +67,17 @@ def enable_tf32():
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
dtype_unet = torch.float16
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
device_esrgan: torch.device = None
|
||||
device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
@ -90,26 +89,10 @@ def cond_cast_float(input):
|
||||
return input.float() if unet_needs_upcast else input
|
||||
|
||||
|
||||
def randn(seed, shape):
|
||||
from modules.shared import opts
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
|
||||
def randn_without_seed(shape):
|
||||
from modules.shared import opts
|
||||
|
||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||
return torch.randn(shape, device=cpu).to(device)
|
||||
return torch.randn(shape, device=device)
|
||||
nv_rng = None
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
from modules import shared
|
||||
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@ -128,8 +111,6 @@ class NansException(Exception):
|
||||
|
||||
|
||||
def test_for_nans(x, where):
|
||||
from modules import shared
|
||||
|
||||
if shared.cmd_opts.disable_nan_check:
|
||||
return
|
||||
|
||||
@ -169,3 +150,4 @@ def first_time_calculation():
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
|
@ -84,3 +84,53 @@ def run(code, task):
|
||||
code()
|
||||
except Exception as e:
|
||||
display(task, e)
|
||||
|
||||
|
||||
def check_versions():
|
||||
from packaging import version
|
||||
from modules import shared
|
||||
|
||||
import torch
|
||||
import gradio
|
||||
|
||||
expected_torch_version = "2.0.0"
|
||||
expected_xformers_version = "0.0.20"
|
||||
expected_gradio_version = "3.41.2"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
print_error_explanation(f"""
|
||||
You are running torch {torch.__version__}.
|
||||
The program is tested to work with torch {expected_torch_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||
there are reports of issues with training tab on the latest version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if shared.xformers_available:
|
||||
import xformers
|
||||
|
||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||
print_error_explanation(f"""
|
||||
You are running xformers {xformers.__version__}.
|
||||
The program is tested to work with xformers {expected_xformers_version}.
|
||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
if gradio.__version__ != expected_gradio_version:
|
||||
print_error_explanation(f"""
|
||||
You are running gradio {gradio.__version__}.
|
||||
The program is designed to work with gradio {expected_gradio_version}.
|
||||
Using a different version of gradio is extremely likely to break the program.
|
||||
|
||||
Reasons why you have the mismatched gradio version can be:
|
||||
- you use --skip-install flag.
|
||||
- you use webui.py to start the program instead of launch.py.
|
||||
- an extension installs the incompatible gradio version.
|
||||
|
||||
Use --skip-version-check commandline argument to disable this check.
|
||||
""".strip())
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
from modules import shared, errors, cache
|
||||
from modules import shared, errors, cache, scripts
|
||||
from modules.gitpython_hack import Repo
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||
|
||||
@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
|
||||
|
||||
|
||||
def active():
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||
return []
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||
else:
|
||||
return [x for x in extensions if x.enabled]
|
||||
@ -90,8 +90,6 @@ class Extension:
|
||||
self.have_info_from_repo = True
|
||||
|
||||
def list_files(self, subdir, extension):
|
||||
from modules import scripts
|
||||
|
||||
dirpath = os.path.join(self.path, subdir)
|
||||
if not os.path.isdir(dirpath):
|
||||
return []
|
||||
@ -141,8 +139,12 @@ def list_extensions():
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
if shared.opts.disable_all_extensions == "all":
|
||||
if shared.cmd_opts.disable_all_extensions:
|
||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "all":
|
||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||
elif shared.cmd_opts.disable_extra_extensions:
|
||||
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
||||
elif shared.opts.disable_all_extensions == "extra":
|
||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from modules import errors
|
||||
@ -84,27 +87,55 @@ class ExtraNetwork:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def lookup_extra_networks(extra_network_data):
|
||||
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
|
||||
|
||||
Example input:
|
||||
{
|
||||
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
|
||||
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||
}
|
||||
|
||||
Example output:
|
||||
|
||||
{
|
||||
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||
}
|
||||
"""
|
||||
|
||||
res = {}
|
||||
|
||||
for extra_network_name, extra_network_args in list(extra_network_data.items()):
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
alias = extra_network_aliases.get(extra_network_name, None)
|
||||
|
||||
if alias is not None and extra_network is None:
|
||||
extra_network = alias
|
||||
|
||||
if extra_network is None:
|
||||
logging.info(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
|
||||
res.setdefault(extra_network, []).extend(extra_network_args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def activate(p, extra_network_data):
|
||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||
activate for all remaining registered networks with an empty argument list"""
|
||||
|
||||
activated = []
|
||||
|
||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
|
||||
if extra_network is None:
|
||||
extra_network = extra_network_aliases.get(extra_network_name, None)
|
||||
|
||||
if extra_network is None:
|
||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||
continue
|
||||
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
|
||||
|
||||
try:
|
||||
extra_network.activate(p, extra_network_args)
|
||||
activated.append(extra_network)
|
||||
except Exception as e:
|
||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
if extra_network in activated:
|
||||
@ -123,19 +154,16 @@ def deactivate(p, extra_network_data):
|
||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||
deactivate for all remaining registered networks"""
|
||||
|
||||
for extra_network_name in extra_network_data:
|
||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||
if extra_network is None:
|
||||
continue
|
||||
data = lookup_extra_networks(extra_network_data)
|
||||
|
||||
for extra_network in data:
|
||||
try:
|
||||
extra_network.deactivate(p)
|
||||
except Exception as e:
|
||||
errors.display(e, f"deactivating extra network {extra_network_name}")
|
||||
errors.display(e, f"deactivating extra network {extra_network.name}")
|
||||
|
||||
for extra_network_name, extra_network in extra_network_registry.items():
|
||||
args = extra_network_data.get(extra_network_name, None)
|
||||
if args is not None:
|
||||
if extra_network in data:
|
||||
continue
|
||||
|
||||
try:
|
||||
@ -177,3 +205,20 @@ def parse_prompts(prompts):
|
||||
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
basename, ext = os.path.splitext(filename)
|
||||
metadata_filename = basename + '.json'
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||
|
||||
return metadata
|
||||
|
@ -7,7 +7,7 @@ import json
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
||||
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||
from modules.ui_common import plaintext_to_html
|
||||
import gradio as gr
|
||||
import safetensors.torch
|
||||
@ -72,7 +72,20 @@ def to_half(tensor, enable):
|
||||
return tensor
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||
metadata = {}
|
||||
|
||||
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||
if checkpoint_info is None:
|
||||
continue
|
||||
|
||||
metadata.update(checkpoint_info.metadata)
|
||||
|
||||
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||
shared.state.begin(job="model-merge")
|
||||
|
||||
def fail(message):
|
||||
@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
shared.state.textinfo = "Saving"
|
||||
print(f"Saving to {output_modelname}...")
|
||||
|
||||
metadata = None
|
||||
metadata = {}
|
||||
|
||||
if save_metadata and copy_metadata_fields:
|
||||
if primary_model_info:
|
||||
metadata.update(primary_model_info.metadata)
|
||||
if secondary_model_info:
|
||||
metadata.update(secondary_model_info.metadata)
|
||||
if tertiary_model_info:
|
||||
metadata.update(tertiary_model_info.metadata)
|
||||
|
||||
if save_metadata:
|
||||
metadata = {"format": "pt"}
|
||||
try:
|
||||
metadata.update(json.loads(metadata_json))
|
||||
except Exception as e:
|
||||
errors.display(e, "readin metadata from json")
|
||||
|
||||
metadata["format"] = "pt"
|
||||
|
||||
if save_metadata and add_merge_recipe:
|
||||
merge_recipe = {
|
||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||
"primary_model_hash": primary_model_info.sha256,
|
||||
@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
"is_inpainting": result_is_inpainting_model,
|
||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||
}
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
|
||||
sd_merge_models = {}
|
||||
|
||||
@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
||||
if tertiary_model_info:
|
||||
add_model_metadata(tertiary_model_info)
|
||||
|
||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||
|
||||
_, extension = os.path.splitext(output_modelname)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||
else:
|
||||
torch.save(theta_0, output_modelname)
|
||||
|
||||
|
37
modules/fifo_lock.py
Normal file
37
modules/fifo_lock.py
Normal file
@ -0,0 +1,37 @@
|
||||
import threading
|
||||
import collections
|
||||
|
||||
|
||||
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
||||
class FIFOLock(object):
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._inner_lock = threading.Lock()
|
||||
self._pending_threads = collections.deque()
|
||||
|
||||
def acquire(self, blocking=True):
|
||||
with self._inner_lock:
|
||||
lock_acquired = self._lock.acquire(False)
|
||||
if lock_acquired:
|
||||
return True
|
||||
elif not blocking:
|
||||
return False
|
||||
|
||||
release_event = threading.Event()
|
||||
self._pending_threads.append(release_event)
|
||||
|
||||
release_event.wait()
|
||||
return self._lock.acquire()
|
||||
|
||||
def release(self):
|
||||
with self._inner_lock:
|
||||
if self._pending_threads:
|
||||
release_event = self._pending_threads.popleft()
|
||||
release_event.set()
|
||||
|
||||
self._lock.release()
|
||||
|
||||
__enter__ = acquire
|
||||
|
||||
def __exit__(self, t, v, tb):
|
||||
self.release()
|
@ -6,7 +6,7 @@ import re
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||
@ -32,6 +32,7 @@ class ParamBinding:
|
||||
|
||||
def reset():
|
||||
paste_fields.clear()
|
||||
registered_param_bindings.clear()
|
||||
|
||||
|
||||
def quote(text):
|
||||
@ -198,7 +199,6 @@ def restore_old_hires_fix_params(res):
|
||||
height = int(res.get("Size-2", 512))
|
||||
|
||||
if firstpass_width == 0 or firstpass_height == 0:
|
||||
from modules import processing
|
||||
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
||||
|
||||
res['Size-1'] = firstpass_width
|
||||
@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Hires sampler" not in res:
|
||||
res["Hires sampler"] = "Use same sampler"
|
||||
|
||||
if "Hires checkpoint" not in res:
|
||||
res["Hires checkpoint"] = "Use same checkpoint"
|
||||
|
||||
if "Hires prompt" not in res:
|
||||
res["Hires prompt"] = ""
|
||||
|
||||
@ -304,32 +307,28 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Schedule rho" not in res:
|
||||
res["Schedule rho"] = 0
|
||||
|
||||
if "VAE Encoder" not in res:
|
||||
res["VAE Encoder"] = "Full"
|
||||
|
||||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
||||
|
||||
]
|
||||
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
|
||||
Example content:
|
||||
|
||||
infotext_to_setting_name_mapping = [
|
||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||
('Model hash', 'sd_model_checkpoint'),
|
||||
('ENSD', 'eta_noise_seed_delta'),
|
||||
('Schedule type', 'k_sched_type'),
|
||||
('Schedule max sigma', 'sigma_max'),
|
||||
('Schedule min sigma', 'sigma_min'),
|
||||
('Schedule rho', 'rho'),
|
||||
('Noise multiplier', 'initial_noise_multiplier'),
|
||||
('Eta', 'eta_ancestral'),
|
||||
('Eta DDIM', 'eta_ddim'),
|
||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
||||
('UniPC variant', 'uni_pc_variant'),
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
('Token merging ratio', 'token_merging_ratio'),
|
||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||
('RNG', 'randn_source'),
|
||||
('NGMS', 's_min_uncond'),
|
||||
('Pad conds', 'pad_cond_uncond'),
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def create_override_settings_dict(text_pairs):
|
||||
@ -350,7 +349,8 @@ def create_override_settings_dict(text_pairs):
|
||||
|
||||
params[k] = v.strip()
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
value = params.get(param_name, None)
|
||||
|
||||
if value is None:
|
||||
@ -399,10 +399,16 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
return res
|
||||
|
||||
if override_settings_component is not None:
|
||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
|
||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in already_handled_fields:
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
73
modules/gradio_extensons.py
Normal file
73
modules/gradio_extensons.py
Normal file
@ -0,0 +1,73 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui_tempdir, patches
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.before_component(self, **kwargs)
|
||||
|
||||
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts.scripts_current is not None:
|
||||
scripts.scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def Block_get_config(self):
|
||||
config = original_Block_get_config(self)
|
||||
|
||||
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||
if webui_tooltip:
|
||||
config["webui_tooltip"] = webui_tooltip
|
||||
|
||||
config.pop('example_inputs', None)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def Blocks_get_config_file(self, *args, **kwargs):
|
||||
config = original_Blocks_get_config_file(self, *args, **kwargs)
|
||||
|
||||
for comp_config in config["components"]:
|
||||
if "example_inputs" in comp_config:
|
||||
comp_config["example_inputs"] = {"serialized": []}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
||||
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||
|
||||
|
||||
ui_tempdir.install_ui_tempdir_override()
|
@ -10,7 +10,7 @@ import torch
|
||||
import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import default
|
||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||
from modules.textual_inversion import textual_inversion, logging
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
|
||||
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
from modules import images, processing
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
|
||||
from modules.paths_internal import roboto_ttf_file
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
|
||||
@ -318,7 +316,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
return res
|
||||
|
||||
|
||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
|
||||
|
||||
|
||||
class FilenameGenerator:
|
||||
def get_vae_filename(self): #get the name of the VAE file.
|
||||
if sd_vae.loaded_vae_file is None:
|
||||
return "NoneType"
|
||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||
split_file_name = file_name.split('.')
|
||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||
else:
|
||||
return split_file_name[0]
|
||||
|
||||
replacements = {
|
||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||
@ -367,7 +355,9 @@ class FilenameGenerator:
|
||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
||||
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
|
||||
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
|
||||
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
|
||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||
@ -380,7 +370,8 @@ class FilenameGenerator:
|
||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||
'user': lambda self: self.p.user,
|
||||
'vae_filename': lambda self: self.get_vae_filename(),
|
||||
'none': lambda self: '', # Overrides the default so you can get just the sequence number
|
||||
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
|
||||
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
|
||||
}
|
||||
default_time_format = '%Y%m%d%H%M%S'
|
||||
|
||||
@ -391,6 +382,22 @@ class FilenameGenerator:
|
||||
self.image = image
|
||||
self.zip = zip
|
||||
|
||||
def get_vae_filename(self):
|
||||
"""Get the name of the VAE file."""
|
||||
|
||||
import modules.sd_vae as sd_vae
|
||||
|
||||
if sd_vae.loaded_vae_file is None:
|
||||
return "NoneType"
|
||||
|
||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||
split_file_name = file_name.split('.')
|
||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||
else:
|
||||
return split_file_name[0]
|
||||
|
||||
|
||||
def hasprompt(self, *args):
|
||||
lower = self.prompt.lower()
|
||||
if self.p is None or self.prompt is None:
|
||||
@ -444,6 +451,14 @@ class FilenameGenerator:
|
||||
|
||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||
|
||||
def image_hash(self, *args):
|
||||
length = int(args[0]) if (args and args[0] != "") else None
|
||||
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
|
||||
|
||||
def string_hash(self, text, *args):
|
||||
length = int(args[0]) if (args and args[0] != "") else 8
|
||||
return hashlib.sha256(text.encode()).hexdigest()[0:length]
|
||||
|
||||
def apply(self, x):
|
||||
res = ''
|
||||
|
||||
@ -585,6 +600,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
"""
|
||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||
|
||||
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
|
||||
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
|
||||
print('Image dimensions too large; saving as PNG')
|
||||
extension = ".png"
|
||||
|
||||
if save_to_dirs is None:
|
||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||
|
||||
|
@ -3,14 +3,14 @@ from contextlib import closing
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_samplers, images as imgutil
|
||||
from modules import images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.images import save_image
|
||||
from modules.sd_models import get_closet_checkpoint_match
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
@ -18,9 +18,10 @@ import modules.scripts
|
||||
|
||||
|
||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||
output_dir = output_dir.strip()
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
@ -32,11 +33,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
save_normally = output_dir == ''
|
||||
|
||||
p.do_not_save_grid = True
|
||||
p.do_not_save_samples = not save_normally
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
@ -46,7 +42,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
cfg_scale = p.cfg_scale
|
||||
sampler_name = p.sampler_name
|
||||
steps = p.steps
|
||||
|
||||
override_settings = p.override_settings
|
||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
@ -109,42 +106,40 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||
|
||||
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
|
||||
if model_info is not None:
|
||||
p.override_settings['sd_model_checkpoint'] = model_info.name
|
||||
elif sd_model_checkpoint_override:
|
||||
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
|
||||
else:
|
||||
p.override_settings.pop("sd_model_checkpoint", None)
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
if proc is None:
|
||||
proc = process_images(p)
|
||||
|
||||
for n, processed_image in enumerate(proc.images):
|
||||
filename = image_path.stem
|
||||
infotext = proc.infotext(p, n)
|
||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
||||
|
||||
if n > 0:
|
||||
filename += f"-{n}"
|
||||
|
||||
if not save_normally:
|
||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
||||
if processed_image.mode == 'RGBA':
|
||||
processed_image = processed_image.convert("RGB")
|
||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
if p.n_iter > 1 or p.batch_size > 1:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||
else:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||
process_images(p)
|
||||
|
||||
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||
override_settings = create_override_settings_dict(override_settings_texts)
|
||||
|
||||
is_batch = mode == 5
|
||||
|
||||
if mode == 0: # img2img
|
||||
image = init_img.convert("RGB")
|
||||
image = init_img
|
||||
mask = None
|
||||
elif mode == 1: # img2img sketch
|
||||
image = sketch.convert("RGB")
|
||||
image = sketch
|
||||
mask = None
|
||||
elif mode == 2: # inpaint
|
||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
||||
image = image.convert("RGB")
|
||||
mask = processing.create_binary_mask(mask)
|
||||
elif mode == 3: # inpaint sketch
|
||||
image = inpaint_color_sketch
|
||||
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||
@ -153,7 +148,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||
image = image.convert("RGB")
|
||||
elif mode == 4: # inpaint upload mask
|
||||
image = init_img_inpaint
|
||||
mask = init_mask_inpaint
|
||||
@ -180,21 +174,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
styles=prompt_styles,
|
||||
seed=seed,
|
||||
subseed=subseed,
|
||||
subseed_strength=subseed_strength,
|
||||
seed_resize_from_h=seed_resize_from_h,
|
||||
seed_resize_from_w=seed_resize_from_w,
|
||||
seed_enable_extras=seed_enable_extras,
|
||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
||||
sampler_name=sampler_name,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
restore_faces=restore_faces,
|
||||
tiling=tiling,
|
||||
init_images=[image],
|
||||
mask=mask,
|
||||
mask_blur=mask_blur,
|
||||
|
168
modules/initialize.py
Normal file
168
modules/initialize.py
Normal file
@ -0,0 +1,168 @@
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from threading import Thread
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
import pytorch_lightning # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||
startup_timer.record("setup paths")
|
||||
|
||||
import ldm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import ldm")
|
||||
|
||||
import sgm.modules.encoders.modules # noqa: F401
|
||||
startup_timer.record("import sgm")
|
||||
|
||||
from modules import shared_init
|
||||
shared_init.initialize()
|
||||
startup_timer.record("initialize shared")
|
||||
|
||||
from modules import processing, gradio_extensons, ui # noqa: F401
|
||||
startup_timer.record("other imports")
|
||||
|
||||
|
||||
def check_versions():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if not cmd_opts.skip_version_check:
|
||||
from modules import errors
|
||||
errors.check_versions()
|
||||
|
||||
|
||||
def initialize():
|
||||
from modules import initialize_util
|
||||
initialize_util.fix_torch_version()
|
||||
initialize_util.fix_asyncio_event_loop_policy()
|
||||
initialize_util.validate_tls_options()
|
||||
initialize_util.configure_sigint_handler()
|
||||
initialize_util.configure_opts_onchange()
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.cleanup_models()
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.setup_model()
|
||||
startup_timer.record("setup SD model")
|
||||
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
from modules import codeformer_model
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
||||
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
||||
startup_timer.record("setup codeformer")
|
||||
|
||||
from modules import gfpgan_model
|
||||
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
||||
startup_timer.record("setup gfpgan")
|
||||
|
||||
initialize_rest(reload_script_modules=False)
|
||||
|
||||
|
||||
def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Called both from initialize() and when reloading the webui.
|
||||
"""
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
from modules import sd_samplers
|
||||
sd_samplers.set_samplers()
|
||||
startup_timer.record("set samplers")
|
||||
|
||||
from modules import extensions
|
||||
extensions.list_extensions()
|
||||
startup_timer.record("list extensions")
|
||||
|
||||
from modules import initialize_util
|
||||
initialize_util.restore_config_state_file()
|
||||
startup_timer.record("restore config state file")
|
||||
|
||||
from modules import shared, upscaler, scripts
|
||||
if cmd_opts.ui_debug_mode:
|
||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||
scripts.load_scripts()
|
||||
return
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.list_models()
|
||||
startup_timer.record("list SD models")
|
||||
|
||||
from modules import localization
|
||||
localization.list_localizations(cmd_opts.localizations_dir)
|
||||
startup_timer.record("list localizations")
|
||||
|
||||
with startup_timer.subcategory("load scripts"):
|
||||
scripts.load_scripts()
|
||||
|
||||
if reload_script_modules:
|
||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||
importlib.reload(module)
|
||||
startup_timer.record("reload script modules")
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.load_upscalers()
|
||||
startup_timer.record("load upscalers")
|
||||
|
||||
from modules import sd_vae
|
||||
sd_vae.refresh_vae_list()
|
||||
startup_timer.record("refresh VAE")
|
||||
|
||||
from modules import textual_inversion
|
||||
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||
startup_timer.record("refresh textual inversion templates")
|
||||
|
||||
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
||||
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
||||
sd_hijack.list_optimizers()
|
||||
startup_timer.record("scripts list_optimizers")
|
||||
|
||||
from modules import sd_unet
|
||||
sd_unet.list_unets()
|
||||
startup_timer.record("scripts list_unets")
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Accesses shared.sd_model property to load model.
|
||||
After it's available, if it has been loaded before this access by some extension,
|
||||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
from modules import devices
|
||||
devices.first_time_calculation()
|
||||
|
||||
Thread(target=load_model).start()
|
||||
|
||||
from modules import shared_items
|
||||
shared_items.reload_hypernetworks()
|
||||
startup_timer.record("reload hypernetworks")
|
||||
|
||||
from modules import ui_extra_networks
|
||||
ui_extra_networks.initialize()
|
||||
ui_extra_networks.register_default_pages()
|
||||
|
||||
from modules import extra_networks
|
||||
extra_networks.initialize()
|
||||
extra_networks.register_default_extra_networks()
|
||||
startup_timer.record("initialize extra networks")
|
202
modules/initialize_util.py
Normal file
202
modules/initialize_util.py
Normal file
@ -0,0 +1,202 @@
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import re
|
||||
|
||||
from modules.timer import startup_timer
|
||||
|
||||
|
||||
def gradio_server_name():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if cmd_opts.server_name:
|
||||
return cmd_opts.server_name
|
||||
else:
|
||||
return "0.0.0.0" if cmd_opts.listen else None
|
||||
|
||||
|
||||
def fix_torch_version():
|
||||
import torch
|
||||
|
||||
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||
torch.__long_version__ = torch.__version__
|
||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||
|
||||
|
||||
def fix_asyncio_event_loop_policy():
|
||||
"""
|
||||
The default `asyncio` event loop policy only automatically creates
|
||||
event loops in the main threads. Other threads must create event
|
||||
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||
loops to be created automatically on any thread, matching the
|
||||
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||
# interface for composing policies so pick the right base.
|
||||
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||
else:
|
||||
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||
|
||||
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||
"""Event loop policy that allows loop creation on any thread.
|
||||
Usage::
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
"""
|
||||
|
||||
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
return super().get_event_loop()
|
||||
except (RuntimeError, AssertionError):
|
||||
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||
# and changed to a RuntimeError in 3.4.3.
|
||||
# "There is no current event loop in thread %r"
|
||||
loop = self.new_event_loop()
|
||||
self.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||
|
||||
|
||||
def restore_config_state_file():
|
||||
from modules import shared, config_states
|
||||
|
||||
config_state_file = shared.opts.restore_config_state_file
|
||||
if config_state_file == "":
|
||||
return
|
||||
|
||||
shared.opts.restore_config_state_file = ""
|
||||
shared.opts.save(shared.config_filename)
|
||||
|
||||
if os.path.isfile(config_state_file):
|
||||
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||
config_state = json.load(f)
|
||||
config_states.restore_extension_config(config_state)
|
||||
startup_timer.record("restore extension config")
|
||||
elif config_state_file:
|
||||
print(f"!!! Config state backup not found: {config_state_file}")
|
||||
|
||||
|
||||
def validate_tls_options():
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
||||
return
|
||||
|
||||
try:
|
||||
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||
print("Invalid path to TLS keyfile given")
|
||||
if not os.path.exists(cmd_opts.tls_certfile):
|
||||
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||
except TypeError:
|
||||
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||
print("TLS setup invalid, running webui without TLS")
|
||||
else:
|
||||
print("Running with TLS")
|
||||
startup_timer.record("TLS")
|
||||
|
||||
|
||||
def get_gradio_auth_creds():
|
||||
"""
|
||||
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
||||
an iterable of (username, password) tuples.
|
||||
"""
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
def process_credential_line(s):
|
||||
s = s.strip()
|
||||
if not s:
|
||||
return None
|
||||
return tuple(s.split(':', 1))
|
||||
|
||||
if cmd_opts.gradio_auth:
|
||||
for cred in cmd_opts.gradio_auth.split(','):
|
||||
cred = process_credential_line(cred)
|
||||
if cred:
|
||||
yield cred
|
||||
|
||||
if cmd_opts.gradio_auth_path:
|
||||
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||
for line in file.readlines():
|
||||
for cred in line.strip().split(','):
|
||||
cred = process_credential_line(cred)
|
||||
if cred:
|
||||
yield cred
|
||||
|
||||
|
||||
def dumpstacks():
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||
code = []
|
||||
for threadId, stack in sys._current_frames().items():
|
||||
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
if not os.environ.get("COVERAGE_RUN"):
|
||||
# Don't install the immediate-quit handler when running under coverage,
|
||||
# as then the coverage report won't be generated.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
|
||||
|
||||
def configure_opts_onchange():
|
||||
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
||||
from modules.call_queue import wrap_queued_call
|
||||
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
from starlette.middleware.gzip import GZipMiddleware
|
||||
|
||||
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
configure_cors_middleware(app)
|
||||
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||
|
||||
|
||||
def configure_cors_middleware(app):
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
cors_options = {
|
||||
"allow_methods": ["*"],
|
||||
"allow_headers": ["*"],
|
||||
"allow_credentials": True,
|
||||
}
|
||||
if cmd_opts.cors_allow_origins:
|
||||
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
||||
if cmd_opts.cors_allow_origins_regex:
|
||||
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
||||
app.add_middleware(CORSMiddleware, **cors_options)
|
||||
|
@ -186,9 +186,8 @@ class InterrogateModels:
|
||||
res = ""
|
||||
shared.state.begin(job="interrogate")
|
||||
try:
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
lowvram.send_everything_to_cpu()
|
||||
devices.torch_gc()
|
||||
|
||||
self.load()
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import importlib.util
|
||||
import platform
|
||||
@ -10,11 +12,11 @@ from functools import lru_cache
|
||||
|
||||
from modules import cmd_args, errors
|
||||
from modules.paths_internal import script_path, extensions_dir
|
||||
from modules import timer
|
||||
|
||||
timer.startup_timer.record("start")
|
||||
from modules.timer import startup_timer
|
||||
from modules import logging_config
|
||||
|
||||
args, _ = cmd_args.parser.parse_known_args()
|
||||
logging_config.setup_logging(args.loglevel)
|
||||
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
@ -141,6 +143,25 @@ def check_run_python(code: str) -> bool:
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def git_fix_workspace(dir, name):
|
||||
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||
return
|
||||
|
||||
|
||||
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||
try:
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
except RuntimeError:
|
||||
if not autofix:
|
||||
raise
|
||||
|
||||
print(f"{errdesc}, attempting autofix...")
|
||||
git_fix_workspace(dir, name)
|
||||
|
||||
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||
|
||||
|
||||
def git_clone(url, dir, name, commithash=None):
|
||||
# TODO clone into temporary dir and move if successful
|
||||
|
||||
@ -148,15 +169,24 @@ def git_clone(url, dir, name, commithash=None):
|
||||
if commithash is None:
|
||||
return
|
||||
|
||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||
|
||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
|
||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
try:
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if commithash is not None:
|
||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||
@ -216,7 +246,7 @@ def list_extensions(settings_file):
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
|
||||
if disable_all_extensions != 'none':
|
||||
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||
return []
|
||||
|
||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||
@ -226,8 +256,15 @@ def run_extensions_installers(settings_file):
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||
with startup_timer.subcategory("run extensions installers"):
|
||||
for dirname_extension in list_extensions(settings_file):
|
||||
logging.debug(f"Installing {dirname_extension}")
|
||||
|
||||
path = os.path.join(extensions_dir, dirname_extension)
|
||||
|
||||
if os.path.isdir(path):
|
||||
run_extension_installer(path)
|
||||
startup_timer.record(dirname_extension)
|
||||
|
||||
|
||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||
@ -274,7 +311,6 @@ def prepare_environment():
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
@ -285,13 +321,13 @@ def prepare_environment():
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
except OSError:
|
||||
@ -300,8 +336,11 @@ def prepare_environment():
|
||||
if not args.skip_python_version_check:
|
||||
check_python_version()
|
||||
|
||||
startup_timer.record("checks")
|
||||
|
||||
commit = commit_hash()
|
||||
tag = git_tag()
|
||||
startup_timer.record("git version info")
|
||||
|
||||
print(f"Python {sys.version}")
|
||||
print(f"Version: {tag}")
|
||||
@ -309,21 +348,22 @@ def prepare_environment():
|
||||
|
||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||
startup_timer.record("install torch")
|
||||
|
||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||
raise RuntimeError(
|
||||
'Torch is not able to use GPU; '
|
||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||
)
|
||||
|
||||
if not is_installed("gfpgan"):
|
||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||
startup_timer.record("torch GPU test")
|
||||
|
||||
if not is_installed("clip"):
|
||||
run_pip(f"install {clip_package}", "clip")
|
||||
startup_timer.record("install clip")
|
||||
|
||||
if not is_installed("open_clip"):
|
||||
run_pip(f"install {openclip_package}", "open_clip")
|
||||
startup_timer.record("install open_clip")
|
||||
|
||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||
if platform.system() == "Windows":
|
||||
@ -337,8 +377,11 @@ def prepare_environment():
|
||||
elif platform.system() == "Linux":
|
||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||
|
||||
startup_timer.record("install xformers")
|
||||
|
||||
if not is_installed("ngrok") and args.ngrok:
|
||||
run_pip("install ngrok", "ngrok")
|
||||
startup_timer.record("install ngrok")
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
@ -348,22 +391,28 @@ def prepare_environment():
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||
startup_timer.record("install CodeFormer requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
if not requirements_met(requirements_file):
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
if args.update_check:
|
||||
version_check(commit)
|
||||
startup_timer.record("check version")
|
||||
|
||||
if args.update_all_extensions:
|
||||
git_pull_recursive(extensions_dir)
|
||||
startup_timer.record("update extensions")
|
||||
|
||||
if "--exit" in sys.argv:
|
||||
print("Exiting because of --exit argument")
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from modules import errors
|
||||
from modules import errors, scripts
|
||||
|
||||
localizations = {}
|
||||
|
||||
@ -16,7 +16,6 @@ def list_localizations(dirname):
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
|
||||
from modules import scripts
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
|
16
modules/logging_config.py
Normal file
16
modules/logging_config.py
Normal file
@ -0,0 +1,16 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
def setup_logging(loglevel):
|
||||
if loglevel is None:
|
||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
|
||||
if loglevel:
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from modules import devices
|
||||
from modules import devices, shared
|
||||
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
@ -14,7 +14,24 @@ def send_everything_to_cpu():
|
||||
module_in_gpu = None
|
||||
|
||||
|
||||
def is_needed(sd_model):
|
||||
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
||||
|
||||
|
||||
def apply(sd_model):
|
||||
enable = is_needed(sd_model)
|
||||
shared.parallel_processing_allowed = not enable
|
||||
|
||||
if enable:
|
||||
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
||||
else:
|
||||
sd_model.lowvram = False
|
||||
|
||||
|
||||
def setup_for_low_vram(sd_model, use_medvram):
|
||||
if getattr(sd_model, 'lowvram', False):
|
||||
return
|
||||
|
||||
sd_model.lowvram = True
|
||||
|
||||
parents = {}
|
||||
@ -127,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
|
||||
|
||||
def is_enabled(sd_model):
|
||||
return getattr(sd_model, 'lowvram', False)
|
||||
return sd_model.lowvram
|
||||
|
@ -4,6 +4,7 @@ import torch
|
||||
import platform
|
||||
from modules.sd_hijack_utils import CondFunc
|
||||
from packaging import version
|
||||
from modules import shared
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -30,8 +31,7 @@ has_mps = check_for_mps()
|
||||
|
||||
def torch_mps_gc() -> None:
|
||||
try:
|
||||
from modules.shared import state
|
||||
if state.current_latent is not None:
|
||||
if shared.state.current_latent is not None:
|
||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||
return
|
||||
from torch.mps import empty_cache
|
||||
@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
|
||||
|
||||
if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
245
modules/options.py
Normal file
245
modules/options.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
self.do_not_save = False
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
|
||||
self.comment_after = comment_after
|
||||
"""HTML text that will be added before label in UI"""
|
||||
|
||||
self.infotext = infotext
|
||||
|
||||
self.restrict_api = restrict_api
|
||||
"""If True, the setting will not be accessible via API"""
|
||||
|
||||
def link(self, label, url):
|
||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def js(self, label, js_func):
|
||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def info(self, info):
|
||||
self.comment_after += f"<span class='info'>({info})</span>"
|
||||
return self
|
||||
|
||||
def html(self, html):
|
||||
self.comment_after += html
|
||||
return self
|
||||
|
||||
def needs_restart(self):
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
def needs_reload_ui(self):
|
||||
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
||||
return self
|
||||
|
||||
|
||||
class OptionHTML(OptionInfo):
|
||||
def __init__(self, text):
|
||||
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
||||
|
||||
self.do_not_save = True
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
|
||||
|
||||
|
||||
class Options:
|
||||
typemap = {int: float}
|
||||
|
||||
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||
self.data_labels = data_labels
|
||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||
self.restricted_opts = restricted_opts
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in options_builtin_fields:
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
if self.data is not None:
|
||||
if key in self.data or key in self.data_labels:
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
info = self.data_labels.get(key, None)
|
||||
if info.do_not_save:
|
||||
return
|
||||
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item in options_builtin_fields:
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
if self.data is not None:
|
||||
if item in self.data:
|
||||
return self.data[item]
|
||||
|
||||
if item in self.data_labels:
|
||||
return self.data_labels[item].default
|
||||
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
def set(self, key, value, is_api=False, run_callbacks=True):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
option = self.data_labels[key]
|
||||
if option.do_not_save:
|
||||
return False
|
||||
|
||||
if is_api and option.restrict_api:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if run_callbacks and option.onchange is not None:
|
||||
try:
|
||||
option.onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_default(self, key):
|
||||
"""returns the default value for the key"""
|
||||
|
||||
data_label = self.data_labels.get(key)
|
||||
if data_label is None:
|
||||
return None
|
||||
|
||||
return data_label.default
|
||||
|
||||
def save(self, filename):
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
json.dump(self.data, file, indent=4)
|
||||
|
||||
def same_type(self, x, y):
|
||||
if x is None or y is None:
|
||||
return True
|
||||
|
||||
type_x = self.typemap.get(type(x), type(x))
|
||||
type_y = self.typemap.get(type(y), type(y))
|
||||
|
||||
return type_x == type_y
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
# 1.6.0 VAE defaults
|
||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||
|
||||
# 1.1.1 quicksettings list migration
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
||||
# 1.4.0 ui_reorder
|
||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||
|
||||
bad_settings = 0
|
||||
for k, v in self.data.items():
|
||||
info = self.data_labels.get(k, None)
|
||||
if info is not None and not self.same_type(info.default, v):
|
||||
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
||||
bad_settings += 1
|
||||
|
||||
if bad_settings > 0:
|
||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||
|
||||
def onchange(self, key, func, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
if call:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
|
||||
def reorder(self):
|
||||
"""reorder settings so that all items related to section always go together"""
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
default_value = self.data_labels[key].default
|
||||
if default_value is None:
|
||||
default_value = getattr(self, key, None)
|
||||
if default_value is None:
|
||||
return None
|
||||
|
||||
expected_type = type(default_value)
|
||||
if expected_type == bool and value == "False":
|
||||
value = False
|
||||
else:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
64
modules/patches.py
Normal file
64
modules/patches.py
Normal file
@ -0,0 +1,64 @@
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def patch(key, obj, field, replacement):
|
||||
"""Replaces a function in a module or a class.
|
||||
|
||||
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
replacement: the new function
|
||||
|
||||
Returns:
|
||||
the original function
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
if patch_key in originals[key]:
|
||||
raise RuntimeError(f"patch for {field} is already applied")
|
||||
|
||||
original_func = getattr(obj, field)
|
||||
originals[key][patch_key] = original_func
|
||||
|
||||
setattr(obj, field, replacement)
|
||||
|
||||
return original_func
|
||||
|
||||
|
||||
def undo(key, obj, field):
|
||||
"""Undoes the peplacement by the patch().
|
||||
|
||||
If the function is not replaced, raises an exception.
|
||||
|
||||
Arguments:
|
||||
key: identifying information for who is doing the replacement. You can use __name__.
|
||||
obj: the module or the class
|
||||
field: name of the function as a string
|
||||
|
||||
Returns:
|
||||
Always None
|
||||
"""
|
||||
|
||||
patch_key = (obj, field)
|
||||
|
||||
if patch_key not in originals[key]:
|
||||
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||
|
||||
original_func = originals[key].pop(patch_key)
|
||||
setattr(obj, field, original_func)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def original(key, obj, field):
|
||||
"""Returns the original function for the patch created by the patch() function"""
|
||||
patch_key = (obj, field)
|
||||
|
||||
return originals[key].get(patch_key, None)
|
||||
|
||||
|
||||
originals = defaultdict(dict)
|
@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
|
||||
shared.state.begin(job="extras")
|
||||
|
||||
image_data = []
|
||||
image_names = []
|
||||
outputs = []
|
||||
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
fn = ''
|
||||
else:
|
||||
image = Image.open(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.orig_name)[0]
|
||||
image_data.append(image)
|
||||
image_names.append(fn)
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
def get_images(extras_mode, image, image_folder, input_dir):
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
fn = ''
|
||||
else:
|
||||
image = Image.open(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.orig_name)[0]
|
||||
yield image, fn
|
||||
elif extras_mode == 2:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||
assert input_dir, 'input directory not selected'
|
||||
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for filename in image_list:
|
||||
try:
|
||||
image = Image.open(filename)
|
||||
except Exception:
|
||||
continue
|
||||
image_data.append(image)
|
||||
image_names.append(filename)
|
||||
else:
|
||||
assert image, 'image not selected'
|
||||
|
||||
image_data.append(image)
|
||||
image_names.append(None)
|
||||
image_list = shared.listfiles(input_dir)
|
||||
for filename in image_list:
|
||||
try:
|
||||
image = Image.open(filename)
|
||||
except Exception:
|
||||
continue
|
||||
yield image, filename
|
||||
else:
|
||||
assert image, 'image not selected'
|
||||
yield image, None
|
||||
|
||||
if extras_mode == 2 and output_dir != '':
|
||||
outpath = output_dir
|
||||
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
|
||||
infotext = ''
|
||||
|
||||
for image, name in zip(image_data, image_names):
|
||||
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
||||
image_data: Image.Image
|
||||
|
||||
shared.state.textinfo = name
|
||||
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
||||
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||
|
||||
scripts.scripts_postproc.run(pp, args)
|
||||
|
||||
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
if extras_mode != 2 or show_extras_results:
|
||||
outputs.append(pp.image)
|
||||
|
||||
image_data.close()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||
|
File diff suppressed because it is too large
Load Diff
49
modules/processing_scripts/refiner.py
Normal file
49
modules/processing_scripts/refiner.py
Normal file
@ -0,0 +1,49 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_models
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
|
||||
class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||
section = "accordions"
|
||||
create_group = False
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def title(self):
|
||||
return "Refiner"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||
with gr.Row():
|
||||
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
|
||||
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||
|
||||
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||
|
||||
def lookup_checkpoint(title):
|
||||
info = sd_models.get_closet_checkpoint_match(title)
|
||||
return None if info is None else info.title
|
||||
|
||||
self.infotext_fields = [
|
||||
(enable_refiner, lambda d: 'Refiner' in d),
|
||||
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||
(refiner_switch_at, 'Refiner switch at'),
|
||||
]
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
||||
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
|
||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
p.refiner_checkpoint = None
|
||||
p.refiner_switch_at = None
|
||||
else:
|
||||
p.refiner_checkpoint = refiner_checkpoint
|
||||
p.refiner_switch_at = refiner_switch_at
|
111
modules/processing_scripts/seed.py
Normal file
111
modules/processing_scripts/seed.py
Normal file
@ -0,0 +1,111 @@
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui, errors
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
|
||||
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
section = "seed"
|
||||
create_group = False
|
||||
|
||||
def __init__(self):
|
||||
self.seed = None
|
||||
self.reuse_seed = None
|
||||
self.reuse_subseed = None
|
||||
|
||||
def title(self):
|
||||
return "Seed"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with gr.Row(elem_id=self.elem_id("seed_row")):
|
||||
if cmd_opts.use_textbox_seed:
|
||||
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
|
||||
else:
|
||||
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||
|
||||
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
||||
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
||||
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||
|
||||
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
|
||||
with gr.Row(elem_id=self.elem_id("subseed_row")):
|
||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
|
||||
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
|
||||
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
|
||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
|
||||
|
||||
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
|
||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
|
||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
|
||||
|
||||
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||
|
||||
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||
|
||||
self.infotext_fields = [
|
||||
(self.seed, "Seed"),
|
||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
(subseed, "Variation seed"),
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
]
|
||||
|
||||
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
|
||||
|
||||
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
|
||||
|
||||
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
|
||||
p.seed = seed
|
||||
|
||||
if seed_checkbox and subseed_strength > 0:
|
||||
p.subseed = subseed
|
||||
p.subseed_strength = subseed_strength
|
||||
|
||||
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
|
||||
p.seed_resize_from_w = seed_resize_from_w
|
||||
p.seed_resize_from_h = seed_resize_from_h
|
||||
|
||||
|
||||
|
||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
||||
|
||||
def copy_seed(gen_info_string: str, index):
|
||||
res = -1
|
||||
|
||||
try:
|
||||
gen_info = json.loads(gen_info_string)
|
||||
index -= gen_info.get('index_of_first_image', 0)
|
||||
|
||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||
else:
|
||||
all_seeds = gen_info.get('all_seeds', [-1])
|
||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
if gen_info_string:
|
||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||
|
||||
return [res, gr.update()]
|
||||
|
||||
reuse_seed.click(
|
||||
fn=copy_seed,
|
||||
_js="(x, y) => [x, selected_gallery_index()]",
|
||||
show_progress=False,
|
||||
inputs=[generation_info, seed],
|
||||
outputs=[seed, seed]
|
||||
)
|
@ -48,6 +48,7 @@ def add_task_to_queue(id_job):
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||
live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
|
||||
|
||||
|
||||
class ProgressResponse(BaseModel):
|
||||
@ -71,7 +72,12 @@ def progressapi(req: ProgressRequest):
|
||||
completed = req.id_task in finished_tasks
|
||||
|
||||
if not active:
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
||||
textinfo = "Waiting..."
|
||||
if queued:
|
||||
sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
|
||||
queue_index = sorted_queued.index(req.id_task)
|
||||
textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
|
||||
|
||||
progress = 0
|
||||
|
||||
@ -89,31 +95,30 @@ def progressapi(req: ProgressRequest):
|
||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||
|
||||
live_preview = None
|
||||
id_live_preview = req.id_live_preview
|
||||
shared.state.set_current_image()
|
||||
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
|
||||
if opts.live_previews_image_format == "png":
|
||||
# using optimize for large images takes an enormous amount of time
|
||||
if max(*image.size) <= 256:
|
||||
save_kwargs = {"optimize": True}
|
||||
if opts.live_previews_enable and req.live_preview:
|
||||
shared.state.set_current_image()
|
||||
if shared.state.id_live_preview != req.id_live_preview:
|
||||
image = shared.state.current_image
|
||||
if image is not None:
|
||||
buffered = io.BytesIO()
|
||||
|
||||
if opts.live_previews_image_format == "png":
|
||||
# using optimize for large images takes an enormous amount of time
|
||||
if max(*image.size) <= 256:
|
||||
save_kwargs = {"optimize": True}
|
||||
else:
|
||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||
|
||||
else:
|
||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||
save_kwargs = {}
|
||||
|
||||
else:
|
||||
save_kwargs = {}
|
||||
|
||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
else:
|
||||
live_preview = None
|
||||
else:
|
||||
live_preview = None
|
||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
|
||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
||||
|
@ -19,14 +19,14 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
||||
!emphasized: "(" prompt ")"
|
||||
| "(" prompt ":" prompt ")"
|
||||
| "[" prompt "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||
alternate: "[" prompt ("|" prompt)+ "]"
|
||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||
WHITESPACE: /\s+/
|
||||
plain: /([^\\\[\]():|]|\\.)+/
|
||||
%import common.SIGNED_NUMBER -> NUMBER
|
||||
""")
|
||||
|
||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
|
||||
"""
|
||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||
>>> g("test")
|
||||
@ -53,18 +53,43 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||
>>> g("[a|(b:1.1)]")
|
||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||
>>> g("[fe|]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
>>> g("[fe|||]male")
|
||||
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
|
||||
>>> g("a [b:.5] c")
|
||||
[[10, 'a b c']]
|
||||
>>> g("a [b:1.5] c")
|
||||
[[5, 'a c'], [10, 'a b c']]
|
||||
"""
|
||||
|
||||
if hires_steps is None or use_old_scheduling:
|
||||
int_offset = 0
|
||||
flt_offset = 0
|
||||
steps = base_steps
|
||||
else:
|
||||
int_offset = base_steps
|
||||
flt_offset = 1.0
|
||||
steps = hires_steps
|
||||
|
||||
def collect_steps(steps, tree):
|
||||
res = [steps]
|
||||
|
||||
class CollectSteps(lark.Visitor):
|
||||
def scheduled(self, tree):
|
||||
tree.children[-1] = float(tree.children[-1])
|
||||
if tree.children[-1] < 1:
|
||||
tree.children[-1] *= steps
|
||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||
res.append(tree.children[-1])
|
||||
s = tree.children[-2]
|
||||
v = float(s)
|
||||
if use_old_scheduling:
|
||||
v = v*steps if v<1 else v
|
||||
else:
|
||||
if "." in s:
|
||||
v = (v - flt_offset) * steps
|
||||
else:
|
||||
v = (v - int_offset)
|
||||
tree.children[-2] = min(steps, int(v))
|
||||
if tree.children[-2] >= 1:
|
||||
res.append(tree.children[-2])
|
||||
|
||||
def alternate(self, tree):
|
||||
res.extend(range(1, steps+1))
|
||||
@ -75,13 +100,14 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
def at_step(step, tree):
|
||||
class AtStep(lark.Transformer):
|
||||
def scheduled(self, args):
|
||||
before, after, _, when = args
|
||||
before, after, _, when, _ = args
|
||||
yield before or () if step <= when else after
|
||||
def alternate(self, args):
|
||||
yield next(args[(step - 1)%len(args)])
|
||||
args = ["" if not arg else arg for arg in args]
|
||||
yield args[(step - 1) % len(args)]
|
||||
def start(self, args):
|
||||
def flatten(x):
|
||||
if type(x) == str:
|
||||
if isinstance(x, str):
|
||||
yield x
|
||||
else:
|
||||
for gen in x:
|
||||
@ -129,7 +155,7 @@ class SdConditioning(list):
|
||||
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
@ -149,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||
"""
|
||||
res = []
|
||||
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||
cache = {}
|
||||
|
||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||
@ -224,7 +250,7 @@ class MulticondLearnedConditioning:
|
||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
|
||||
@ -233,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
||||
|
||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||
|
||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
|
||||
|
||||
res = []
|
||||
for indexes in res_indexes:
|
||||
@ -333,7 +359,7 @@ re_attention = re.compile(r"""
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
:\s*([+-]?[.\d]+)\s*\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
|
@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||
|
170
modules/rng.py
Normal file
170
modules/rng.py
Normal file
@ -0,0 +1,170 @@
|
||||
import torch
|
||||
|
||||
from modules import devices, rng_philox, shared
|
||||
|
||||
|
||||
def randn(seed, shape, generator=None):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||
|
||||
manual_seed(seed)
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||
|
||||
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||
|
||||
return torch.randn(shape, device=devices.device, generator=generator)
|
||||
|
||||
|
||||
def randn_local(seed, shape):
|
||||
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||
|
||||
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
rng = rng_philox.Generator(seed)
|
||||
return torch.asarray(rng.randn(shape), device=devices.device)
|
||||
|
||||
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
|
||||
|
||||
|
||||
def randn_like(x):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||
|
||||
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
def randn_without_seed(shape, generator=None):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||
|
||||
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||
|
||||
return torch.randn(shape, device=devices.device, generator=generator)
|
||||
|
||||
|
||||
def manual_seed(seed):
|
||||
"""Set up a global random number generator using the specified seed."""
|
||||
|
||||
if shared.opts.randn_source == "NV":
|
||||
global nv_rng
|
||||
nv_rng = rng_philox.Generator(seed)
|
||||
return
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def create_generator(seed):
|
||||
if shared.opts.randn_source == "NV":
|
||||
return rng_philox.Generator(seed)
|
||||
|
||||
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||
generator = torch.Generator(device).manual_seed(int(seed))
|
||||
return generator
|
||||
|
||||
|
||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||
def slerp(val, low, high):
|
||||
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||
dot = (low_norm*high_norm).sum(1)
|
||||
|
||||
if dot.mean() > 0.9995:
|
||||
return low * val + high * (1 - val)
|
||||
|
||||
omega = torch.acos(dot)
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||
return res
|
||||
|
||||
|
||||
class ImageRNG:
|
||||
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||
self.shape = tuple(map(int, shape))
|
||||
self.seeds = seeds
|
||||
self.subseeds = subseeds
|
||||
self.subseed_strength = subseed_strength
|
||||
self.seed_resize_from_h = seed_resize_from_h
|
||||
self.seed_resize_from_w = seed_resize_from_w
|
||||
|
||||
self.generators = [create_generator(seed) for seed in seeds]
|
||||
|
||||
self.is_first = True
|
||||
|
||||
def first(self):
|
||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
||||
|
||||
xs = []
|
||||
|
||||
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
|
||||
subnoise = None
|
||||
if self.subseeds is not None and self.subseed_strength != 0:
|
||||
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
|
||||
subnoise = randn(subseed, noise_shape)
|
||||
|
||||
if noise_shape != self.shape:
|
||||
noise = randn(seed, noise_shape)
|
||||
else:
|
||||
noise = randn(seed, self.shape, generator=generator)
|
||||
|
||||
if subnoise is not None:
|
||||
noise = slerp(self.subseed_strength, noise, subnoise)
|
||||
|
||||
if noise_shape != self.shape:
|
||||
x = randn(seed, self.shape, generator=generator)
|
||||
dx = (self.shape[2] - noise_shape[2]) // 2
|
||||
dy = (self.shape[1] - noise_shape[1]) // 2
|
||||
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||
tx = 0 if dx < 0 else dx
|
||||
ty = 0 if dy < 0 else dy
|
||||
dx = max(-dx, 0)
|
||||
dy = max(-dy, 0)
|
||||
|
||||
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
|
||||
noise = x
|
||||
|
||||
xs.append(noise)
|
||||
|
||||
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
|
||||
if eta_noise_seed_delta:
|
||||
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
|
||||
|
||||
return torch.stack(xs).to(shared.device)
|
||||
|
||||
def next(self):
|
||||
if self.is_first:
|
||||
self.is_first = False
|
||||
return self.first()
|
||||
|
||||
xs = []
|
||||
for generator in self.generators:
|
||||
x = randn_without_seed(self.shape, generator=generator)
|
||||
xs.append(x)
|
||||
|
||||
return torch.stack(xs).to(shared.device)
|
||||
|
||||
|
||||
devices.randn = randn
|
||||
devices.randn_local = randn_local
|
||||
devices.randn_like = randn_like
|
||||
devices.randn_without_seed = randn_without_seed
|
||||
devices.manual_seed = manual_seed
|
102
modules/rng_philox.py
Normal file
102
modules/rng_philox.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||
|
||||
Usage:
|
||||
|
||||
```
|
||||
g = Generator(seed=0)
|
||||
print(g.randn(shape=(3, 4)))
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||
```
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||
|
||||
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||
|
||||
|
||||
def uint32(x):
|
||||
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||
|
||||
|
||||
def philox4_round(counter, key):
|
||||
"""A single round of the Philox 4x32 random number generator."""
|
||||
|
||||
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||
|
||||
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||
counter[1] = v2[0]
|
||||
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||
counter[3] = v1[0]
|
||||
|
||||
|
||||
def philox4_32(counter, key, rounds=10):
|
||||
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
|
||||
Parameters:
|
||||
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||
rounds (int): The number of rounds to perform.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
"""
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
philox4_round(counter, key)
|
||||
|
||||
key[0] = key[0] + philox_w[0]
|
||||
key[1] = key[1] + philox_w[1]
|
||||
|
||||
philox4_round(counter, key)
|
||||
return counter
|
||||
|
||||
|
||||
def box_muller(x, y):
|
||||
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||
|
||||
s = np.sqrt(-2.0 * np.log(u))
|
||||
|
||||
r1 = s * np.sin(v)
|
||||
return r1.astype(np.float32)
|
||||
|
||||
|
||||
class Generator:
|
||||
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.offset = 0
|
||||
|
||||
def randn(self, shape):
|
||||
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||
|
||||
n = 1
|
||||
for x in shape:
|
||||
n *= x
|
||||
|
||||
counter = np.zeros((4, n), dtype=np.uint32)
|
||||
counter[0] = self.offset
|
||||
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||
self.offset += 1
|
||||
|
||||
key = np.empty(n, dtype=np.uint64)
|
||||
key.fill(self.seed)
|
||||
key = uint32(key)
|
||||
|
||||
g = philox4_32(counter, key)
|
||||
|
||||
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
@ -28,6 +28,15 @@ class ImageSaveParams:
|
||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||
|
||||
|
||||
class ExtraNoiseParams:
|
||||
def __init__(self, noise, x):
|
||||
self.noise = noise
|
||||
"""Random noise generated by the seed"""
|
||||
|
||||
self.x = x
|
||||
"""Latent image representation of the image"""
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||
self.x = x
|
||||
@ -100,6 +109,7 @@ callback_map = dict(
|
||||
callbacks_ui_settings=[],
|
||||
callbacks_before_image_saved=[],
|
||||
callbacks_image_saved=[],
|
||||
callbacks_extra_noise=[],
|
||||
callbacks_cfg_denoiser=[],
|
||||
callbacks_cfg_denoised=[],
|
||||
callbacks_cfg_after_cfg=[],
|
||||
@ -189,6 +199,14 @@ def image_saved_callback(params: ImageSaveParams):
|
||||
report_exception(c, 'image_saved_callback')
|
||||
|
||||
|
||||
def extra_noise_callback(params: ExtraNoiseParams):
|
||||
for c in callback_map['callbacks_extra_noise']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'callbacks_extra_noise')
|
||||
|
||||
|
||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||
for c in callback_map['callbacks_cfg_denoiser']:
|
||||
try:
|
||||
@ -367,6 +385,14 @@ def on_image_saved(callback):
|
||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||
|
||||
|
||||
def on_extra_noise(callback):
|
||||
"""register a function to be called before adding extra noise in img2img or hires fix;
|
||||
The callback is called with one argument:
|
||||
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
||||
"""
|
||||
add_callback(callback_map['callbacks_extra_noise'], callback)
|
||||
|
||||
|
||||
def on_cfg_denoiser(callback):
|
||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||
The callback is called with one argument:
|
||||
|
@ -3,6 +3,7 @@ import re
|
||||
import sys
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
|
||||
self.images = images
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnComponent:
|
||||
component: gr.blocks.Block
|
||||
|
||||
|
||||
class Script:
|
||||
name = None
|
||||
"""script's internal name derived from title"""
|
||||
@ -35,9 +41,13 @@ class Script:
|
||||
|
||||
is_txt2img = False
|
||||
is_img2img = False
|
||||
tabname = None
|
||||
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
"""A gr.Group component that has all script's UI inside it."""
|
||||
|
||||
create_group = True
|
||||
"""If False, for alwayson scripts, a group component will not be created."""
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
@ -52,6 +62,15 @@ class Script:
|
||||
api_info = None
|
||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||
|
||||
on_before_component_elem_id = None
|
||||
"""list of callbacks to be called before a component with an elem_id is created"""
|
||||
|
||||
on_after_component_elem_id = None
|
||||
"""list of callbacks to be called after a component with an elem_id is created"""
|
||||
|
||||
setup_for_ui_only = False
|
||||
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
@ -90,9 +109,16 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def before_process(self, p, *args):
|
||||
"""
|
||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
@ -212,6 +238,29 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def on_before_component(self, callback, *, elem_id):
|
||||
"""
|
||||
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
|
||||
|
||||
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
|
||||
|
||||
This function is an alternative to before_component in that it also cllows to run before a component is created, but
|
||||
it doesn't require to be called for every created component - just for the one you need.
|
||||
"""
|
||||
if self.on_before_component_elem_id is None:
|
||||
self.on_before_component_elem_id = []
|
||||
|
||||
self.on_before_component_elem_id.append((elem_id, callback))
|
||||
|
||||
def on_after_component(self, callback, *, elem_id):
|
||||
"""
|
||||
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
|
||||
"""
|
||||
if self.on_after_component_elem_id is None:
|
||||
self.on_after_component_elem_id = []
|
||||
|
||||
self.on_after_component_elem_id.append((elem_id, callback))
|
||||
|
||||
def describe(self):
|
||||
"""unused"""
|
||||
return ""
|
||||
@ -220,7 +269,7 @@ class Script:
|
||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||
|
||||
need_tabname = self.show(True) == self.show(False)
|
||||
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
||||
tabkind = 'img2img' if self.is_img2img else 'txt2img'
|
||||
tabname = f"{tabkind}_" if need_tabname else ""
|
||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||
|
||||
@ -232,6 +281,19 @@ class Script:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ScriptBuiltinUI(Script):
|
||||
setup_for_ui_only = True
|
||||
|
||||
def elem_id(self, item_id):
|
||||
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
|
||||
|
||||
need_tabname = self.show(True) == self.show(False)
|
||||
tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
|
||||
|
||||
return f'{tabname}{item_id}'
|
||||
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
@ -250,7 +312,7 @@ postprocessing_scripts_data = []
|
||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||
scripts_list = []
|
||||
|
||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||
@ -258,8 +320,9 @@ def list_scripts(scriptdirname, extension):
|
||||
for filename in sorted(os.listdir(basedir)):
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
if include_extensions:
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
@ -288,7 +351,7 @@ def load_scripts():
|
||||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py")
|
||||
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
@ -349,10 +412,17 @@ class ScriptRunner:
|
||||
self.selectable_scripts = []
|
||||
self.alwayson_scripts = []
|
||||
self.titles = []
|
||||
self.title_map = {}
|
||||
self.infotext_fields = []
|
||||
self.paste_field_names = []
|
||||
self.inputs = [None]
|
||||
|
||||
self.on_before_component_elem_id = {}
|
||||
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
||||
|
||||
self.on_after_component_elem_id = {}
|
||||
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
@ -367,6 +437,7 @@ class ScriptRunner:
|
||||
script.filename = script_data.path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
script.tabname = "img2img" if is_img2img else "txt2img"
|
||||
|
||||
visibility = script.show(script.is_img2img)
|
||||
|
||||
@ -379,6 +450,28 @@ class ScriptRunner:
|
||||
self.scripts.append(script)
|
||||
self.selectable_scripts.append(script)
|
||||
|
||||
self.apply_on_before_component_callbacks()
|
||||
|
||||
def apply_on_before_component_callbacks(self):
|
||||
for script in self.scripts:
|
||||
on_before = script.on_before_component_elem_id or []
|
||||
on_after = script.on_after_component_elem_id or []
|
||||
|
||||
for elem_id, callback in on_before:
|
||||
if elem_id not in self.on_before_component_elem_id:
|
||||
self.on_before_component_elem_id[elem_id] = []
|
||||
|
||||
self.on_before_component_elem_id[elem_id].append((callback, script))
|
||||
|
||||
for elem_id, callback in on_after:
|
||||
if elem_id not in self.on_after_component_elem_id:
|
||||
self.on_after_component_elem_id[elem_id] = []
|
||||
|
||||
self.on_after_component_elem_id[elem_id].append((callback, script))
|
||||
|
||||
on_before.clear()
|
||||
on_after.clear()
|
||||
|
||||
def create_script_ui(self, script):
|
||||
import modules.api.models as api_models
|
||||
|
||||
@ -429,15 +522,20 @@ class ScriptRunner:
|
||||
if script.alwayson and script.section != section:
|
||||
continue
|
||||
|
||||
with gr.Group(visible=script.alwayson) as group:
|
||||
self.create_script_ui(script)
|
||||
if script.create_group:
|
||||
with gr.Group(visible=script.alwayson) as group:
|
||||
self.create_script_ui(script)
|
||||
|
||||
script.group = group
|
||||
script.group = group
|
||||
else:
|
||||
self.create_script_ui(script)
|
||||
|
||||
def prepare_ui(self):
|
||||
self.inputs = [None]
|
||||
|
||||
def setup_ui(self):
|
||||
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
|
||||
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
|
||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||
|
||||
self.setup_ui_for_section(None)
|
||||
@ -484,6 +582,8 @@ class ScriptRunner:
|
||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||
|
||||
self.apply_on_before_component_callbacks()
|
||||
|
||||
return self.inputs
|
||||
|
||||
def run(self, p, *args):
|
||||
@ -577,6 +677,12 @@ class ScriptRunner:
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||
try:
|
||||
callback(OnComponent(component=component))
|
||||
except Exception:
|
||||
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
||||
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.before_component(component, **kwargs)
|
||||
@ -584,12 +690,21 @@ class ScriptRunner:
|
||||
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||
|
||||
def after_component(self, component, **kwargs):
|
||||
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
|
||||
try:
|
||||
callback(OnComponent(component=component))
|
||||
except Exception:
|
||||
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
||||
|
||||
for script in self.scripts:
|
||||
try:
|
||||
script.after_component(component, **kwargs)
|
||||
except Exception:
|
||||
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||
|
||||
def script(self, title):
|
||||
return self.title_map.get(title.lower())
|
||||
|
||||
def reload_sources(self, cache):
|
||||
for si, script in list(enumerate(self.scripts)):
|
||||
args_from = script.args_from
|
||||
@ -608,7 +723,6 @@ class ScriptRunner:
|
||||
self.scripts[si].args_from = args_from
|
||||
self.scripts[si].args_to = args_to
|
||||
|
||||
|
||||
def before_hr(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
@ -617,6 +731,17 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||
|
||||
def setup_scrips(self, p, *, is_ui=True):
|
||||
for script in self.alwayson_scripts:
|
||||
if not is_ui and script.setup_for_ui_only:
|
||||
continue
|
||||
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.setup(p, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
@ -631,49 +756,3 @@ def reload_script_body_only():
|
||||
|
||||
|
||||
reload_scripts = load_scripts # compatibility alias
|
||||
|
||||
|
||||
def add_classes_to_gradio_component(comp):
|
||||
"""
|
||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||
"""
|
||||
|
||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||
|
||||
if getattr(comp, 'multiselect', False):
|
||||
comp.elem_classes.append('multiselect')
|
||||
|
||||
|
||||
|
||||
def IOComponent_init(self, *args, **kwargs):
|
||||
if scripts_current is not None:
|
||||
scripts_current.before_component(self, **kwargs)
|
||||
|
||||
script_callbacks.before_component_callback(self, **kwargs)
|
||||
|
||||
res = original_IOComponent_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
script_callbacks.after_component_callback(self, **kwargs)
|
||||
|
||||
if scripts_current is not None:
|
||||
scripts_current.after_component(self, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||
gr.components.IOComponent.__init__ = IOComponent_init
|
||||
|
||||
|
||||
def BlockContext_init(self, *args, **kwargs):
|
||||
res = original_BlockContext_init(self, *args, **kwargs)
|
||||
|
||||
add_classes_to_gradio_component(self)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
||||
|
@ -3,8 +3,31 @@ import open_clip
|
||||
import torch
|
||||
import transformers.utils.hub
|
||||
|
||||
from modules import shared
|
||||
|
||||
class DisableInitialization:
|
||||
|
||||
class ReplaceHelper:
|
||||
def __init__(self):
|
||||
self.replaced = []
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
original = getattr(obj, field, None)
|
||||
if original is None:
|
||||
return None
|
||||
|
||||
self.replaced.append((obj, field, original))
|
||||
setattr(obj, field, func)
|
||||
|
||||
return original
|
||||
|
||||
def restore(self):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
|
||||
class DisableInitialization(ReplaceHelper):
|
||||
"""
|
||||
When an object of this class enters a `with` block, it starts:
|
||||
- preventing torch's layer initialization functions from working
|
||||
@ -21,7 +44,7 @@ class DisableInitialization:
|
||||
"""
|
||||
|
||||
def __init__(self, disable_clip=True):
|
||||
self.replaced = []
|
||||
super().__init__()
|
||||
self.disable_clip = disable_clip
|
||||
|
||||
def replace(self, obj, field, func):
|
||||
@ -86,8 +109,124 @@ class DisableInitialization:
|
||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for obj, field, original in self.replaced:
|
||||
setattr(obj, field, original)
|
||||
self.restore()
|
||||
|
||||
self.replaced.clear()
|
||||
|
||||
class InitializeOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
def set_device(x):
|
||||
x["device"] = "meta"
|
||||
return x
|
||||
|
||||
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
||||
|
||||
class LoadStateDictOnMeta(ReplaceHelper):
|
||||
"""
|
||||
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||
Meant to be used together with InitializeOnMeta above.
|
||||
|
||||
Usage:
|
||||
```
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||
super().__init__()
|
||||
self.state_dict = state_dict
|
||||
self.device = device
|
||||
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||
|
||||
def get_weight_dtype(self, key):
|
||||
key_first_term, _ = key.split('.', 1)
|
||||
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||
|
||||
def __enter__(self):
|
||||
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||
return
|
||||
|
||||
sd = self.state_dict
|
||||
device = self.device
|
||||
|
||||
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||
used_param_keys = []
|
||||
|
||||
for name, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
|
||||
key = prefix + name
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||
used_param_keys.append(key)
|
||||
|
||||
if param.is_meta:
|
||||
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||
|
||||
for name in module._buffers:
|
||||
key = prefix + name
|
||||
|
||||
sd_param = sd.pop(key, None)
|
||||
if sd_param is not None:
|
||||
state_dict[key] = sd_param
|
||||
used_param_keys.append(key)
|
||||
|
||||
original(module, state_dict, prefix, *args, **kwargs)
|
||||
|
||||
for key in used_param_keys:
|
||||
state_dict.pop(key, None)
|
||||
|
||||
def load_state_dict(original, module, state_dict, strict=True):
|
||||
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||
|
||||
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||
|
||||
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||
the function and does not call the original) the state dict will just fail to load because weights
|
||||
would be on the meta device.
|
||||
"""
|
||||
|
||||
if state_dict == sd:
|
||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
original(module, state_dict, strict=strict)
|
||||
|
||||
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.restore()
|
||||
|
@ -2,7 +2,6 @@ import torch
|
||||
from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
@ -30,8 +29,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
ldm.modules.attention.print = shared.ldm_print
|
||||
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
ldm.util.print = shared.ldm_print
|
||||
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
@ -164,12 +165,13 @@ class StableDiffusionModelHijack:
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
def __init__(self):
|
||||
import modules.textual_inversion.textual_inversion
|
||||
|
||||
self.extra_generation_params = {}
|
||||
self.comments = []
|
||||
|
||||
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self, option=None):
|
||||
@ -197,7 +199,7 @@ class StableDiffusionModelHijack:
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
@ -243,7 +245,21 @@ class StableDiffusionModelHijack:
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
for i in range(len(conditioner.embedders)):
|
||||
embedder = conditioner.embedders[i]
|
||||
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
|
||||
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
|
||||
conditioner.embedders[i] = embedder.wrapped
|
||||
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
|
||||
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
|
||||
conditioner.embedders[i] = embedder.wrapped
|
||||
|
||||
if hasattr(m, 'cond_stage_model'):
|
||||
delattr(m, 'cond_stage_model')
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
@ -292,10 +308,11 @@ class StableDiffusionModelHijack:
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
@ -309,7 +326,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = devices.cond_cast_unet(vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
position += 1
|
||||
continue
|
||||
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
emb_len = int(embedding.vectors)
|
||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||
next_chunk()
|
||||
|
||||
@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
hashes.append(f"{name}: {shorthash}")
|
||||
|
||||
if hashes:
|
||||
if self.hijack.extra_generation_params.get("TI hashes"):
|
||||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
|
@ -1,97 +0,0 @@
|
||||
import torch
|
||||
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from ldm.models.diffusion.ddim import noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = {}
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import psutil
|
||||
import platform
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
name = "sub-quadratic"
|
||||
cmd_opt = "opt_sub_quad_attention"
|
||||
priority = 10
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return 1000 if shared.device.type == 'mps' else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return 1000 if not torch.cuda.is_available() else 10
|
||||
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
if chunk_threshold is None:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||
if q.device.type == 'mps':
|
||||
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||
else:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||
elif chunk_threshold == 0:
|
||||
chunk_threshold_bytes = None
|
||||
else:
|
||||
|
@ -14,8 +14,7 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
|
||||
@ -33,6 +32,8 @@ class CheckpointInfo:
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
@ -43,6 +44,19 @@ class CheckpointInfo:
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
def read_metadata():
|
||||
metadata = read_metadata_from_safetensors(filename)
|
||||
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||
|
||||
return metadata
|
||||
|
||||
self.metadata = {}
|
||||
if self.is_safetensors:
|
||||
try:
|
||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading metadata for {filename}")
|
||||
|
||||
self.name = name
|
||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||
@ -52,17 +66,11 @@ class CheckpointInfo:
|
||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||
|
||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||
|
||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||
|
||||
self.metadata = {}
|
||||
|
||||
_, ext = os.path.splitext(self.filename)
|
||||
if ext.lower() == ".safetensors":
|
||||
try:
|
||||
self.metadata = read_metadata_from_safetensors(filename)
|
||||
except Exception as e:
|
||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
|
||||
if self.shorthash:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||
|
||||
def register(self):
|
||||
checkpoints_list[self.title] = self
|
||||
@ -74,13 +82,18 @@ class CheckpointInfo:
|
||||
if self.sha256 is None:
|
||||
return
|
||||
|
||||
self.shorthash = self.sha256[0:10]
|
||||
shorthash = self.sha256[0:10]
|
||||
if self.shorthash == self.sha256[0:10]:
|
||||
return self.shorthash
|
||||
|
||||
self.shorthash = shorthash
|
||||
|
||||
if self.shorthash not in self.ids:
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||
|
||||
checkpoints_list.pop(self.title)
|
||||
checkpoints_list.pop(self.title, None)
|
||||
self.title = f'{self.name} [{self.shorthash}]'
|
||||
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||
self.register()
|
||||
|
||||
return self.shorthash
|
||||
@ -101,14 +114,8 @@ def setup_model():
|
||||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
def convert(name):
|
||||
return int(name) if name.isdigit() else name.lower()
|
||||
|
||||
def alphanumeric_key(key):
|
||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
||||
|
||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
def checkpoint_tiles(use_short=False):
|
||||
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
||||
|
||||
|
||||
def list_models():
|
||||
@ -131,12 +138,18 @@ def list_models():
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
for filename in sorted(model_list, key=str.lower):
|
||||
for filename in model_list:
|
||||
checkpoint_info = CheckpointInfo(filename)
|
||||
checkpoint_info.register()
|
||||
|
||||
|
||||
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||
|
||||
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
if not search_string:
|
||||
return None
|
||||
|
||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
@ -145,6 +158,11 @@ def get_closet_checkpoint_match(search_string):
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||
if found:
|
||||
return found[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -280,11 +298,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
return res
|
||||
|
||||
|
||||
class SkipWritingToConfig:
|
||||
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||
|
||||
skip = False
|
||||
previous = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous = SkipWritingToConfig.skip
|
||||
SkipWritingToConfig.skip = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
@ -297,18 +331,23 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
checkpoints_loaded[checkpoint_info] = state_dict
|
||||
|
||||
del state_dict
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
if shared.cmd_opts.no_half:
|
||||
model.float()
|
||||
devices.dtype_unet = torch.float32
|
||||
timer.record("apply float()")
|
||||
else:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
@ -324,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
@ -346,7 +385,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
@ -423,6 +462,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
class SdModelData:
|
||||
def __init__(self):
|
||||
self.sd_model = None
|
||||
self.loaded_sd_models = []
|
||||
self.was_loaded_at_least_once = False
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@ -437,6 +477,7 @@ class SdModelData:
|
||||
|
||||
try:
|
||||
load_model()
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||
print("", file=sys.stderr)
|
||||
@ -445,14 +486,30 @@ class SdModelData:
|
||||
|
||||
return self.sd_model
|
||||
|
||||
def set_sd_model(self, v):
|
||||
def set_sd_model(self, v, already_loaded=False):
|
||||
self.sd_model = v
|
||||
if already_loaded:
|
||||
sd_vae.base_vae = getattr(v, "base_vae", None)
|
||||
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||
|
||||
try:
|
||||
self.loaded_sd_models.remove(v)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if v is not None:
|
||||
self.loaded_sd_models.insert(0, v)
|
||||
|
||||
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img()
|
||||
extra_networks.activate(p, {})
|
||||
|
||||
if hasattr(sd_model, 'conditioner'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
return d['crossattn']
|
||||
@ -460,20 +517,46 @@ def get_empty_cond(sd_model):
|
||||
return sd_model.cond_stage_model([""])
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
if m.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
m.to(devices.cpu)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def model_target_device(m):
|
||||
if lowvram.is_needed(m):
|
||||
return devices.cpu
|
||||
else:
|
||||
return devices.device
|
||||
|
||||
|
||||
def send_model_to_device(m):
|
||||
lowvram.apply(m)
|
||||
|
||||
if not m.lowvram:
|
||||
m.to(shared.device)
|
||||
|
||||
|
||||
def send_model_to_trash(m):
|
||||
m.to(device="meta")
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
send_model_to_trash(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
timer = Timer()
|
||||
timer.record("unload existing model")
|
||||
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
@ -495,25 +578,35 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
except Exception:
|
||||
pass
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
except Exception as e:
|
||||
errors.display(e, "creating model quickly", full_traceback=True)
|
||||
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
if shared.cmd_opts.no_half:
|
||||
weight_dtype_conversion = None
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'': torch.float16,
|
||||
}
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
@ -521,7 +614,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
model_data.sd_model = sd_model
|
||||
model_data.set_sd_model(sd_model)
|
||||
model_data.was_loaded_at_least_once = True
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
@ -542,10 +635,70 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
return sd_model
|
||||
|
||||
|
||||
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
"""
|
||||
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||
If no such model exists, returns None.
|
||||
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||
"""
|
||||
|
||||
already_loaded = None
|
||||
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||
loaded_model = model_data.loaded_sd_models[i]
|
||||
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
already_loaded = loaded_model
|
||||
continue
|
||||
|
||||
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||
model_data.loaded_sd_models.pop()
|
||||
send_model_to_trash(loaded_model)
|
||||
timer.record("send model to trash")
|
||||
|
||||
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||
send_model_to_cpu(sd_model)
|
||||
timer.record("send model to cpu")
|
||||
|
||||
if already_loaded is not None:
|
||||
send_model_to_device(already_loaded)
|
||||
timer.record("send model to device")
|
||||
|
||||
model_data.set_sd_model(already_loaded, already_loaded=True)
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
|
||||
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
|
||||
|
||||
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||
sd_vae.reload_vae_weights(already_loaded)
|
||||
return model_data.sd_model
|
||||
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||
|
||||
model_data.sd_model = None
|
||||
load_model(checkpoint_info)
|
||||
return model_data.sd_model
|
||||
elif len(model_data.loaded_sd_models) > 0:
|
||||
sd_model = model_data.loaded_sd_models.pop()
|
||||
model_data.sd_model = sd_model
|
||||
|
||||
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
|
||||
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
|
||||
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
|
||||
|
||||
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||
return sd_model
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
if not sd_model:
|
||||
sd_model = model_data.sd_model
|
||||
|
||||
@ -554,19 +707,17 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
sd_unet.apply_unet("None")
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
send_model_to_cpu(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
@ -574,7 +725,9 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
if sd_model is not None:
|
||||
send_model_to_trash(sd_model)
|
||||
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return model_data.sd_model
|
||||
|
||||
@ -591,17 +744,19 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
sd_unet.apply_unet()
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
from modules import devices, sd_hijack
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared, paths, sd_disable_initialization
|
||||
from modules import shared, paths, sd_disable_initialization, devices
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
from modules import devices
|
||||
|
||||
device = devices.cpu
|
||||
|
||||
|
31
modules/sd_models_types.py
Normal file
31
modules/sd_models_types.py
Normal file
@ -0,0 +1,31 @@
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from modules.sd_models import CheckpointInfo
|
||||
|
||||
|
||||
class WebuiSdModel(LatentDiffusion):
|
||||
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||
|
||||
lowvram: bool
|
||||
"""True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
|
||||
|
||||
sd_model_hash: str
|
||||
"""short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
|
||||
|
||||
sd_model_checkpoint: str
|
||||
"""path to the file on disk that model weights were obtained from"""
|
||||
|
||||
sd_checkpoint_info: 'CheckpointInfo'
|
||||
"""structure with additional information about the file with model's weights"""
|
||||
|
||||
is_sdxl: bool
|
||||
"""True if the model's architecture is SDXL"""
|
||||
|
||||
is_sd2: bool
|
||||
"""True if the model's architecture is SD 2.x"""
|
||||
|
||||
is_sd1: bool
|
||||
"""True if the model's architecture is SD 1.x"""
|
@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||
return embedder.tokenize(texts)
|
||||
|
||||
raise AssertionError('no tokenizer available')
|
||||
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
@ -89,10 +98,10 @@ def extend_sdxl(model):
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||
sgm.modules.encoders.modules.print = lambda *args: None
|
||||
sgm.modules.attention.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
|
@ -1,17 +1,18 @@
|
||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||
|
||||
all_samplers = [
|
||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||
*sd_samplers_compvis.samplers_data_compvis,
|
||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
samplers_hidden = {}
|
||||
|
||||
|
||||
def find_sampler_config(name):
|
||||
@ -38,13 +39,11 @@ def create_sampler(name, model):
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
global samplers, samplers_for_img2img, samplers_hidden
|
||||
|
||||
hidden = set(shared.opts.hide_samplers)
|
||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
samplers_hidden = set(shared.opts.hide_samplers)
|
||||
samplers = all_samplers
|
||||
samplers_for_img2img = all_samplers
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
@ -53,4 +52,8 @@ def set_samplers():
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
|
||||
def visible_sampler_names():
|
||||
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
230
modules/sd_samplers_cfg_denoiser.py
Normal file
230
modules/sd_samplers_cfg_denoiser.py
Normal file
@ -0,0 +1,230 @@
|
||||
import torch
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__()
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
"""number of steps as specified by user in UI"""
|
||||
|
||||
self.total_steps = None
|
||||
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = sampler
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
self.mask_before_denoising = False
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def update_inner_model(self):
|
||||
self.model_wrap = None
|
||||
|
||||
c, uc = self.p.get_conds()
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if sd_samplers_common.apply_refiner(self):
|
||||
cond = self.sampler.sampler_extra_args['cond']
|
||||
uncond = self.sampler.sampler_extra_args['uncond']
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.opts.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if not self.mask_before_denoising and self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
preview = self.sampler.last_latent
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||
else:
|
||||
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
@ -1,13 +1,22 @@
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
||||
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
import k_diffusion.sampling
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
||||
class SamplerData(SamplerDataTuple):
|
||||
def total_steps(self, steps):
|
||||
if self.options.get("second_order", False):
|
||||
steps = steps * 2
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
@ -25,19 +34,34 @@ def setup_img2img_steps(p, steps=None):
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
|
||||
|
||||
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
from modules import lowvram
|
||||
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
||||
approximation = 1
|
||||
|
||||
if approximation == 2:
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||
elif approximation == 1:
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
elif approximation == 3:
|
||||
x_sample = sample * 1.5
|
||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
else:
|
||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
|
||||
return x_sample
|
||||
|
||||
|
||||
def single_sample_to_image(sample, approximation=None):
|
||||
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||
|
||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||
@ -46,6 +70,12 @@ def single_sample_to_image(sample, approximation=None):
|
||||
return Image.fromarray(x_sample)
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
x = x.to(devices.dtype_vae)
|
||||
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
|
||||
def sample_to_image(samples, index=0, approximation=None):
|
||||
return single_sample_to_image(samples[index], approximation)
|
||||
|
||||
@ -54,6 +84,32 @@ def samples_to_image_grid(samples, approximation=None):
|
||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
'''image[0, 1] -> latent'''
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
|
||||
|
||||
if approximation == 3:
|
||||
image = image.to(devices.device, devices.dtype)
|
||||
x_latent = sd_vae_taesd.encoder_model()(image)
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
if len(image) > 1:
|
||||
x_latent = torch.stack([
|
||||
model.get_first_stage_encoding(
|
||||
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||
)[0]
|
||||
for img in image
|
||||
])
|
||||
else:
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
|
||||
return x_latent
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
state.current_latent = decoded
|
||||
|
||||
@ -85,11 +141,186 @@ class InterruptedException(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
if opts.randn_source == "CPU":
|
||||
def replace_torchsde_browinan():
|
||||
import torchsde._brownian.brownian_interval
|
||||
|
||||
def torchsde_randn(size, dtype, device, seed):
|
||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||
|
||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||
|
||||
|
||||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(cfg_denoiser):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
return False
|
||||
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
return False
|
||||
|
||||
if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
|
||||
return False
|
||||
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
"""This is here to replace torch.randn_like of k-diffusion.
|
||||
|
||||
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
||||
this is needed to properly replace every use of torch.randn_like.
|
||||
|
||||
We need to replace to make images generated in batches to be same as images generated individually."""
|
||||
|
||||
def __init__(self, p):
|
||||
self.rng = p.rng
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
return self.rng.next()
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, funcname):
|
||||
self.funcname = funcname
|
||||
self.func = funcname
|
||||
self.extra_params = []
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config: SamplerData = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.eta_option_field = 'eta_ancestral'
|
||||
self.eta_infotext_field = 'Eta'
|
||||
self.eta_default = 1.0
|
||||
|
||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||
|
||||
self.p = None
|
||||
self.model_wrap_cfg = None
|
||||
self.sampler_extra_args = None
|
||||
self.options = {}
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.model_wrap_cfg.steps = steps
|
||||
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p) -> dict:
|
||||
self.p = p
|
||||
self.model_wrap_cfg.p = p
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(p)
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != self.eta_default:
|
||||
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
@ -1,224 +0,0 @@
|
||||
import math
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules.shared import state
|
||||
from modules import sd_samplers_common, prompt_parser, shared
|
||||
import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
||||
]
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||
self.orig_p_sample_ddim = None
|
||||
if self.is_plms:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
||||
elif self.is_ddim:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None
|
||||
self.last_latent = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
return res
|
||||
|
||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||
image_conditioning = None
|
||||
uc_image_conditioning = None
|
||||
if isinstance(cond, dict):
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
image_conditioning = cond["c_adm"]
|
||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
||||
else:
|
||||
image_conditioning = cond["c_concat"][0]
|
||||
cond = cond["c_crossattn"][0]
|
||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||
cond = tensor
|
||||
|
||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
||||
# not 100% correct but should work well enough
|
||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
||||
last_vector = unconditional_conditioning[:, -1:]
|
||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x = img_orig * self.mask + self.nmask * x
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
return x, ts, cond, unconditional_conditioning
|
||||
|
||||
def update_step(self, last_latent):
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
||||
else:
|
||||
self.last_latent = last_latent
|
||||
|
||||
sd_samplers_common.store_latent(self.last_latent)
|
||||
|
||||
self.step += 1
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def after_sample(self, x, ts, cond, uncond, res):
|
||||
if not self.is_unipc:
|
||||
self.update_step(res[1])
|
||||
|
||||
return x, ts, cond, uncond, res
|
||||
|
||||
def unipc_after_update(self, x, model_x):
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
if self.is_ddim:
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
else:
|
||||
self.eta = 0.0
|
||||
|
||||
if self.eta != 0.0:
|
||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||
|
||||
if self.is_unipc:
|
||||
keys = [
|
||||
('UniPC variant', 'uni_pc_variant'),
|
||||
('UniPC skip type', 'uni_pc_skip_type'),
|
||||
('UniPC order', 'uni_pc_order'),
|
||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||
]
|
||||
|
||||
for name, key in keys:
|
||||
v = getattr(shared.opts, key)
|
||||
if v != shared.opts.get_default(key):
|
||||
p.extra_generation_params[name] = v
|
||||
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
if self.is_unipc:
|
||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
||||
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
||||
num_steps = shared.opts.uni_pc_order
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == math.floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
||||
return num_steps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
self.initialize(p)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.init_latent = x
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
||||
else:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
self.initialize(p)
|
||||
|
||||
self.init_latent = None
|
||||
self.last_latent = x
|
||||
self.step = 0
|
||||
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
if self.conditioning_key == "crossattn-adm":
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
||||
else:
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
return samples_ddim
|
74
modules/sd_samplers_extra.py
Normal file
74
modules/sd_samplers_extra.py
Normal file
@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
|
||||
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
|
||||
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
|
||||
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
step_id = 0
|
||||
from k_diffusion.sampling import to_d, get_sigmas_karras
|
||||
|
||||
def heun_step(x, old_sigma, new_sigma, second_order=True):
|
||||
nonlocal step_id
|
||||
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||
d = to_d(x, old_sigma, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||
dt = new_sigma - old_sigma
|
||||
if new_sigma == 0 or not second_order:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
step_id += 1
|
||||
return x
|
||||
|
||||
steps = sigmas.shape[0] - 1
|
||||
if restart_list is None:
|
||||
if steps >= 20:
|
||||
restart_steps = 9
|
||||
restart_times = 1
|
||||
if steps >= 36:
|
||||
restart_steps = steps // 4
|
||||
restart_times = 2
|
||||
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
|
||||
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
|
||||
else:
|
||||
restart_list = {}
|
||||
|
||||
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
|
||||
|
||||
step_list = []
|
||||
for i in range(len(sigmas) - 1):
|
||||
step_list.append((sigmas[i], sigmas[i + 1]))
|
||||
if i + 1 in restart_list:
|
||||
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||
min_idx = i + 1
|
||||
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||
if max_idx < min_idx:
|
||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||
while restart_times > 0:
|
||||
restart_times -= 1
|
||||
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
||||
|
||||
last_sigma = None
|
||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||
if last_sigma is None:
|
||||
last_sigma = old_sigma
|
||||
elif last_sigma < old_sigma:
|
||||
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
|
||||
x = heun_step(x, old_sigma, new_sigma)
|
||||
last_sigma = new_sigma
|
||||
|
||||
return x
|
@ -1,47 +1,60 @@
|
||||
from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
from modules import prompt_parser, devices, sd_samplers_common
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts, state
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
||||
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
||||
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_fast': ['s_noise'],
|
||||
'sample_dpm_2_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_sde': ['s_noise'],
|
||||
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||
}
|
||||
|
||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
@ -53,289 +66,27 @@ k_diffusion_scheduler = {
|
||||
}
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname)
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||
# so is_edit_model is set to False to support AND composition.
|
||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||
|
||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||
cfg_denoiser_callback(denoiser_params)
|
||||
x_in = denoiser_params.x
|
||||
image_cond_in = denoiser_params.image_cond
|
||||
sigma_in = denoiser_params.sigma
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
self.padded_cond_uncond = False
|
||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||
else:
|
||||
return torch.randn_like(x)
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
|
||||
self.conditioning_key = sd_model.model.conditioning_key
|
||||
self.options = options or {}
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
latent = d["denoised"]
|
||||
if opts.live_preview_content == "Combined":
|
||||
sd_samplers_common.store_latent(latent)
|
||||
self.last_latent = latent
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except sd_samplers_common.InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params["Eta"] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
return extra_params_kwargs
|
||||
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||
|
||||
def get_sigmas(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
@ -376,6 +127,9 @@ class KDiffusionSampler:
|
||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
@ -384,24 +138,21 @@ class KDiffusionSampler:
|
||||
|
||||
return sigmas
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||
|
||||
xi = x + noise * sigma_sched[0]
|
||||
|
||||
if opts.img2img_extra_noise > 0:
|
||||
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||
extra_noise_params = ExtraNoiseParams(noise, x)
|
||||
extra_noise_callback(extra_noise_params)
|
||||
noise = extra_noise_params.noise
|
||||
xi += noise * opts.img2img_extra_noise
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
@ -421,9 +172,12 @@ class KDiffusionSampler:
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
if self.config.options.get('solver_type', None) == 'heun':
|
||||
extra_params_kwargs['solver_type'] = 'heun'
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
extra_args = {
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
@ -431,7 +185,7 @@ class KDiffusionSampler:
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
@ -448,29 +202,37 @@ class KDiffusionSampler:
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
|
||||
if 'sigma_min' in parameters:
|
||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||
if 'n' in parameters:
|
||||
extra_params_kwargs['n'] = steps
|
||||
else:
|
||||
|
||||
if 'sigmas' in parameters:
|
||||
extra_params_kwargs['sigmas'] = sigmas
|
||||
|
||||
if self.config.options.get('brownian_noise', False):
|
||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||
|
||||
if self.config.options.get('solver_type', None) == 'heun':
|
||||
extra_params_kwargs['solver_type'] = 'heun'
|
||||
|
||||
self.last_latent = x
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
|
167
modules/sd_samplers_timesteps.py
Normal file
167
modules/sd_samplers_timesteps.py
Normal file
@ -0,0 +1,167 @@
|
||||
import torch
|
||||
import inspect
|
||||
import sys
|
||||
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
|
||||
samplers_timesteps = [
|
||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||
]
|
||||
|
||||
|
||||
samplers_data_timesteps = [
|
||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_timesteps
|
||||
]
|
||||
|
||||
|
||||
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
|
||||
|
||||
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.inner_model = model
|
||||
|
||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||
|
||||
def forward(self, input, timesteps, **kwargs):
|
||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||
return e_t
|
||||
|
||||
|
||||
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||
|
||||
def __init__(self, sampler):
|
||||
super().__init__(sampler)
|
||||
|
||||
self.alphas = shared.sd_model.alphas_cumprod
|
||||
self.mask_before_denoising = True
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
ts = sigma.to(dtype=int)
|
||||
|
||||
a_t = self.alphas[ts][:, None, None, None]
|
||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||
|
||||
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||
|
||||
return pred_x0
|
||||
|
||||
@property
|
||||
def inner_model(self):
|
||||
if self.model_wrap is None:
|
||||
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||
self.model_wrap = denoiser(shared.sd_model)
|
||||
|
||||
return self.model_wrap
|
||||
|
||||
|
||||
class CompVisSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.eta_option_field = 'eta_ddim'
|
||||
self.eta_infotext_field = 'Eta DDIM'
|
||||
self.eta_default = 0.0
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||
|
||||
def get_timesteps(self, p, steps):
|
||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||
discard_next_to_last_sigma = True
|
||||
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||
|
||||
steps += 1 if discard_next_to_last_sigma else 0
|
||||
|
||||
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
timesteps_sched = timesteps[:t_enc]
|
||||
|
||||
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||
|
||||
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||
|
||||
if opts.img2img_extra_noise > 0:
|
||||
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||
extra_noise_params = ExtraNoiseParams(noise, x)
|
||||
extra_noise_callback(extra_noise_params)
|
||||
noise = extra_noise_params.noise
|
||||
xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||
if 'is_img2img' in parameters:
|
||||
extra_params_kwargs['is_img2img'] = True
|
||||
|
||||
self.model_wrap_cfg.init_latent = x
|
||||
self.last_latent = x
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps = steps or p.steps
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
parameters = inspect.signature(self.func).parameters
|
||||
|
||||
if 'timesteps' in parameters:
|
||||
extra_params_kwargs['timesteps'] = timesteps
|
||||
|
||||
self.last_latent = x
|
||||
self.sampler_extra_args = {
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
if self.model_wrap_cfg.padded_cond_uncond:
|
||||
p.extra_generation_params["Pad conds"] = True
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
|
||||
VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
|
137
modules/sd_samplers_timesteps_impl.py
Normal file
137
modules/sd_samplers_timesteps_impl.py
Normal file
@ -0,0 +1,137 @@
|
||||
import torch
|
||||
import tqdm
|
||||
import k_diffusion.sampling
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.models.diffusion.uni_pc import uni_pc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones((x.shape[0]))
|
||||
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
|
||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||
|
||||
a_t = alphas[index].item() * s_x
|
||||
a_prev = alphas_prev[index].item() * s_x
|
||||
sigma_t = sigmas[index].item() * s_x
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||
old_eps = []
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = alphas[index].item() * s_x
|
||||
a_prev = alphas_prev[index].item() * s_x
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
return x_prev, pred_x0
|
||||
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
ts = timesteps[index].item() * s_in
|
||||
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||
|
||||
e_t = model(x, ts, **extra_args)
|
||||
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = model(x_prev, t_next, **extra_args)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
else:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
|
||||
x = x_prev
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UniPCCFG(uni_pc.UniPC):
|
||||
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||
super().__init__(None, *args, **kwargs)
|
||||
|
||||
def after_update(x, model_x):
|
||||
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||
self.index += 1
|
||||
|
||||
self.cfg_model = cfg_model
|
||||
self.extra_args = extra_args
|
||||
self.callback = callback
|
||||
self.index = 0
|
||||
self.after_update = after_update
|
||||
|
||||
def get_model_input_time(self, t_continuous):
|
||||
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||
|
||||
def model(self, x, t):
|
||||
t_input = self.get_model_input_time(t)
|
||||
|
||||
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
|
||||
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||
|
||||
return x
|
@ -47,7 +47,7 @@ def apply_unet(option=None):
|
||||
if current_unet_option is None:
|
||||
current_unet = None
|
||||
|
||||
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
||||
if not shared.sd_model.lowvram:
|
||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||
|
||||
return
|
||||
|
@ -1,6 +1,9 @@
|
||||
import os
|
||||
import collections
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
|
||||
@ -16,6 +19,23 @@ checkpoint_info = None
|
||||
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
|
||||
def get_loaded_vae_name():
|
||||
if loaded_vae_file is None:
|
||||
return None
|
||||
|
||||
return os.path.basename(loaded_vae_file)
|
||||
|
||||
|
||||
def get_loaded_vae_hash():
|
||||
if loaded_vae_file is None:
|
||||
return None
|
||||
|
||||
sha256 = hashes.sha256(loaded_vae_file, 'vae')
|
||||
|
||||
return sha256[0:10] if sha256 else None
|
||||
|
||||
|
||||
def get_base_vae(model):
|
||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||
return base_vae
|
||||
@ -83,6 +103,8 @@ def refresh_vae_list():
|
||||
name = get_filename(filepath)
|
||||
vae_dict[name] = filepath
|
||||
|
||||
vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
|
||||
|
||||
|
||||
def find_vae_near_checkpoint(checkpoint_file):
|
||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||
@ -93,27 +115,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
||||
return None
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file):
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||
@dataclass
|
||||
class VaeResolution:
|
||||
vae: str = None
|
||||
source: str = None
|
||||
resolved: bool = True
|
||||
|
||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
def tuple(self):
|
||||
return self.vae, self.source
|
||||
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||
return vae_near_checkpoint, 'found near the checkpoint'
|
||||
|
||||
def is_automatic():
|
||||
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||
|
||||
|
||||
def resolve_vae_from_setting() -> VaeResolution:
|
||||
if shared.opts.sd_vae == "None":
|
||||
return None, None
|
||||
return VaeResolution()
|
||||
|
||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||
if vae_from_options is not None:
|
||||
return vae_from_options, 'specified in settings'
|
||||
return VaeResolution(vae_from_options, 'specified in settings')
|
||||
|
||||
if not is_automatic:
|
||||
if not is_automatic():
|
||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||
|
||||
return None, None
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
|
||||
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||
vae_metadata = metadata.get("vae", None)
|
||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||
if vae_metadata == "None":
|
||||
return VaeResolution()
|
||||
|
||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||
if vae_from_metadata is not None:
|
||||
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
|
||||
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||
if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):
|
||||
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||
|
||||
return VaeResolution(resolved=False)
|
||||
|
||||
|
||||
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||
if shared.cmd_opts.vae_path is not None:
|
||||
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||
|
||||
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||
return resolve_vae_from_setting()
|
||||
|
||||
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||
if res.resolved:
|
||||
return res
|
||||
|
||||
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||
if res.resolved:
|
||||
return res
|
||||
|
||||
res = resolve_vae_from_setting()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def load_vae_dict(filename, map_location):
|
||||
@ -123,7 +192,7 @@ def load_vae_dict(filename, map_location):
|
||||
|
||||
|
||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
global vae_dict, loaded_vae_file
|
||||
global vae_dict, base_vae, loaded_vae_file
|
||||
# save_settings = False
|
||||
|
||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||
@ -161,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
restore_base_vae(model)
|
||||
|
||||
loaded_vae_file = vae_file
|
||||
model.base_vae = base_vae
|
||||
model.loaded_vae_file = loaded_vae_file
|
||||
|
||||
|
||||
# don't call this from outside
|
||||
@ -178,8 +249,6 @@ unspecified = object()
|
||||
|
||||
|
||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
from modules import lowvram, devices, sd_hijack
|
||||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
@ -187,14 +256,14 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
|
||||
if vae_file == unspecified:
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
||||
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||
else:
|
||||
vae_source = "from function argument"
|
||||
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
if sd_model.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
@ -206,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
|
@ -81,6 +81,6 @@ def cheap_approximation(sample):
|
||||
|
||||
coefs = torch.tensor(coeffs).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
|
||||
|
||||
return x_sample
|
||||
|
@ -44,7 +44,17 @@ def decoder():
|
||||
)
|
||||
|
||||
|
||||
class TAESD(nn.Module):
|
||||
def encoder():
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 4),
|
||||
)
|
||||
|
||||
|
||||
class TAESDDecoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
@ -55,21 +65,28 @@ class TAESD(nn.Module):
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
class TAESDEncoder(nn.Module):
|
||||
latent_magnitude = 3
|
||||
latent_shift = 0.5
|
||||
|
||||
def __init__(self, encoder_path="taesd_encoder.pth"):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.encoder = encoder()
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading TAESD decoder to: {model_path}')
|
||||
print(f'Downloading TAESD model to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
def decoder_model():
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
@ -78,7 +95,7 @@ def model():
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_model = TAESD(model_path)
|
||||
loaded_model = TAESDDecoder(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
@ -86,3 +103,22 @@ def model():
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return loaded_model.decoder
|
||||
|
||||
|
||||
def encoder_model():
|
||||
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
loaded_model = TAESDEncoder(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return loaded_model.encoder
|
||||
|
@ -1,771 +1,51 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import launch
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from typing import Optional
|
||||
from modules import util
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
cmd_opts = shared_cmd_options.cmd_opts
|
||||
parser = shared_cmd_options.parser
|
||||
|
||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||
parallel_processing_allowed = True
|
||||
styles_filename = cmd_opts.styles_file
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
demo = None
|
||||
|
||||
parser = cmd_args.parser
|
||||
device = None
|
||||
|
||||
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||
weight_load_location = None
|
||||
|
||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
cmd_opts = parser.parse_args()
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
"outdir_samples",
|
||||
"outdir_txt2img_samples",
|
||||
"outdir_img2img_samples",
|
||||
"outdir_extras_samples",
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/glass",
|
||||
"gradio/monochrome",
|
||||
"gradio/seafoam",
|
||||
"gradio/soft",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"gradio/dracula_test",
|
||||
"abidlabs/dracula_test",
|
||||
"abidlabs/pakistan",
|
||||
"dawood/microsoft_windows",
|
||||
"ysharma/steampunk"
|
||||
]
|
||||
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
device = devices.device
|
||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||
xformers_available = False
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = {}
|
||||
|
||||
loaded_hypernetworks = []
|
||||
|
||||
state = None
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
global hypernetworks
|
||||
prompt_styles = None
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
|
||||
|
||||
class State:
|
||||
skipped = False
|
||||
interrupted = False
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
server_start = None
|
||||
_server_command_signal = threading.Event()
|
||||
_server_command: Optional[str] = None
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
# Compatibility getter for need_restart.
|
||||
return self.server_command == "restart"
|
||||
|
||||
@need_restart.setter
|
||||
def need_restart(self, value: bool) -> None:
|
||||
# Compatibility setter for need_restart.
|
||||
if value:
|
||||
self.server_command = "restart"
|
||||
|
||||
@property
|
||||
def server_command(self):
|
||||
return self._server_command
|
||||
|
||||
@server_command.setter
|
||||
def server_command(self, value: Optional[str]) -> None:
|
||||
"""
|
||||
Set the server command to `value` and signal that it's been set.
|
||||
"""
|
||||
self._server_command = value
|
||||
self._server_command_signal.set()
|
||||
|
||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
Wait for server command to get set; return and clear the value and signal.
|
||||
"""
|
||||
if self._server_command_signal.wait(timeout):
|
||||
self._server_command_signal.clear()
|
||||
req = self._server_command
|
||||
self._server_command = None
|
||||
return req
|
||||
return None
|
||||
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def nextjob(self):
|
||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def set_current_image(self):
|
||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
||||
if not parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
||||
|
||||
|
||||
state = State()
|
||||
state.server_start = time.time()
|
||||
|
||||
styles_filename = cmd_opts.styles_file
|
||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
||||
|
||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||
interrogator = None
|
||||
|
||||
face_restorers = []
|
||||
|
||||
options_templates = None
|
||||
opts = None
|
||||
restricted_opts = None
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
self.section = section
|
||||
self.refresh = refresh
|
||||
|
||||
self.comment_before = comment_before
|
||||
"""HTML text that will be added after label in UI"""
|
||||
|
||||
self.comment_after = comment_after
|
||||
"""HTML text that will be added before label in UI"""
|
||||
|
||||
def link(self, label, url):
|
||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def js(self, label, js_func):
|
||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||
return self
|
||||
|
||||
def info(self, info):
|
||||
self.comment_after += f"<span class='info'>({info})</span>"
|
||||
return self
|
||||
|
||||
def html(self, html):
|
||||
self.comment_after += html
|
||||
return self
|
||||
|
||||
def needs_restart(self):
|
||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||
return self
|
||||
|
||||
|
||||
|
||||
|
||||
def options_section(section_identifier, options_dict):
|
||||
for v in options_dict.values():
|
||||
v.section = section_identifier
|
||||
|
||||
return options_dict
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.list_models()
|
||||
|
||||
|
||||
def list_samplers():
|
||||
import modules.sd_samplers
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
tab_names = []
|
||||
|
||||
options_templates = {}
|
||||
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
||||
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
||||
</ul>"""),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update()
|
||||
|
||||
|
||||
class Options:
|
||||
data = None
|
||||
data_labels = options_templates
|
||||
typemap = {int: float}
|
||||
|
||||
def __init__(self):
|
||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if self.data is not None:
|
||||
if key in self.data or key in self.data_labels:
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
info = opts.data_labels.get(key, None)
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
|
||||
return super(Options, self).__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if self.data is not None:
|
||||
if item in self.data:
|
||||
return self.data[item]
|
||||
|
||||
if item in self.data_labels:
|
||||
return self.data_labels[item].default
|
||||
|
||||
return super(Options, self).__getattribute__(item)
|
||||
|
||||
def set(self, key, value):
|
||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||
|
||||
oldval = self.data.get(key, None)
|
||||
if oldval == value:
|
||||
return False
|
||||
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
if self.data_labels[key].onchange is not None:
|
||||
try:
|
||||
self.data_labels[key].onchange()
|
||||
except Exception as e:
|
||||
errors.display(e, f"changing setting {key} to {value}")
|
||||
setattr(self, key, oldval)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_default(self, key):
|
||||
"""returns the default value for the key"""
|
||||
|
||||
data_label = self.data_labels.get(key)
|
||||
if data_label is None:
|
||||
return None
|
||||
|
||||
return data_label.default
|
||||
|
||||
def save(self, filename):
|
||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||
|
||||
with open(filename, "w", encoding="utf8") as file:
|
||||
json.dump(self.data, file, indent=4)
|
||||
|
||||
def same_type(self, x, y):
|
||||
if x is None or y is None:
|
||||
return True
|
||||
|
||||
type_x = self.typemap.get(type(x), type(x))
|
||||
type_y = self.typemap.get(type(y), type(y))
|
||||
|
||||
return type_x == type_y
|
||||
|
||||
def load(self, filename):
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
# 1.1.1 quicksettings list migration
|
||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||
|
||||
# 1.4.0 ui_reorder
|
||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||
|
||||
bad_settings = 0
|
||||
for k, v in self.data.items():
|
||||
info = self.data_labels.get(k, None)
|
||||
if info is not None and not self.same_type(info.default, v):
|
||||
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
||||
bad_settings += 1
|
||||
|
||||
if bad_settings > 0:
|
||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||
|
||||
def onchange(self, key, func, call=True):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
if call:
|
||||
func()
|
||||
|
||||
def dumpjson(self):
|
||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||
return json.dumps(d)
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
|
||||
def reorder(self):
|
||||
"""reorder settings so that all items related to section always go together"""
|
||||
|
||||
section_ids = {}
|
||||
settings_items = self.data_labels.items()
|
||||
for _, item in settings_items:
|
||||
if item.section not in section_ids:
|
||||
section_ids[item.section] = len(section_ids)
|
||||
|
||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||
|
||||
def cast_value(self, key, value):
|
||||
"""casts an arbitrary to the same type as this setting's value with key
|
||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
default_value = self.data_labels[key].default
|
||||
if default_value is None:
|
||||
default_value = getattr(self, key, None)
|
||||
if default_value is None:
|
||||
return None
|
||||
|
||||
expected_type = type(default_value)
|
||||
if expected_type == bool and value == "False":
|
||||
value = False
|
||||
else:
|
||||
value = expected_type(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
|
||||
class Shared(sys.modules[__name__].__class__):
|
||||
"""
|
||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||
at program startup.
|
||||
"""
|
||||
|
||||
sd_model_val = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
import modules.sd_models
|
||||
|
||||
return modules.sd_models.model_data.get_sd_model()
|
||||
|
||||
@sd_model.setter
|
||||
def sd_model(self, value):
|
||||
import modules.sd_models
|
||||
|
||||
modules.sd_models.model_data.set_sd_model(value)
|
||||
|
||||
|
||||
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
||||
sys.modules[__name__].__class__ = Shared
|
||||
sd_model: sd_models_types.WebuiSdModel = None
|
||||
|
||||
settings_components = None
|
||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||
|
||||
tab_names = []
|
||||
|
||||
latent_upscale_default_mode = "Latent"
|
||||
latent_upscale_modes = {
|
||||
"Latent": {"mode": "bilinear", "antialias": False},
|
||||
@ -784,108 +64,24 @@ progress_print_out = sys.stdout
|
||||
|
||||
gradio_theme = gr.themes.Base()
|
||||
|
||||
total_tqdm = None
|
||||
|
||||
def reload_gradio_theme(theme_name=None):
|
||||
global gradio_theme
|
||||
if not theme_name:
|
||||
theme_name = opts.gradio_theme
|
||||
mem_mon = None
|
||||
|
||||
default_theme_args = dict(
|
||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||
)
|
||||
options_section = options.options_section
|
||||
OptionInfo = options.OptionInfo
|
||||
OptionHTML = options.OptionHTML
|
||||
|
||||
if theme_name == "Default":
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
else:
|
||||
try:
|
||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
except Exception as e:
|
||||
errors.display(e, "changing gradio theme")
|
||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
natural_sort_key = util.natural_sort_key
|
||||
listfiles = util.listfiles
|
||||
html_path = util.html_path
|
||||
html = util.html
|
||||
walk_files = util.walk_files
|
||||
ldm_print = util.ldm_print
|
||||
|
||||
reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
|
||||
|
||||
|
||||
class TotalTQDM:
|
||||
def __init__(self):
|
||||
self._tqdm = None
|
||||
|
||||
def reset(self):
|
||||
self._tqdm = tqdm.tqdm(
|
||||
desc="Total progress",
|
||||
total=state.job_count * state.sampling_steps,
|
||||
position=1,
|
||||
file=progress_print_out
|
||||
)
|
||||
|
||||
def update(self):
|
||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.update()
|
||||
|
||||
def updateTotal(self, new_total):
|
||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.total = new_total
|
||||
|
||||
def clear(self):
|
||||
if self._tqdm is not None:
|
||||
self._tqdm.refresh()
|
||||
self._tqdm.close()
|
||||
self._tqdm = None
|
||||
|
||||
|
||||
total_tqdm = TotalTQDM()
|
||||
|
||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
||||
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
||||
|
||||
|
||||
def listfiles(dirname):
|
||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
||||
return [file for file in filenames if os.path.isfile(file)]
|
||||
|
||||
|
||||
def html_path(filename):
|
||||
return os.path.join(script_path, "html", filename)
|
||||
|
||||
|
||||
def html(filename):
|
||||
path = html_path(filename)
|
||||
|
||||
if os.path.exists(path):
|
||||
with open(path, encoding="utf8") as file:
|
||||
return file.read()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def walk_files(path, allowed_extensions=None):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
if allowed_extensions is not None:
|
||||
allowed_extensions = set(allowed_extensions)
|
||||
|
||||
items = list(os.walk(path, followlinks=True))
|
||||
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
||||
|
||||
for root, _, files in items:
|
||||
for filename in sorted(files, key=natural_sort_key):
|
||||
if allowed_extensions is not None:
|
||||
_, ext = os.path.splitext(filename)
|
||||
if ext not in allowed_extensions:
|
||||
continue
|
||||
|
||||
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||
continue
|
||||
|
||||
yield os.path.join(root, filename)
|
||||
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
|
||||
refresh_checkpoints = shared_items.refresh_checkpoints
|
||||
list_samplers = shared_items.list_samplers
|
||||
reload_hypernetworks = shared_items.reload_hypernetworks
|
||||
|
18
modules/shared_cmd_options.py
Normal file
18
modules/shared_cmd_options.py
Normal file
@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
import launch
|
||||
from modules import cmd_args, script_loading
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
|
||||
parser = cmd_args.parser
|
||||
|
||||
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||
|
||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
cmd_opts = parser.parse_args()
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
|
||||
cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access
|
67
modules/shared_gradio_themes.py
Normal file
67
modules/shared_gradio_themes.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors, shared
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||
gradio_hf_hub_themes = [
|
||||
"gradio/base",
|
||||
"gradio/glass",
|
||||
"gradio/monochrome",
|
||||
"gradio/seafoam",
|
||||
"gradio/soft",
|
||||
"gradio/dracula_test",
|
||||
"abidlabs/dracula_test",
|
||||
"abidlabs/Lime",
|
||||
"abidlabs/pakistan",
|
||||
"Ama434/neutral-barlow",
|
||||
"dawood/microsoft_windows",
|
||||
"finlaymacklon/smooth_slate",
|
||||
"Franklisi/darkmode",
|
||||
"freddyaboulton/dracula_revamped",
|
||||
"freddyaboulton/test-blue",
|
||||
"gstaff/xkcd",
|
||||
"Insuz/Mocha",
|
||||
"Insuz/SimpleIndigo",
|
||||
"JohnSmith9982/small_and_pretty",
|
||||
"nota-ai/theme",
|
||||
"nuttea/Softblue",
|
||||
"ParityError/Anime",
|
||||
"reilnuud/polite",
|
||||
"remilia/Ghostly",
|
||||
"rottenlittlecreature/Moon_Goblin",
|
||||
"step-3-profit/Midnight-Deep",
|
||||
"Taithrah/Minimal",
|
||||
"ysharma/huggingface",
|
||||
"ysharma/steampunk",
|
||||
"NoCrypt/miku"
|
||||
]
|
||||
|
||||
|
||||
def reload_gradio_theme(theme_name=None):
|
||||
if not theme_name:
|
||||
theme_name = shared.opts.gradio_theme
|
||||
|
||||
default_theme_args = dict(
|
||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||
)
|
||||
|
||||
if theme_name == "Default":
|
||||
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||
else:
|
||||
try:
|
||||
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
|
||||
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
|
||||
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
|
||||
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
|
||||
else:
|
||||
os.makedirs(theme_cache_dir, exist_ok=True)
|
||||
shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||
shared.gradio_theme.dump(theme_cache_path)
|
||||
except Exception as e:
|
||||
errors.display(e, "changing gradio theme")
|
||||
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
49
modules/shared_init.py
Normal file
49
modules/shared_init.py
Normal file
@ -0,0 +1,49 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
|
||||
def initialize():
|
||||
"""Initializes fields inside the shared module in a controlled manner.
|
||||
|
||||
Should be called early because some other modules you can import mingt need these fields to be already set.
|
||||
"""
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
|
||||
from modules import options, shared_options
|
||||
shared.options_templates = shared_options.options_templates
|
||||
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
||||
shared.restricted_opts = shared_options.restricted_opts
|
||||
if os.path.exists(shared.config_filename):
|
||||
shared.opts.load(shared.config_filename)
|
||||
|
||||
from modules import devices
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||
|
||||
shared.device = devices.device
|
||||
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
from modules import shared_state
|
||||
shared.state = shared_state.State()
|
||||
|
||||
from modules import styles
|
||||
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
|
||||
|
||||
from modules import interrogate
|
||||
shared.interrogator = interrogate.InterrogateModels("interrogate")
|
||||
|
||||
from modules import shared_total_tqdm
|
||||
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
|
||||
|
||||
from modules import memmon, devices
|
||||
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
|
||||
shared.mem_mon.start()
|
||||
|
@ -1,3 +1,6 @@
|
||||
import sys
|
||||
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
@ -41,13 +44,36 @@ def refresh_unet_list():
|
||||
modules.sd_unet.list_unets()
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.list_models()
|
||||
|
||||
|
||||
def list_samplers():
|
||||
import modules.sd_samplers
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules import shared
|
||||
|
||||
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
|
||||
|
||||
ui_reorder_categories_builtin_items = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"accordions",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"denoising",
|
||||
"seed",
|
||||
"batch",
|
||||
"override_settings",
|
||||
@ -61,9 +87,33 @@ def ui_reorder_categories():
|
||||
|
||||
sections = {}
|
||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||
if isinstance(script.section, str):
|
||||
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
|
||||
sections[script.section] = 1
|
||||
|
||||
yield from sections
|
||||
|
||||
yield "scripts"
|
||||
|
||||
|
||||
class Shared(sys.modules[__name__].__class__):
|
||||
"""
|
||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||
at program startup.
|
||||
"""
|
||||
|
||||
sd_model_val = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
import modules.sd_models
|
||||
|
||||
return modules.sd_models.model_data.get_sd_model()
|
||||
|
||||
@sd_model.setter
|
||||
def sd_model(self, value):
|
||||
import modules.sd_models
|
||||
|
||||
modules.sd_models.model_data.set_sd_model(value)
|
||||
|
||||
|
||||
sys.modules['modules.shared'].__class__ = Shared
|
||||
|
330
modules/shared_options.py
Normal file
330
modules/shared_options.py
Normal file
@ -0,0 +1,330 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.options import options_section, OptionInfo, OptionHTML
|
||||
|
||||
options_templates = {}
|
||||
hide_dirs = shared.hide_dirs
|
||||
|
||||
restricted_opts = {
|
||||
"samples_filename_pattern",
|
||||
"directories_filename_pattern",
|
||||
"outdir_samples",
|
||||
"outdir_txt2img_samples",
|
||||
"outdir_img2img_samples",
|
||||
"outdir_extras_samples",
|
||||
"outdir_grids",
|
||||
"outdir_txt2img_grids",
|
||||
"outdir_save",
|
||||
"outdir_init_images"
|
||||
}
|
||||
|
||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||
"font": OptionInfo("", "Font for image grids that have text"),
|
||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||
|
||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('API', "API"), {
|
||||
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('vae', "VAE"), {
|
||||
"sd_vae_explanation": OptionHTML("""
|
||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
|
||||
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
|
||||
"""),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('img2img', "img2img"), {
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
|
||||
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
|
||||
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
|
||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
|
||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
|
||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
||||
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
||||
</ul>"""),
|
||||
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "Live previews"), {
|
||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||
"live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
|
||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
|
||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section((None, "Hidden options"), {
|
||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
159
modules/shared_state.py
Normal file
159
modules/shared_state.py
Normal file
@ -0,0 +1,159 @@
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules import errors, shared, devices
|
||||
from typing import Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class State:
|
||||
skipped = False
|
||||
interrupted = False
|
||||
job = ""
|
||||
job_no = 0
|
||||
job_count = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
current_latent = None
|
||||
current_image = None
|
||||
current_image_sampling_step = 0
|
||||
id_live_preview = 0
|
||||
textinfo = None
|
||||
time_start = None
|
||||
server_start = None
|
||||
_server_command_signal = threading.Event()
|
||||
_server_command: Optional[str] = None
|
||||
|
||||
def __init__(self):
|
||||
self.server_start = time.time()
|
||||
|
||||
@property
|
||||
def need_restart(self) -> bool:
|
||||
# Compatibility getter for need_restart.
|
||||
return self.server_command == "restart"
|
||||
|
||||
@need_restart.setter
|
||||
def need_restart(self, value: bool) -> None:
|
||||
# Compatibility setter for need_restart.
|
||||
if value:
|
||||
self.server_command = "restart"
|
||||
|
||||
@property
|
||||
def server_command(self):
|
||||
return self._server_command
|
||||
|
||||
@server_command.setter
|
||||
def server_command(self, value: Optional[str]) -> None:
|
||||
"""
|
||||
Set the server command to `value` and signal that it's been set.
|
||||
"""
|
||||
self._server_command = value
|
||||
self._server_command_signal.set()
|
||||
|
||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||
"""
|
||||
Wait for server command to get set; return and clear the value and signal.
|
||||
"""
|
||||
if self._server_command_signal.wait(timeout):
|
||||
self._server_command_signal.clear()
|
||||
req = self._server_command
|
||||
self._server_command = None
|
||||
return req
|
||||
return None
|
||||
|
||||
def request_restart(self) -> None:
|
||||
self.interrupt()
|
||||
self.server_command = "restart"
|
||||
log.info("Received restart request")
|
||||
|
||||
def skip(self):
|
||||
self.skipped = True
|
||||
log.info("Received skip request")
|
||||
|
||||
def interrupt(self):
|
||||
self.interrupted = True
|
||||
log.info("Received interrupt request")
|
||||
|
||||
def nextjob(self):
|
||||
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
self.job_no += 1
|
||||
self.sampling_step = 0
|
||||
self.current_image_sampling_step = 0
|
||||
|
||||
def dict(self):
|
||||
obj = {
|
||||
"skipped": self.skipped,
|
||||
"interrupted": self.interrupted,
|
||||
"job": self.job,
|
||||
"job_count": self.job_count,
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"job_no": self.job_no,
|
||||
"sampling_step": self.sampling_step,
|
||||
"sampling_steps": self.sampling_steps,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.current_latent = None
|
||||
self.current_image = None
|
||||
self.current_image_sampling_step = 0
|
||||
self.id_live_preview = 0
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
||||
def end(self):
|
||||
duration = time.time() - self.time_start
|
||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||
self.job = ""
|
||||
self.job_count = 0
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
def set_current_image(self):
|
||||
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||
if not shared.parallel_processing_allowed:
|
||||
return
|
||||
|
||||
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
|
||||
try:
|
||||
if shared.opts.show_progress_grid:
|
||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||
else:
|
||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
except Exception:
|
||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||
# we silently ignore this error
|
||||
errors.record_exception()
|
||||
|
||||
def assign_current_image(self, image):
|
||||
self.current_image = image
|
||||
self.id_live_preview += 1
|
37
modules/shared_total_tqdm.py
Normal file
37
modules/shared_total_tqdm.py
Normal file
@ -0,0 +1,37 @@
|
||||
import tqdm
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
class TotalTQDM:
|
||||
def __init__(self):
|
||||
self._tqdm = None
|
||||
|
||||
def reset(self):
|
||||
self._tqdm = tqdm.tqdm(
|
||||
desc="Total progress",
|
||||
total=shared.state.job_count * shared.state.sampling_steps,
|
||||
position=1,
|
||||
file=shared.progress_print_out
|
||||
)
|
||||
|
||||
def update(self):
|
||||
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.update()
|
||||
|
||||
def updateTotal(self, new_total):
|
||||
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||
return
|
||||
if self._tqdm is None:
|
||||
self.reset()
|
||||
self._tqdm.total = new_total
|
||||
|
||||
def clear(self):
|
||||
if self._tqdm is not None:
|
||||
self._tqdm.refresh()
|
||||
self._tqdm.close()
|
||||
self._tqdm = None
|
||||
|
@ -106,10 +106,7 @@ class StyleDatabase:
|
||||
if os.path.exists(path):
|
||||
shutil.copy(path, f"{path}.bak")
|
||||
|
||||
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||
|
@ -58,7 +58,7 @@ def _summarize_chunk(
|
||||
scale: float,
|
||||
) -> AttnChunk:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
scale: float,
|
||||
) -> Tensor:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
|
@ -10,7 +10,7 @@ import psutil
|
||||
import re
|
||||
|
||||
import launch
|
||||
from modules import paths_internal, timer
|
||||
from modules import paths_internal, timer, shared, extensions, errors
|
||||
|
||||
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
|
||||
environment_whitelist = {
|
||||
@ -23,7 +23,6 @@ environment_whitelist = {
|
||||
"TORCH_COMMAND",
|
||||
"REQS_FILE",
|
||||
"XFORMERS_PACKAGE",
|
||||
"GFPGAN_PACKAGE",
|
||||
"CLIP_PACKAGE",
|
||||
"OPENCLIP_PACKAGE",
|
||||
"STABLE_DIFFUSION_REPO",
|
||||
@ -83,7 +82,7 @@ def get_dict():
|
||||
"Data path": paths_internal.data_path,
|
||||
"Extensions dir": paths_internal.extensions_dir,
|
||||
"Checksum": checksum_token,
|
||||
"Commandline": sys.argv,
|
||||
"Commandline": get_argv(),
|
||||
"Torch env info": get_torch_sysinfo(),
|
||||
"Exceptions": get_exceptions(),
|
||||
"CPU": {
|
||||
@ -115,8 +114,6 @@ def format_exception(e, tb):
|
||||
|
||||
def get_exceptions():
|
||||
try:
|
||||
from modules import errors
|
||||
|
||||
return list(reversed(errors.exception_records))
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
@ -126,6 +123,22 @@ def get_environment():
|
||||
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
||||
|
||||
|
||||
def get_argv():
|
||||
res = []
|
||||
|
||||
for v in sys.argv:
|
||||
if shared.cmd_opts.gradio_auth and shared.cmd_opts.gradio_auth == v:
|
||||
res.append("<hidden>")
|
||||
continue
|
||||
|
||||
if shared.cmd_opts.api_auth and shared.cmd_opts.api_auth == v:
|
||||
res.append("<hidden>")
|
||||
continue
|
||||
|
||||
res.append(v)
|
||||
|
||||
return res
|
||||
|
||||
re_newline = re.compile(r"\r*\n")
|
||||
|
||||
|
||||
@ -142,8 +155,6 @@ def get_torch_sysinfo():
|
||||
def get_extensions(*, enabled):
|
||||
|
||||
try:
|
||||
from modules import extensions
|
||||
|
||||
def to_json(x: extensions.Extension):
|
||||
return {
|
||||
"name": x.name,
|
||||
@ -160,7 +171,6 @@ def get_extensions(*, enabled):
|
||||
|
||||
def get_config():
|
||||
try:
|
||||
from modules import shared
|
||||
return shared.opts.data
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
|
@ -13,7 +13,7 @@ import numpy as np
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
@ -181,29 +181,38 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||
vectors = data['clip_g'].shape[0]
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
embedding.vectors = vectors
|
||||
embedding.shape = shape
|
||||
embedding.filename = path
|
||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||
|
||||
@ -378,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
|
||||
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
from modules import processing
|
||||
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
template_file = textual_inversion_templates.get(template_filename, None)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
import argparse
|
||||
|
||||
|
||||
class TimerSubcategory:
|
||||
@ -11,20 +12,27 @@ class TimerSubcategory:
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||
self.timer.subcategory_level += 1
|
||||
|
||||
if self.timer.print_log:
|
||||
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
elapsed_for_subcategroy = time.time() - self.start
|
||||
self.timer.base_category = self.original_base_category
|
||||
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
||||
self.timer.record(self.category)
|
||||
self.timer.subcategory_level -= 1
|
||||
self.timer.record(self.category, disable_log=True)
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
def __init__(self, print_log=False):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
self.base_category = ''
|
||||
self.print_log = print_log
|
||||
self.subcategory_level = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
@ -38,13 +46,16 @@ class Timer:
|
||||
|
||||
self.records[category] += amount
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
def record(self, category, extra_time=0, disable_log=False):
|
||||
e = self.elapsed()
|
||||
|
||||
self.add_time_to_record(self.base_category + category, e + extra_time)
|
||||
|
||||
self.total += e + extra_time
|
||||
|
||||
if self.print_log and not disable_log:
|
||||
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
||||
|
||||
def subcategory(self, name):
|
||||
self.elapsed()
|
||||
|
||||
@ -71,6 +82,10 @@ class Timer:
|
||||
self.__init__()
|
||||
|
||||
|
||||
startup_timer = Timer()
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
||||
args = parser.parse_known_args()[0]
|
||||
|
||||
startup_timer = Timer(print_log=args.log_startup)
|
||||
|
||||
startup_record = None
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user