mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
Merge branch 'release_candidate'
This commit is contained in:
commit
5ef669de08
@ -87,5 +87,11 @@ module.exports = {
|
|||||||
modalNextImage: "readonly",
|
modalNextImage: "readonly",
|
||||||
// token-counters.js
|
// token-counters.js
|
||||||
setupTokenCounters: "readonly",
|
setupTokenCounters: "readonly",
|
||||||
|
// localStorage.js
|
||||||
|
localSet: "readonly",
|
||||||
|
localGet: "readonly",
|
||||||
|
localRemove: "readonly",
|
||||||
|
// resizeHandle.js
|
||||||
|
setupResizeHandle: "writable"
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
78
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
78
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -26,7 +26,7 @@ body:
|
|||||||
id: steps
|
id: steps
|
||||||
attributes:
|
attributes:
|
||||||
label: Steps to reproduce the problem
|
label: Steps to reproduce the problem
|
||||||
description: Please provide us with precise step by step information on how to reproduce the bug
|
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
||||||
value: |
|
value: |
|
||||||
1. Go to ....
|
1. Go to ....
|
||||||
2. Press ....
|
2. Press ....
|
||||||
@ -37,64 +37,14 @@ body:
|
|||||||
id: what-should
|
id: what-should
|
||||||
attributes:
|
attributes:
|
||||||
label: What should have happened?
|
label: What should have happened?
|
||||||
description: Tell what you think the normal behavior should be
|
description: Tell us what you think the normal behavior should be
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: input
|
- type: textarea
|
||||||
id: commit
|
id: sysinfo
|
||||||
attributes:
|
attributes:
|
||||||
label: Version or Commit where the problem happens
|
label: Sysinfo
|
||||||
description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)"
|
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: dropdown
|
|
||||||
id: py-version
|
|
||||||
attributes:
|
|
||||||
label: What Python version are you running on ?
|
|
||||||
multiple: false
|
|
||||||
options:
|
|
||||||
- Python 3.10.x
|
|
||||||
- Python 3.11.x (above, no supported yet)
|
|
||||||
- Python 3.9.x (below, no recommended)
|
|
||||||
- type: dropdown
|
|
||||||
id: platforms
|
|
||||||
attributes:
|
|
||||||
label: What platforms do you use to access the UI ?
|
|
||||||
multiple: true
|
|
||||||
options:
|
|
||||||
- Windows
|
|
||||||
- Linux
|
|
||||||
- MacOS
|
|
||||||
- iOS
|
|
||||||
- Android
|
|
||||||
- Other/Cloud
|
|
||||||
- type: dropdown
|
|
||||||
id: device
|
|
||||||
attributes:
|
|
||||||
label: What device are you running WebUI on?
|
|
||||||
multiple: true
|
|
||||||
options:
|
|
||||||
- Nvidia GPUs (RTX 20 above)
|
|
||||||
- Nvidia GPUs (GTX 16 below)
|
|
||||||
- AMD GPUs (RX 6000 above)
|
|
||||||
- AMD GPUs (RX 5000 below)
|
|
||||||
- CPU
|
|
||||||
- Other GPUs
|
|
||||||
- type: dropdown
|
|
||||||
id: cross_attention_opt
|
|
||||||
attributes:
|
|
||||||
label: Cross attention optimization
|
|
||||||
description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization
|
|
||||||
multiple: false
|
|
||||||
options:
|
|
||||||
- Automatic
|
|
||||||
- xformers
|
|
||||||
- sdp-no-mem
|
|
||||||
- sdp
|
|
||||||
- Doggettx
|
|
||||||
- V1
|
|
||||||
- InvokeAI
|
|
||||||
- "None "
|
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
@ -108,21 +58,7 @@ body:
|
|||||||
- Brave
|
- Brave
|
||||||
- Apple Safari
|
- Apple Safari
|
||||||
- Microsoft Edge
|
- Microsoft Edge
|
||||||
- type: textarea
|
- Other
|
||||||
id: cmdargs
|
|
||||||
attributes:
|
|
||||||
label: Command Line Arguments
|
|
||||||
description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
|
|
||||||
render: Shell
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
|
||||||
id: extensions
|
|
||||||
attributes:
|
|
||||||
label: List of extensions
|
|
||||||
description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise.
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: logs
|
id: logs
|
||||||
attributes:
|
attributes:
|
||||||
|
148
CHANGELOG.md
148
CHANGELOG.md
@ -1,3 +1,151 @@
|
|||||||
|
## 1.6.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)
|
||||||
|
* add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards
|
||||||
|
* add style editor dialog
|
||||||
|
* hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))
|
||||||
|
* option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))
|
||||||
|
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
||||||
|
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
||||||
|
* makes all of them work with img2img
|
||||||
|
* makes prompt composition posssible (AND)
|
||||||
|
* makes them available for SDXL
|
||||||
|
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
||||||
|
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
||||||
|
* textual inversion inference support for SDXL
|
||||||
|
* extra networks UI: show metadata for SD checkpoints
|
||||||
|
* checkpoint merger: add metadata support
|
||||||
|
* prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))
|
||||||
|
* VAE: allow selecting own VAE for each checkpoint (in user metadata editor)
|
||||||
|
* VAE: add selected VAE to infotext
|
||||||
|
* options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))
|
||||||
|
* add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723))
|
||||||
|
* change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it
|
||||||
|
* show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707))
|
||||||
|
* add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models
|
||||||
|
* prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457))
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))
|
||||||
|
* postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))
|
||||||
|
* XYZ: in the axis labels, remove pathnames from model filenames
|
||||||
|
* XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))
|
||||||
|
* XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))
|
||||||
|
* add gradio version warning
|
||||||
|
* sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))
|
||||||
|
* use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))
|
||||||
|
* move some settings to their own section: img2img, VAE
|
||||||
|
* add checkbox to show/hide dirs for extra networks
|
||||||
|
* Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))
|
||||||
|
* gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))
|
||||||
|
* sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))
|
||||||
|
* update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))
|
||||||
|
* option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))
|
||||||
|
* enable cond cache by default
|
||||||
|
* git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))
|
||||||
|
* allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))
|
||||||
|
* automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))
|
||||||
|
* put commonly used samplers on top, make DPM++ 2M Karras the default choice
|
||||||
|
* zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727))
|
||||||
|
* option to cache Lora networks in memory
|
||||||
|
* rework hires fix UI to use accordion
|
||||||
|
* face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back
|
||||||
|
* change quicksettings items to have variable width
|
||||||
|
* Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))
|
||||||
|
* Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
|
||||||
|
* support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))
|
||||||
|
* add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))
|
||||||
|
* support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584))
|
||||||
|
* make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634))
|
||||||
|
* configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648))
|
||||||
|
* make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645))
|
||||||
|
* more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639))
|
||||||
|
* make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635))
|
||||||
|
* make progress bar work independently from live preview display which results in it being updated a lot more often
|
||||||
|
* forbid Full live preview method for medvram and add a setting to undo the forbidding
|
||||||
|
* make it possible to localize tooltips and placeholders
|
||||||
|
* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))
|
||||||
|
* Restore faces and Tiling generation parameters have been moved to settings out of main UI
|
||||||
|
* if you want to put them back into main UI, use `Options in main UI` setting on the UI page.
|
||||||
|
|
||||||
|
### Extensions and API:
|
||||||
|
* gradio 3.41.2
|
||||||
|
* also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd
|
||||||
|
* support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')
|
||||||
|
* properly clear the total console progressbar when using txt2img and img2img from API
|
||||||
|
* add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))
|
||||||
|
* shared.py and webui.py split into many files
|
||||||
|
* add --loglevel commandline argument for logging
|
||||||
|
* add a custom UI element that combines accordion and checkbox
|
||||||
|
* avoid importing gradio in tests because it spams warnings
|
||||||
|
* put infotext label for setting into OptionInfo definition rather than in a separate list
|
||||||
|
* make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))
|
||||||
|
* option to make scripts UI without gr.Group
|
||||||
|
* add a way for scripts to register a callback for before/after just a single component's creation
|
||||||
|
* use dataclass for StableDiffusionProcessing
|
||||||
|
* store patches for Lora in a specialized module instead of inside torch
|
||||||
|
* support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698))
|
||||||
|
* add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616))
|
||||||
|
* dump current stack traces when exiting with SIGINT
|
||||||
|
* add type annotations for extra fields of shared.sd_model
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* Don't crash if out of local storage quota for javascriot localStorage
|
||||||
|
* XYZ plot do not fail if an exception occurs
|
||||||
|
* fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))
|
||||||
|
* localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))
|
||||||
|
* fix sdxl model invalid configuration after the hijack
|
||||||
|
* correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))
|
||||||
|
* open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))
|
||||||
|
* prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))
|
||||||
|
* add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))
|
||||||
|
* fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))
|
||||||
|
* fix options in main UI misbehaving when there's just one element
|
||||||
|
* make it possible to use a sampler from infotext even if it's hidden in the dropdown
|
||||||
|
* fix styles missing from the prompt in infotext when making a grid of batch of multiplie images
|
||||||
|
* prevent bogus progress output in console when calculating hires fix dimensions
|
||||||
|
* fix --use-textbox-seed
|
||||||
|
* fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))
|
||||||
|
* properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))
|
||||||
|
* MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))
|
||||||
|
* add second_order to samplers that mistakenly didn't have it
|
||||||
|
* when refreshing cards in extra networks UI, do not discard user's custom resolution
|
||||||
|
* fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))
|
||||||
|
* fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588))
|
||||||
|
* fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586))
|
||||||
|
* fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596))
|
||||||
|
* auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603))
|
||||||
|
* fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607))
|
||||||
|
* fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638))
|
||||||
|
* fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633))
|
||||||
|
* attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630))
|
||||||
|
* implement missing undo hijack for SDXL
|
||||||
|
* fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684))
|
||||||
|
* fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689))
|
||||||
|
* fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685))
|
||||||
|
* fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667))
|
||||||
|
* create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717))
|
||||||
|
* prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version
|
||||||
|
* set devices.dtype_unet correctly
|
||||||
|
* run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||||
|
* prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745))
|
||||||
|
* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
|
||||||
|
* fix defaults settings page breaking when any of main UI tabs are hidden
|
||||||
|
* fix incorrect save/display of new values in Defaults page in settings
|
||||||
|
* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
|
||||||
|
* fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||||
|
* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||||
|
* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))
|
||||||
|
* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
|
||||||
|
* fix a bug allowing users to bypass gradio and API authentication (reported by vysecurity)
|
||||||
|
* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))
|
||||||
|
* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))
|
||||||
|
* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))
|
||||||
|
* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))
|
||||||
|
* get progressbar to display correctly in extensions tab
|
||||||
|
|
||||||
|
|
||||||
## 1.5.2
|
## 1.5.2
|
||||||
|
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
|
7
CITATION.cff
Normal file
7
CITATION.cff
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
authors:
|
||||||
|
- given-names: AUTOMATIC1111
|
||||||
|
title: "Stable Diffusion Web UI"
|
||||||
|
date-released: 2022-08-22
|
||||||
|
url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui"
|
14
README.md
14
README.md
@ -78,7 +78,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- Clip skip
|
- Clip skip
|
||||||
- Hypernetworks
|
- Hypernetworks
|
||||||
- Loras (same as Hypernetworks but more pretty)
|
- Loras (same as Hypernetworks but more pretty)
|
||||||
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
|
||||||
- Can select to load a different VAE from settings screen
|
- Can select to load a different VAE from settings screen
|
||||||
- Estimated completion time in progress bar
|
- Estimated completion time in progress bar
|
||||||
- API
|
- API
|
||||||
@ -88,12 +88,15 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||||
- Now without any bad letters!
|
- Now without any bad letters!
|
||||||
- Load checkpoints in safetensors format
|
- Load checkpoints in safetensors format
|
||||||
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
|
||||||
- Now with a license!
|
- Now with a license!
|
||||||
- Reorder elements in the UI from settings screen
|
- Reorder elements in the UI from settings screen
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||||
|
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
|
||||||
|
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
|
||||||
|
|
||||||
Alternatively, use online services (like Google Colab):
|
Alternatively, use online services (like Google Colab):
|
||||||
|
|
||||||
@ -115,7 +118,7 @@ Alternatively, use online services (like Google Colab):
|
|||||||
1. Install the dependencies:
|
1. Install the dependencies:
|
||||||
```bash
|
```bash
|
||||||
# Debian-based:
|
# Debian-based:
|
||||||
sudo apt install wget git python3 python3-venv
|
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||||
# Red Hat-based:
|
# Red Hat-based:
|
||||||
sudo dnf install wget git python3
|
sudo dnf install wget git python3
|
||||||
# Arch-based:
|
# Arch-based:
|
||||||
@ -123,7 +126,7 @@ sudo pacman -S wget git python3
|
|||||||
```
|
```
|
||||||
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
||||||
```bash
|
```bash
|
||||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
wget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh
|
||||||
```
|
```
|
||||||
3. Run `webui.sh`.
|
3. Run `webui.sh`.
|
||||||
4. Check `webui-user.sh` for options.
|
4. Check `webui-user.sh` for options.
|
||||||
@ -169,5 +172,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- LyCORIS - KohakuBlueleaf
|
- LyCORIS - KohakuBlueleaf
|
||||||
|
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
|
||||||
|
self.errors = {}
|
||||||
|
"""mapping of network names to the number of errors the network had during operation"""
|
||||||
|
|
||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
|
||||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
pass
|
if self.errors:
|
||||||
|
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
31
extensions-builtin/Lora/lora_patches.py
Normal file
31
extensions-builtin/Lora/lora_patches.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import networks
|
||||||
|
from modules import patches
|
||||||
|
|
||||||
|
|
||||||
|
class LoraPatches:
|
||||||
|
def __init__(self):
|
||||||
|
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
|
||||||
|
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
|
||||||
|
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
|
||||||
|
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
|
||||||
|
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
|
||||||
|
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
|
||||||
|
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
|
||||||
|
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
|
||||||
|
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
|
||||||
|
|
||||||
|
def undo(self):
|
||||||
|
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
|
||||||
|
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
|
||||||
|
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
|
||||||
|
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
|
||||||
|
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
|
||||||
|
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
|
||||||
|
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
|
||||||
|
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
|
||||||
|
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
|
||||||
|
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
|
||||||
|
|
@ -133,7 +133,7 @@ class NetworkModule:
|
|||||||
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
@ -145,7 +145,10 @@ class NetworkModule:
|
|||||||
if orig_weight.size().numel() == updown.size().numel():
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
updown = updown.reshape(orig_weight.shape)
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
return updown * self.calc_scale() * self.multiplier()
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||||
|
|
||||||
def calc_updown(self, target):
|
def calc_updown(self, target):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -14,9 +14,14 @@ class NetworkModuleFull(network.NetworkModule):
|
|||||||
super().__init__(net, weights)
|
super().__init__(net, weights)
|
||||||
|
|
||||||
self.weight = weights.w.get("diff")
|
self.weight = weights.w.get("diff")
|
||||||
|
self.ex_bias = weights.w.get("diff_b")
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
output_shape = self.weight.shape
|
output_shape = self.weight.shape
|
||||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
if self.ex_bias is not None:
|
||||||
|
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
||||||
|
28
extensions-builtin/Lora/network_norm.py
Normal file
28
extensions-builtin/Lora/network_norm.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeNorm(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||||
|
return NetworkModuleNorm(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleNorm(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w_norm = weights.w.get("w_norm")
|
||||||
|
self.b_norm = weights.w.get("b_norm")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.w_norm.shape
|
||||||
|
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
if self.b_norm is not None:
|
||||||
|
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
@ -1,12 +1,15 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
import network_hada
|
import network_hada
|
||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
|
import network_norm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -19,6 +22,7 @@ module_types = [
|
|||||||
network_ia3.ModuleTypeIa3(),
|
network_ia3.ModuleTypeIa3(),
|
||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
|
network_norm.ModuleTypeNorm(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +35,8 @@ suffix_conversion = {
|
|||||||
"resnets": {
|
"resnets": {
|
||||||
"conv1": "in_layers_2",
|
"conv1": "in_layers_2",
|
||||||
"conv2": "out_layers_3",
|
"conv2": "out_layers_3",
|
||||||
|
"norm1": "in_layers_0",
|
||||||
|
"norm2": "out_layers_0",
|
||||||
"time_emb_proj": "emb_layers_1",
|
"time_emb_proj": "emb_layers_1",
|
||||||
"conv_shortcut": "skip_connection",
|
"conv_shortcut": "skip_connection",
|
||||||
}
|
}
|
||||||
@ -190,11 +196,19 @@ def load_network(name, network_on_disk):
|
|||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def purge_networks_from_memory():
|
||||||
|
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
|
||||||
|
name = next(iter(networks_in_memory))
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
@ -212,15 +226,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
failed_to_load_networks = []
|
failed_to_load_networks = []
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
||||||
net = already_loaded.get(name, None)
|
net = already_loaded.get(name, None)
|
||||||
|
|
||||||
network_on_disk = networks_on_disk[i]
|
|
||||||
|
|
||||||
if network_on_disk is not None:
|
if network_on_disk is not None:
|
||||||
|
if net is None:
|
||||||
|
net = networks_in_memory.get(name)
|
||||||
|
|
||||||
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||||
try:
|
try:
|
||||||
net = load_network(name, network_on_disk)
|
net = load_network(name, network_on_disk)
|
||||||
|
|
||||||
|
networks_in_memory.pop(name, None)
|
||||||
|
networks_in_memory[name] = net
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"loading network {network_on_disk.filename}")
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
continue
|
continue
|
||||||
@ -231,7 +249,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
if net is None:
|
if net is None:
|
||||||
failed_to_load_networks.append(name)
|
failed_to_load_networks.append(name)
|
||||||
print(f"Couldn't find network with name {name}")
|
logging.info(f"Couldn't find network with name {name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
@ -240,23 +258,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
|
purge_networks_from_memory()
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
|
||||||
if weights_backup is None:
|
if weights_backup is None and bias_backup is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if weights_backup is not None:
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
if bias_backup is not None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias.copy_(bias_backup)
|
||||||
|
else:
|
||||||
|
self.bias.copy_(bias_backup)
|
||||||
else:
|
else:
|
||||||
self.weight.copy_(weights_backup)
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias = None
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
@ -271,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||||
|
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
if weights_backup is None:
|
if weights_backup is None and wanted_names != ():
|
||||||
|
if current_names != ():
|
||||||
|
raise RuntimeError("no backup weights found and current weights are not unchanged")
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||||
else:
|
else:
|
||||||
@ -279,21 +315,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
self.network_weights_backup = weights_backup
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
if bias_backup is None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||||
|
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||||
|
elif getattr(self, 'bias', None) is not None:
|
||||||
|
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||||
|
else:
|
||||||
|
bias_backup = None
|
||||||
|
self.network_bias_backup = bias_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
network_restore_weights_from_backup(self)
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown = module.calc_updown(self.weight)
|
with torch.no_grad():
|
||||||
|
updown, ex_bias = module.calc_updown(self.weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight += updown
|
||||||
continue
|
if ex_bias is not None and hasattr(self, 'bias'):
|
||||||
|
if self.bias is None:
|
||||||
|
self.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.bias += ex_bias
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||||
@ -301,21 +357,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
with torch.no_grad():
|
||||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
self.in_proj_weight += updown_qkv
|
self.in_proj_weight += updown_qkv
|
||||||
self.out_proj.weight += updown_out
|
self.out_proj.weight += updown_out
|
||||||
continue
|
if ex_bias is not None:
|
||||||
|
if self.out_proj.bias is None:
|
||||||
|
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.out_proj.bias += ex_bias
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
@ -342,7 +410,7 @@ def network_forward(module, input, original_forward):
|
|||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
y = module.forward(y, input)
|
y = module.forward(input, y)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -354,44 +422,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|||||||
|
|
||||||
def network_Linear_forward(self, input):
|
def network_Linear_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
return network_forward(self, input, originals.Linear_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_network(self, input)
|
return originals.Linear_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Linear_load_state_dict(self, *args, **kwargs):
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Linear_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_forward(self, input):
|
def network_Conv2d_forward(self, input):
|
||||||
if shared.opts.lora_functional:
|
if shared.opts.lora_functional:
|
||||||
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
return network_forward(self, input, originals.Conv2d_forward)
|
||||||
|
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_network(self, input)
|
return originals.Conv2d_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, originals.GroupNorm_forward)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return originals.GroupNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, originals.LayerNorm_forward)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return originals.LayerNorm_forward(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_forward(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||||
network_reset_cached_weight(self)
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_networks():
|
def list_available_networks():
|
||||||
@ -459,9 +557,14 @@ def infotext_pasted(infotext, params):
|
|||||||
params["Prompt"] += "\n" + "".join(added)
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
|
||||||
|
originals: lora_patches.LoraPatches = None
|
||||||
|
|
||||||
|
extra_network_lora = None
|
||||||
|
|
||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
loaded_networks = []
|
loaded_networks = []
|
||||||
|
networks_in_memory = {}
|
||||||
available_network_hash_lookup = {}
|
available_network_hash_lookup = {}
|
||||||
forbidden_network_aliases = {}
|
forbidden_network_aliases = {}
|
||||||
|
|
||||||
|
@ -1,57 +1,30 @@
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import network
|
import network
|
||||||
import networks
|
import networks
|
||||||
import lora # noqa:F401
|
import lora # noqa:F401
|
||||||
|
import lora_patches
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
networks.originals.undo()
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
|
||||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
|
|
||||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||||
extra_networks.register_extra_network(extra_network)
|
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
networks.originals = lora_patches.LoraPatches()
|
||||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
|
||||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
|
||||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
|
||||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
|
||||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
|
||||||
|
|
||||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
|
||||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
|
||||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
|
||||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
|
||||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
@ -65,6 +38,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
|||||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
|
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@ -121,3 +95,5 @@ def infotext_pasted(infotext, d):
|
|||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
|
||||||
|
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
|
||||||
|
@ -70,6 +70,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
metadata = item.get("metadata") or {}
|
metadata = item.get("metadata") or {}
|
||||||
|
|
||||||
keys = {
|
keys = {
|
||||||
|
'ss_output_name': "Output name:",
|
||||||
'ss_sd_model_name': "Model:",
|
'ss_sd_model_name': "Model:",
|
||||||
'ss_clip_skip': "Clip skip:",
|
'ss_clip_skip': "Clip skip:",
|
||||||
'ss_network_module': "Kohya module:",
|
'ss_network_module': "Kohya module:",
|
||||||
@ -167,7 +168,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||||
|
|
||||||
with gr.Column(scale=1, min_width=120):
|
with gr.Column(scale=1, min_width=120):
|
||||||
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
|
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
||||||
|
|
||||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||||
|
|
||||||
|
@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
|
@ -12,8 +12,22 @@ onUiLoaded(async() => {
|
|||||||
"Sketch": elementIDs.sketch
|
"Sketch": elementIDs.sketch
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
// Get active tab
|
// Get active tab
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Waits for an element to be present in the DOM.
|
||||||
|
*/
|
||||||
|
const waitForElement = (id) => new Promise(resolve => {
|
||||||
|
const checkForElement = () => {
|
||||||
|
const element = document.querySelector(id);
|
||||||
|
if (element) return resolve(element);
|
||||||
|
setTimeout(checkForElement, 100);
|
||||||
|
};
|
||||||
|
checkForElement();
|
||||||
|
});
|
||||||
|
|
||||||
function getActiveTab(elements, all = false) {
|
function getActiveTab(elements, all = false) {
|
||||||
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||||
|
|
||||||
@ -34,7 +48,7 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Wait until opts loaded
|
// Wait until opts loaded
|
||||||
async function waitForOpts() {
|
async function waitForOpts() {
|
||||||
for (;;) {
|
for (; ;) {
|
||||||
if (window.opts && Object.keys(window.opts).length) {
|
if (window.opts && Object.keys(window.opts).length) {
|
||||||
return window.opts;
|
return window.opts;
|
||||||
}
|
}
|
||||||
@ -42,6 +56,11 @@ onUiLoaded(async() => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Detect whether the element has a horizontal scroll bar
|
||||||
|
function hasHorizontalScrollbar(element) {
|
||||||
|
return element.scrollWidth > element.clientWidth;
|
||||||
|
}
|
||||||
|
|
||||||
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
// Function for defining the "Ctrl", "Shift" and "Alt" keys
|
||||||
function isModifierKey(event, key) {
|
function isModifierKey(event, key) {
|
||||||
switch (key) {
|
switch (key) {
|
||||||
@ -201,7 +220,8 @@ onUiLoaded(async() => {
|
|||||||
canvas_hotkey_overlap: "KeyO",
|
canvas_hotkey_overlap: "KeyO",
|
||||||
canvas_disabled_functions: [],
|
canvas_disabled_functions: [],
|
||||||
canvas_show_tooltip: true,
|
canvas_show_tooltip: true,
|
||||||
canvas_blur_prompt: false
|
canvas_auto_expand: true,
|
||||||
|
canvas_blur_prompt: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
const functionMap = {
|
const functionMap = {
|
||||||
@ -249,7 +269,7 @@ onUiLoaded(async() => {
|
|||||||
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
input?.addEventListener("input", () => restoreImgRedMask(elements));
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyZoomAndPan(elemId) {
|
function applyZoomAndPan(elemId, isExtension = true) {
|
||||||
const targetElement = gradioApp().querySelector(elemId);
|
const targetElement = gradioApp().querySelector(elemId);
|
||||||
|
|
||||||
if (!targetElement) {
|
if (!targetElement) {
|
||||||
@ -361,6 +381,12 @@ onUiLoaded(async() => {
|
|||||||
panY: 0
|
panY: 0
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "hidden";
|
||||||
|
}
|
||||||
|
|
||||||
|
targetElement.isZoomed = false;
|
||||||
|
|
||||||
fixCanvas();
|
fixCanvas();
|
||||||
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;
|
||||||
|
|
||||||
@ -371,8 +397,27 @@ onUiLoaded(async() => {
|
|||||||
toggleOverlap("off");
|
toggleOverlap("off");
|
||||||
fullScreenMode = false;
|
fullScreenMode = false;
|
||||||
|
|
||||||
|
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
|
||||||
|
if (closeBtn) {
|
||||||
|
closeBtn.addEventListener("click", resetZoom);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (canvas && isExtension) {
|
||||||
|
const parentElement = targetElement.closest('[id^="component-"]');
|
||||||
|
if (
|
||||||
|
canvas &&
|
||||||
|
parseFloat(canvas.style.width) > parentElement.offsetWidth &&
|
||||||
|
parseFloat(targetElement.style.width) > parentElement.offsetWidth
|
||||||
|
) {
|
||||||
|
fitToElement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
canvas &&
|
canvas &&
|
||||||
|
!isExtension &&
|
||||||
parseFloat(canvas.style.width) > 865 &&
|
parseFloat(canvas.style.width) > 865 &&
|
||||||
parseFloat(targetElement.style.width) > 865
|
parseFloat(targetElement.style.width) > 865
|
||||||
) {
|
) {
|
||||||
@ -381,9 +426,6 @@ onUiLoaded(async() => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
targetElement.style.width = "";
|
targetElement.style.width = "";
|
||||||
if (canvas) {
|
|
||||||
targetElement.style.height = canvas.style.height;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
// Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements
|
||||||
@ -439,7 +481,7 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
// Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables
|
||||||
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
function updateZoom(newZoomLevel, mouseX, mouseY) {
|
||||||
newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15));
|
newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15));
|
||||||
|
|
||||||
elemData[elemId].panX +=
|
elemData[elemId].panX +=
|
||||||
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
|
mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;
|
||||||
@ -450,6 +492,10 @@ onUiLoaded(async() => {
|
|||||||
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;
|
||||||
|
|
||||||
toggleOverlap("on");
|
toggleOverlap("on");
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
|
}
|
||||||
|
|
||||||
return newZoomLevel;
|
return newZoomLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,10 +518,12 @@ onUiLoaded(async() => {
|
|||||||
fullScreenMode = false;
|
fullScreenMode = false;
|
||||||
elemData[elemId].zoomLevel = updateZoom(
|
elemData[elemId].zoomLevel = updateZoom(
|
||||||
elemData[elemId].zoomLevel +
|
elemData[elemId].zoomLevel +
|
||||||
(operation === "+" ? delta : -delta),
|
(operation === "+" ? delta : -delta),
|
||||||
zoomPosX - targetElement.getBoundingClientRect().left,
|
zoomPosX - targetElement.getBoundingClientRect().left,
|
||||||
zoomPosY - targetElement.getBoundingClientRect().top
|
zoomPosY - targetElement.getBoundingClientRect().top
|
||||||
);
|
);
|
||||||
|
|
||||||
|
targetElement.isZoomed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -489,10 +537,19 @@ onUiLoaded(async() => {
|
|||||||
//Reset Zoom
|
//Reset Zoom
|
||||||
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;
|
||||||
|
|
||||||
|
let parentElement;
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
parentElement = targetElement.closest('[id^="component-"]');
|
||||||
|
} else {
|
||||||
|
parentElement = targetElement.parentElement;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// Get element and screen dimensions
|
// Get element and screen dimensions
|
||||||
const elementWidth = targetElement.offsetWidth;
|
const elementWidth = targetElement.offsetWidth;
|
||||||
const elementHeight = targetElement.offsetHeight;
|
const elementHeight = targetElement.offsetHeight;
|
||||||
const parentElement = targetElement.parentElement;
|
|
||||||
const screenWidth = parentElement.clientWidth;
|
const screenWidth = parentElement.clientWidth;
|
||||||
const screenHeight = parentElement.clientHeight;
|
const screenHeight = parentElement.clientHeight;
|
||||||
|
|
||||||
@ -545,8 +602,12 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
if (!canvas) return;
|
if (!canvas) return;
|
||||||
|
|
||||||
if (canvas.offsetWidth > 862) {
|
if (canvas.offsetWidth > 862 || isExtension) {
|
||||||
targetElement.style.width = canvas.offsetWidth + "px";
|
targetElement.style.width = (canvas.offsetWidth + 2) + "px";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (fullScreenMode) {
|
if (fullScreenMode) {
|
||||||
@ -648,8 +709,48 @@ onUiLoaded(async() => {
|
|||||||
mouseY = e.offsetY;
|
mouseY = e.offsetY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simulation of the function to put a long image into the screen.
|
||||||
|
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
|
||||||
|
// We hide the image and show it to the user when it is ready.
|
||||||
|
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
function autoExpand() {
|
||||||
|
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
|
||||||
|
if (canvas) {
|
||||||
|
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
|
||||||
|
targetElement.style.visibility = "hidden";
|
||||||
|
setTimeout(() => {
|
||||||
|
fitToScreen();
|
||||||
|
resetZoom();
|
||||||
|
targetElement.style.visibility = "visible";
|
||||||
|
targetElement.isExpanded = true;
|
||||||
|
}, 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
targetElement.addEventListener("mousemove", getMousePosition);
|
targetElement.addEventListener("mousemove", getMousePosition);
|
||||||
|
|
||||||
|
//observers
|
||||||
|
// Creating an observer with a callback function to handle DOM changes
|
||||||
|
const observer = new MutationObserver((mutationsList, observer) => {
|
||||||
|
for (let mutation of mutationsList) {
|
||||||
|
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
|
||||||
|
mutation.target.tagName.toLowerCase() === 'canvas') {
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
setTimeout(resetZoom, 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Apply auto expand if enabled
|
||||||
|
if (hotkeysConfig.canvas_auto_expand) {
|
||||||
|
targetElement.addEventListener("mousemove", autoExpand);
|
||||||
|
// Set up an observer to track attribute changes
|
||||||
|
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
|
||||||
|
}
|
||||||
|
|
||||||
// Handle events only inside the targetElement
|
// Handle events only inside the targetElement
|
||||||
let isKeyDownHandlerAttached = false;
|
let isKeyDownHandlerAttached = false;
|
||||||
|
|
||||||
@ -754,6 +855,11 @@ onUiLoaded(async() => {
|
|||||||
if (isMoving && elemId === activeElement) {
|
if (isMoving && elemId === activeElement) {
|
||||||
updatePanPosition(e.movementX, e.movementY);
|
updatePanPosition(e.movementX, e.movementY);
|
||||||
targetElement.style.pointerEvents = "none";
|
targetElement.style.pointerEvents = "none";
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.style.overflow = "visible";
|
||||||
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
targetElement.style.pointerEvents = "auto";
|
targetElement.style.pointerEvents = "auto";
|
||||||
}
|
}
|
||||||
@ -764,13 +870,93 @@ onUiLoaded(async() => {
|
|||||||
isMoving = false;
|
isMoving = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Checks for extension
|
||||||
|
function checkForOutBox() {
|
||||||
|
const parentElement = targetElement.closest('[id^="component-"]');
|
||||||
|
if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) {
|
||||||
|
resetZoom();
|
||||||
|
targetElement.isExpanded = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) {
|
||||||
|
resetZoom();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) {
|
||||||
|
resetZoom();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.addEventListener("mousemove", checkForOutBox);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
window.addEventListener('resize', (e) => {
|
||||||
|
resetZoom();
|
||||||
|
|
||||||
|
if (isExtension) {
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
targetElement.isZoomed = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
gradioApp().addEventListener("mousemove", handleMoveByKey);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
applyZoomAndPan(elementIDs.sketch);
|
applyZoomAndPan(elementIDs.sketch, false);
|
||||||
applyZoomAndPan(elementIDs.inpaint);
|
applyZoomAndPan(elementIDs.inpaint, false);
|
||||||
applyZoomAndPan(elementIDs.inpaintSketch);
|
applyZoomAndPan(elementIDs.inpaintSketch, false);
|
||||||
|
|
||||||
// Make the function global so that other extensions can take advantage of this solution
|
// Make the function global so that other extensions can take advantage of this solution
|
||||||
window.applyZoomAndPan = applyZoomAndPan;
|
const applyZoomAndPanIntegration = async(id, elementIDs) => {
|
||||||
|
const mainEl = document.querySelector(id);
|
||||||
|
if (id.toLocaleLowerCase() === "none") {
|
||||||
|
for (const elementID of elementIDs) {
|
||||||
|
const el = await waitForElement(elementID);
|
||||||
|
if (!el) break;
|
||||||
|
applyZoomAndPan(elementID);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mainEl) return;
|
||||||
|
mainEl.addEventListener("click", async() => {
|
||||||
|
for (const elementID of elementIDs) {
|
||||||
|
const el = await waitForElement(elementID);
|
||||||
|
if (!el) break;
|
||||||
|
applyZoomAndPan(elementID);
|
||||||
|
}
|
||||||
|
}, {once: true});
|
||||||
|
};
|
||||||
|
|
||||||
|
window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image")
|
||||||
|
|
||||||
|
window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension
|
||||||
|
|
||||||
|
/*
|
||||||
|
The function `applyZoomAndPanIntegration` takes two arguments:
|
||||||
|
|
||||||
|
1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.
|
||||||
|
If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event.
|
||||||
|
|
||||||
|
2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.
|
||||||
|
If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||||
|
In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet".
|
||||||
|
*/
|
||||||
|
|
||||||
|
// More examples
|
||||||
|
// Add integration with ControlNet txt2img One TAB
|
||||||
|
// applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]);
|
||||||
|
|
||||||
|
// Add integration with ControlNet txt2img Tabs
|
||||||
|
// applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));
|
||||||
|
|
||||||
|
// Add integration with Inpaint Anything
|
||||||
|
// applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]);
|
||||||
});
|
});
|
||||||
|
@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
|
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
}))
|
}))
|
||||||
|
@ -61,3 +61,6 @@
|
|||||||
to {opacity: 1;}
|
to {opacity: 1;}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.styler {
|
||||||
|
overflow:inherit !important;
|
||||||
|
}
|
@ -1,5 +1,7 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_components, ui_settings
|
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||||
from modules.ui_components import FormColumn
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
@ -19,18 +21,38 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
self.comps = []
|
self.comps = []
|
||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
|
self.infotext_fields = []
|
||||||
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
|
|
||||||
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row():
|
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
||||||
for setting_name in shared.opts.extra_options:
|
|
||||||
with FormColumn():
|
|
||||||
comp = ui_settings.create_setting_component(setting_name)
|
|
||||||
|
|
||||||
self.comps.append(comp)
|
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||||
self.setting_names.append(setting_name)
|
|
||||||
|
for row in range(row_count):
|
||||||
|
with gr.Row():
|
||||||
|
for col in range(shared.opts.extra_options_cols):
|
||||||
|
index = row * shared.opts.extra_options_cols + col
|
||||||
|
if index >= len(extra_options):
|
||||||
|
break
|
||||||
|
|
||||||
|
setting_name = extra_options[index]
|
||||||
|
|
||||||
|
with FormColumn():
|
||||||
|
comp = ui_settings.create_setting_component(setting_name)
|
||||||
|
|
||||||
|
self.comps.append(comp)
|
||||||
|
self.setting_names.append(setting_name)
|
||||||
|
|
||||||
|
setting_infotext_name = mapping.get(setting_name)
|
||||||
|
if setting_infotext_name is not None:
|
||||||
|
self.infotext_fields.append((comp, setting_infotext_name))
|
||||||
|
|
||||||
def get_settings_values():
|
def get_settings_values():
|
||||||
return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]
|
||||||
|
return res[0] if len(res) == 1 else res
|
||||||
|
|
||||||
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
|
||||||
|
|
||||||
@ -43,6 +65,10 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
||||||
"extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
|
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
|
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||||
|
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
||||||
|
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,13 @@ function reportWindowSize() {
|
|||||||
var button = gradioApp().getElementById(tab + '_generate_box');
|
var button = gradioApp().getElementById(tab + '_generate_box');
|
||||||
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
||||||
target.insertBefore(button, target.firstElementChild);
|
target.insertBefore(button, target.firstElementChild);
|
||||||
|
|
||||||
|
gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
window.addEventListener("resize", reportWindowSize);
|
window.addEventListener("resize", reportWindowSize);
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
reportWindowSize();
|
||||||
|
});
|
||||||
|
@ -33,7 +33,7 @@ function extensions_check() {
|
|||||||
|
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
|
requestProgress(id, gradioApp().getElementById('extensions_installed_html'), null, function() {
|
||||||
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,20 +1,38 @@
|
|||||||
|
function toggleCss(key, css, enable) {
|
||||||
|
var style = document.getElementById(key);
|
||||||
|
if (enable && !style) {
|
||||||
|
style = document.createElement('style');
|
||||||
|
style.id = key;
|
||||||
|
style.type = 'text/css';
|
||||||
|
document.head.appendChild(style);
|
||||||
|
}
|
||||||
|
if (style && !enable) {
|
||||||
|
document.head.removeChild(style);
|
||||||
|
}
|
||||||
|
if (style) {
|
||||||
|
style.innerHTML == '';
|
||||||
|
style.appendChild(document.createTextNode(css));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function setupExtraNetworksForTab(tabname) {
|
function setupExtraNetworksForTab(tabname) {
|
||||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||||
|
|
||||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||||
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
|
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
||||||
|
var search = searchDiv.querySelector('textarea');
|
||||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
||||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
||||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||||
|
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||||
|
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||||
|
|
||||||
search.classList.add('search');
|
|
||||||
sort.classList.add('sort');
|
|
||||||
sortOrder.classList.add('sortorder');
|
|
||||||
sort.dataset.sortkey = 'sortDefault';
|
sort.dataset.sortkey = 'sortDefault';
|
||||||
tabs.appendChild(search);
|
tabs.appendChild(searchDiv);
|
||||||
tabs.appendChild(sort);
|
tabs.appendChild(sort);
|
||||||
tabs.appendChild(sortOrder);
|
tabs.appendChild(sortOrder);
|
||||||
tabs.appendChild(refresh);
|
tabs.appendChild(refresh);
|
||||||
|
tabs.appendChild(showDirsDiv);
|
||||||
|
|
||||||
var applyFilter = function() {
|
var applyFilter = function() {
|
||||||
var searchTerm = search.value.toLowerCase();
|
var searchTerm = search.value.toLowerCase();
|
||||||
@ -80,6 +98,15 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
|
|
||||||
|
var showDirsUpdate = function() {
|
||||||
|
var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
|
||||||
|
toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
|
||||||
|
localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
|
||||||
|
};
|
||||||
|
showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
|
||||||
|
showDirs.addEventListener("change", showDirsUpdate);
|
||||||
|
showDirsUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyExtraNetworkFilter(tabname) {
|
function applyExtraNetworkFilter(tabname) {
|
||||||
@ -179,7 +206,7 @@ function saveCardPreview(event, tabname, filename) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event) {
|
function extraNetworksSearchButton(tabs_id, event) {
|
||||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
||||||
var button = event.target;
|
var button = event.target;
|
||||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||||
|
|
||||||
@ -222,6 +249,15 @@ function popup(contents) {
|
|||||||
globalPopup.style.display = "flex";
|
globalPopup.style.display = "flex";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var storedPopupIds = {};
|
||||||
|
function popupId(id) {
|
||||||
|
if (!storedPopupIds[id]) {
|
||||||
|
storedPopupIds[id] = gradioApp().getElementById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
popup(storedPopupIds[id]);
|
||||||
|
}
|
||||||
|
|
||||||
function extraNetworksShowMetadata(text) {
|
function extraNetworksShowMetadata(text) {
|
||||||
var elem = document.createElement('pre');
|
var elem = document.createElement('pre');
|
||||||
elem.classList.add('popup-metadata');
|
elem.classList.add('popup-metadata');
|
||||||
@ -305,7 +341,7 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
|
|||||||
newDiv.innerHTML = data.html;
|
newDiv.innerHTML = data.html;
|
||||||
var newCard = newDiv.firstElementChild;
|
var newCard = newDiv.firstElementChild;
|
||||||
|
|
||||||
newCard.style = '';
|
newCard.style.display = '';
|
||||||
card.parentElement.insertBefore(newCard, card);
|
card.parentElement.insertBefore(newCard, card);
|
||||||
card.parentElement.removeChild(card);
|
card.parentElement.removeChild(card);
|
||||||
}
|
}
|
||||||
|
@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
|
|||||||
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var comp of window.gradio_config.components) {
|
||||||
|
if (comp.props.webui_tooltip && comp.props.elem_id) {
|
||||||
|
var elem = gradioApp().getElementById(comp.props.elem_id);
|
||||||
|
if (elem) {
|
||||||
|
elem.title = comp.props.webui_tooltip;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
@ -136,6 +136,11 @@ function setupImageForLightbox(e) {
|
|||||||
var event = isFirefox ? 'mousedown' : 'click';
|
var event = isFirefox ? 'mousedown' : 'click';
|
||||||
|
|
||||||
e.addEventListener(event, function(evt) {
|
e.addEventListener(event, function(evt) {
|
||||||
|
if (evt.button == 1) {
|
||||||
|
open(evt.target.src);
|
||||||
|
evt.preventDefault();
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
|
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||||
|
37
javascript/inputAccordion.js
Normal file
37
javascript/inputAccordion.js
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
var elem = mutationRecord.target;
|
||||||
|
var open = elem.classList.contains('open');
|
||||||
|
|
||||||
|
var accordion = elem.parentNode;
|
||||||
|
accordion.classList.toggle('input-accordion-open', open);
|
||||||
|
|
||||||
|
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||||
|
checkbox.checked = open;
|
||||||
|
updateInput(checkbox);
|
||||||
|
|
||||||
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
|
if (extra) {
|
||||||
|
extra.style.display = open ? "" : "none";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function inputAccordionChecked(id, checked) {
|
||||||
|
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
||||||
|
if (label.classList.contains('open') != checked) {
|
||||||
|
label.click();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||||
|
var labelWrap = accordion.querySelector('.label-wrap');
|
||||||
|
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||||
|
|
||||||
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
|
if (extra) {
|
||||||
|
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
26
javascript/localStorage.js
Normal file
26
javascript/localStorage.js
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
|
||||||
|
function localSet(k, v) {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(k, v);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to save ${k} to localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function localGet(k, def) {
|
||||||
|
try {
|
||||||
|
return localStorage.getItem(k);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to load ${k} from localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return def;
|
||||||
|
}
|
||||||
|
|
||||||
|
function localRemove(k) {
|
||||||
|
try {
|
||||||
|
return localStorage.removeItem(k);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to remove ${k} from localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
}
|
@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
|
|||||||
train_hypernetwork: 'OPTION',
|
train_hypernetwork: 'OPTION',
|
||||||
txt2img_styles: 'OPTION',
|
txt2img_styles: 'OPTION',
|
||||||
img2img_styles: 'OPTION',
|
img2img_styles: 'OPTION',
|
||||||
setting_random_artist_categories: 'SPAN',
|
setting_random_artist_categories: 'OPTION',
|
||||||
setting_face_restoration_model: 'SPAN',
|
setting_face_restoration_model: 'OPTION',
|
||||||
setting_realesrgan_enabled_models: 'SPAN',
|
setting_realesrgan_enabled_models: 'OPTION',
|
||||||
extras_upscaler_1: 'SPAN',
|
extras_upscaler_1: 'OPTION',
|
||||||
extras_upscaler_2: 'SPAN',
|
extras_upscaler_2: 'OPTION',
|
||||||
};
|
};
|
||||||
|
|
||||||
var re_num = /^[.\d]+$/;
|
var re_num = /^[.\d]+$/;
|
||||||
@ -107,12 +107,41 @@ function processNode(node) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function localizeWholePage() {
|
||||||
|
processNode(gradioApp());
|
||||||
|
|
||||||
|
function elem(comp) {
|
||||||
|
var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id;
|
||||||
|
return gradioApp().getElementById(elem_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (var comp of window.gradio_config.components) {
|
||||||
|
if (comp.props.webui_tooltip) {
|
||||||
|
let e = elem(comp);
|
||||||
|
|
||||||
|
let tl = e ? getTranslation(e.title) : undefined;
|
||||||
|
if (tl !== undefined) {
|
||||||
|
e.title = tl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (comp.props.placeholder) {
|
||||||
|
let e = elem(comp);
|
||||||
|
let textbox = e ? e.querySelector('[placeholder]') : null;
|
||||||
|
|
||||||
|
let tl = textbox ? getTranslation(textbox.placeholder) : undefined;
|
||||||
|
if (tl !== undefined) {
|
||||||
|
textbox.placeholder = tl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function dumpTranslations() {
|
function dumpTranslations() {
|
||||||
if (!hasLocalization()) {
|
if (!hasLocalization()) {
|
||||||
// If we don't have any localization,
|
// If we don't have any localization,
|
||||||
// we will not have traversed the app to find
|
// we will not have traversed the app to find
|
||||||
// original_lines, so do that now.
|
// original_lines, so do that now.
|
||||||
processNode(gradioApp());
|
localizeWholePage();
|
||||||
}
|
}
|
||||||
var dumped = {};
|
var dumped = {};
|
||||||
if (localization.rtl) {
|
if (localization.rtl) {
|
||||||
@ -154,7 +183,7 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
processNode(gradioApp());
|
localizeWholePage();
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||||
|
@ -15,7 +15,7 @@ onAfterUiUpdate(function() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
|
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"] div[id$="_results"] .thumbnail-item > img');
|
||||||
|
|
||||||
if (galleryPreviews == null) return;
|
if (galleryPreviews == null) return;
|
||||||
|
|
||||||
|
@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
var dateStart = new Date();
|
var dateStart = new Date();
|
||||||
var wasEverActive = false;
|
var wasEverActive = false;
|
||||||
var parentProgressbar = progressbarContainer.parentNode;
|
var parentProgressbar = progressbarContainer.parentNode;
|
||||||
var parentGallery = gallery ? gallery.parentNode : null;
|
|
||||||
|
|
||||||
var divProgress = document.createElement('div');
|
var divProgress = document.createElement('div');
|
||||||
divProgress.className = 'progressDiv';
|
divProgress.className = 'progressDiv';
|
||||||
@ -80,32 +79,26 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
divProgress.appendChild(divInner);
|
divProgress.appendChild(divInner);
|
||||||
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
||||||
|
|
||||||
if (parentGallery) {
|
var livePreview = null;
|
||||||
var livePreview = document.createElement('div');
|
|
||||||
livePreview.className = 'livePreview';
|
|
||||||
parentGallery.insertBefore(livePreview, gallery);
|
|
||||||
}
|
|
||||||
|
|
||||||
var removeProgressBar = function() {
|
var removeProgressBar = function() {
|
||||||
|
if (!divProgress) return;
|
||||||
|
|
||||||
setTitle("");
|
setTitle("");
|
||||||
parentProgressbar.removeChild(divProgress);
|
parentProgressbar.removeChild(divProgress);
|
||||||
if (parentGallery) parentGallery.removeChild(livePreview);
|
if (gallery && livePreview) gallery.removeChild(livePreview);
|
||||||
atEnd();
|
atEnd();
|
||||||
|
|
||||||
|
divProgress = null;
|
||||||
};
|
};
|
||||||
|
|
||||||
var fun = function(id_task, id_live_preview) {
|
var funProgress = function(id_task) {
|
||||||
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
request("./internal/progress", {id_task: id_task, live_preview: false}, function(res) {
|
||||||
if (res.completed) {
|
if (res.completed) {
|
||||||
removeProgressBar();
|
removeProgressBar();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var rect = progressbarContainer.getBoundingClientRect();
|
|
||||||
|
|
||||||
if (rect.width) {
|
|
||||||
divProgress.style.width = rect.width + "px";
|
|
||||||
}
|
|
||||||
|
|
||||||
let progressText = "";
|
let progressText = "";
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
||||||
@ -119,7 +112,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
progressText += " ETA: " + formatTime(res.eta);
|
progressText += " ETA: " + formatTime(res.eta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
setTitle(progressText);
|
setTitle(progressText);
|
||||||
|
|
||||||
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
||||||
@ -142,16 +134,33 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (onProgress) {
|
||||||
|
onProgress(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
funProgress(id_task, res.id_live_preview);
|
||||||
|
}, opts.live_preview_refresh_period || 500);
|
||||||
|
}, function() {
|
||||||
|
removeProgressBar();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
var funLivePreview = function(id_task, id_live_preview) {
|
||||||
|
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
||||||
|
if (!divProgress) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (res.live_preview && gallery) {
|
if (res.live_preview && gallery) {
|
||||||
rect = gallery.getBoundingClientRect();
|
|
||||||
if (rect.width) {
|
|
||||||
livePreview.style.width = rect.width + "px";
|
|
||||||
livePreview.style.height = rect.height + "px";
|
|
||||||
}
|
|
||||||
|
|
||||||
var img = new Image();
|
var img = new Image();
|
||||||
img.onload = function() {
|
img.onload = function() {
|
||||||
|
if (!livePreview) {
|
||||||
|
livePreview = document.createElement('div');
|
||||||
|
livePreview.className = 'livePreview';
|
||||||
|
gallery.insertBefore(livePreview, gallery.firstElementChild);
|
||||||
|
}
|
||||||
|
|
||||||
livePreview.appendChild(img);
|
livePreview.appendChild(img);
|
||||||
if (livePreview.childElementCount > 2) {
|
if (livePreview.childElementCount > 2) {
|
||||||
livePreview.removeChild(livePreview.firstElementChild);
|
livePreview.removeChild(livePreview.firstElementChild);
|
||||||
@ -160,18 +169,18 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
img.src = res.live_preview;
|
img.src = res.live_preview;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (onProgress) {
|
|
||||||
onProgress(res);
|
|
||||||
}
|
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
fun(id_task, res.id_live_preview);
|
funLivePreview(id_task, res.id_live_preview);
|
||||||
}, opts.live_preview_refresh_period || 500);
|
}, opts.live_preview_refresh_period || 500);
|
||||||
}, function() {
|
}, function() {
|
||||||
removeProgressBar();
|
removeProgressBar();
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
fun(id_task, 0);
|
funProgress(id_task, 0);
|
||||||
|
|
||||||
|
if (gallery) {
|
||||||
|
funLivePreview(id_task, 0);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
141
javascript/resizeHandle.js
Normal file
141
javascript/resizeHandle.js
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
(function() {
|
||||||
|
const GRADIO_MIN_WIDTH = 320;
|
||||||
|
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
|
||||||
|
const PAD = 16;
|
||||||
|
const DEBOUNCE_TIME = 100;
|
||||||
|
|
||||||
|
const R = {
|
||||||
|
tracking: false,
|
||||||
|
parent: null,
|
||||||
|
parentWidth: null,
|
||||||
|
leftCol: null,
|
||||||
|
leftColStartWidth: null,
|
||||||
|
screenX: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
let resizeTimer;
|
||||||
|
let parents = [];
|
||||||
|
|
||||||
|
function setLeftColGridTemplate(el, width) {
|
||||||
|
el.style.gridTemplateColumns = `${width}px 16px 1fr`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function displayResizeHandle(parent) {
|
||||||
|
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
|
||||||
|
parent.style.display = 'flex';
|
||||||
|
if (R.handle != null) {
|
||||||
|
R.handle.style.opacity = '0';
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
parent.style.display = 'grid';
|
||||||
|
if (R.handle != null) {
|
||||||
|
R.handle.style.opacity = '100';
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function afterResize(parent) {
|
||||||
|
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
|
||||||
|
const oldParentWidth = R.parentWidth;
|
||||||
|
const newParentWidth = parent.offsetWidth;
|
||||||
|
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
|
||||||
|
|
||||||
|
const ratio = newParentWidth / oldParentWidth;
|
||||||
|
|
||||||
|
const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
|
||||||
|
setLeftColGridTemplate(parent, newWidthL);
|
||||||
|
|
||||||
|
R.parentWidth = newParentWidth;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setup(parent) {
|
||||||
|
const leftCol = parent.firstElementChild;
|
||||||
|
const rightCol = parent.lastElementChild;
|
||||||
|
|
||||||
|
parents.push(parent);
|
||||||
|
|
||||||
|
parent.style.display = 'grid';
|
||||||
|
parent.style.gap = '0';
|
||||||
|
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||||
|
|
||||||
|
const resizeHandle = document.createElement('div');
|
||||||
|
resizeHandle.classList.add('resize-handle');
|
||||||
|
parent.insertBefore(resizeHandle, rightCol);
|
||||||
|
|
||||||
|
resizeHandle.addEventListener('mousedown', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
document.body.classList.add('resizing');
|
||||||
|
|
||||||
|
R.tracking = true;
|
||||||
|
R.parent = parent;
|
||||||
|
R.parentWidth = parent.offsetWidth;
|
||||||
|
R.handle = resizeHandle;
|
||||||
|
R.leftCol = leftCol;
|
||||||
|
R.leftColStartWidth = leftCol.offsetWidth;
|
||||||
|
R.screenX = evt.screenX;
|
||||||
|
});
|
||||||
|
|
||||||
|
resizeHandle.addEventListener('dblclick', (evt) => {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||||
|
});
|
||||||
|
|
||||||
|
afterResize(parent);
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener('mousemove', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
if (R.tracking) {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
const delta = R.screenX - evt.screenX;
|
||||||
|
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
|
||||||
|
setLeftColGridTemplate(R.parent, leftColWidth);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
window.addEventListener('mouseup', (evt) => {
|
||||||
|
if (evt.button !== 0) return;
|
||||||
|
|
||||||
|
if (R.tracking) {
|
||||||
|
evt.preventDefault();
|
||||||
|
evt.stopPropagation();
|
||||||
|
|
||||||
|
R.tracking = false;
|
||||||
|
|
||||||
|
document.body.classList.remove('resizing');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
window.addEventListener('resize', () => {
|
||||||
|
clearTimeout(resizeTimer);
|
||||||
|
|
||||||
|
resizeTimer = setTimeout(function() {
|
||||||
|
for (const parent of parents) {
|
||||||
|
afterResize(parent);
|
||||||
|
}
|
||||||
|
}, DEBOUNCE_TIME);
|
||||||
|
});
|
||||||
|
|
||||||
|
setupResizeHandle = setup;
|
||||||
|
})();
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
|
||||||
|
if (!elem.querySelector('.resize-handle')) {
|
||||||
|
setupResizeHandle(elem);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
@ -19,28 +19,11 @@ function all_gallery_buttons() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function selected_gallery_button() {
|
function selected_gallery_button() {
|
||||||
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
|
return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;
|
||||||
var visibleCurrentButton = null;
|
|
||||||
allCurrentButtons.forEach(function(elem) {
|
|
||||||
if (elem.parentElement.offsetParent) {
|
|
||||||
visibleCurrentButton = elem;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return visibleCurrentButton;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function selected_gallery_index() {
|
function selected_gallery_index() {
|
||||||
var buttons = all_gallery_buttons();
|
return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));
|
||||||
var button = selected_gallery_button();
|
|
||||||
|
|
||||||
var result = -1;
|
|
||||||
buttons.forEach(function(v, i) {
|
|
||||||
if (v == button) {
|
|
||||||
result = i;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function extract_image_from_gallery(gallery) {
|
function extract_image_from_gallery(gallery) {
|
||||||
@ -152,11 +135,11 @@ function submit() {
|
|||||||
showSubmitButtons('txt2img', false);
|
showSubmitButtons('txt2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("txt2img_task_id", id);
|
localSet("txt2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
localStorage.removeItem("txt2img_task_id");
|
localRemove("txt2img_task_id");
|
||||||
showRestoreProgressButton('txt2img', false);
|
showRestoreProgressButton('txt2img', false);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -171,11 +154,11 @@ function submit_img2img() {
|
|||||||
showSubmitButtons('img2img', false);
|
showSubmitButtons('img2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("img2img_task_id", id);
|
localSet("img2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
showSubmitButtons('img2img', true);
|
showSubmitButtons('img2img', true);
|
||||||
localStorage.removeItem("img2img_task_id");
|
localRemove("img2img_task_id");
|
||||||
showRestoreProgressButton('img2img', false);
|
showRestoreProgressButton('img2img', false);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -189,9 +172,7 @@ function submit_img2img() {
|
|||||||
|
|
||||||
function restoreProgressTxt2img() {
|
function restoreProgressTxt2img() {
|
||||||
showRestoreProgressButton("txt2img", false);
|
showRestoreProgressButton("txt2img", false);
|
||||||
var id = localStorage.getItem("txt2img_task_id");
|
var id = localGet("txt2img_task_id");
|
||||||
|
|
||||||
id = localStorage.getItem("txt2img_task_id");
|
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
@ -205,7 +186,7 @@ function restoreProgressTxt2img() {
|
|||||||
function restoreProgressImg2img() {
|
function restoreProgressImg2img() {
|
||||||
showRestoreProgressButton("img2img", false);
|
showRestoreProgressButton("img2img", false);
|
||||||
|
|
||||||
var id = localStorage.getItem("img2img_task_id");
|
var id = localGet("img2img_task_id");
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
@ -218,8 +199,8 @@ function restoreProgressImg2img() {
|
|||||||
|
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
|
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||||
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
|
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
15
launch.py
15
launch.py
@ -1,6 +1,5 @@
|
|||||||
from modules import launch_utils
|
from modules import launch_utils
|
||||||
|
|
||||||
|
|
||||||
args = launch_utils.args
|
args = launch_utils.args
|
||||||
python = launch_utils.python
|
python = launch_utils.python
|
||||||
git = launch_utils.git
|
git = launch_utils.git
|
||||||
@ -26,8 +25,18 @@ start = launch_utils.start
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if not args.skip_prepare_environment:
|
if args.dump_sysinfo:
|
||||||
prepare_environment()
|
filename = launch_utils.dump_sysinfo()
|
||||||
|
|
||||||
|
print(f"Sysinfo saved as {filename}. Exiting...")
|
||||||
|
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
launch_utils.startup_timer.record("initial startup")
|
||||||
|
|
||||||
|
with launch_utils.startup_timer.subcategory("prepare environment"):
|
||||||
|
if not args.skip_prepare_environment:
|
||||||
|
prepare_environment()
|
||||||
|
|
||||||
if args.test_server:
|
if args.test_server:
|
||||||
configure_for_tests()
|
configure_for_tests()
|
||||||
|
@ -4,6 +4,8 @@ import os
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
import ipaddress
|
||||||
|
import requests
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -15,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin,Image
|
||||||
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
|
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||||
from modules.sd_vae import vae_dict
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@ -56,7 +57,41 @@ def setUpscalers(req: dict):
|
|||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
|
def verify_url(url):
|
||||||
|
"""Returns True if the url refers to a global resource."""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
try:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
domain_name = parsed_url.netloc
|
||||||
|
host = socket.gethostbyname_ex(domain_name)
|
||||||
|
for ip in host[2]:
|
||||||
|
ip_addr = ipaddress.ip_address(ip)
|
||||||
|
if not ip_addr.is_global:
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
|
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||||
|
if not opts.api_enable_requests:
|
||||||
|
raise HTTPException(status_code=500, detail="Requests not allowed")
|
||||||
|
|
||||||
|
if opts.api_forbid_local_requests and not verify_url(encoding):
|
||||||
|
raise HTTPException(status_code=500, detail="Request to local resource not allowed")
|
||||||
|
|
||||||
|
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||||
|
response = requests.get(encoding, timeout=30, headers=headers)
|
||||||
|
try:
|
||||||
|
image = Image.open(BytesIO(response.content))
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||||
|
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
@ -197,6 +232,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
@ -329,6 +365,7 @@ class Api:
|
|||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
|
p.is_api = True
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_txt2img_grids
|
p.outpath_grids = opts.outdir_txt2img_grids
|
||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
@ -343,6 +380,7 @@ class Api:
|
|||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -388,6 +426,7 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
|
p.is_api = True
|
||||||
p.scripts = script_runner
|
p.scripts = script_runner
|
||||||
p.outpath_grids = opts.outdir_img2img_grids
|
p.outpath_grids = opts.outdir_img2img_grids
|
||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
@ -402,6 +441,7 @@ class Api:
|
|||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -530,7 +570,7 @@ class Api:
|
|||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
shared.opts.set(k, v)
|
shared.opts.set(k, v, is_api=True)
|
||||||
|
|
||||||
shared.opts.save(shared.config_filename)
|
shared.opts.save(shared.config_filename)
|
||||||
return
|
return
|
||||||
@ -562,10 +602,12 @@ class Api:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
import modules.sd_models as sd_models
|
||||||
|
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
|
||||||
|
|
||||||
def get_sd_vaes(self):
|
def get_sd_vaes(self):
|
||||||
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
import modules.sd_vae as sd_vae
|
||||||
|
return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
@ -608,6 +650,10 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
|
def refresh_vae(self):
|
||||||
|
with self.queue_lock:
|
||||||
|
shared_items.refresh_vae_list()
|
||||||
|
|
||||||
def create_embedding(self, args: dict):
|
def create_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="create_embedding")
|
shared.state.begin(job="create_embedding")
|
||||||
|
@ -50,10 +50,12 @@ class PydanticModelGenerator:
|
|||||||
additional_fields = None,
|
additional_fields = None,
|
||||||
):
|
):
|
||||||
def field_type_generator(k, v):
|
def field_type_generator(k, v):
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
|
||||||
# print(k, v.annotation, v.default)
|
|
||||||
field_type = v.annotation
|
field_type = v.annotation
|
||||||
|
|
||||||
|
if field_type == 'Image':
|
||||||
|
# images are sent as base64 strings via API
|
||||||
|
field_type = 'str'
|
||||||
|
|
||||||
return Optional[field_type]
|
return Optional[field_type]
|
||||||
|
|
||||||
def merge_class_params(class_):
|
def merge_class_params(class_):
|
||||||
@ -63,7 +65,6 @@ class PydanticModelGenerator:
|
|||||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._class_data = merge_class_params(class_instance)
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ class PydanticModelGenerator:
|
|||||||
field=underscore(k),
|
field=underscore(k),
|
||||||
field_alias=k,
|
field_alias=k,
|
||||||
field_type=field_type_generator(k, v),
|
field_type=field_type_generator(k, v),
|
||||||
field_value=v.default
|
field_value=None if isinstance(v.default, property) else v.default
|
||||||
)
|
)
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
]
|
]
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules.paths import data_path, script_path
|
from modules.paths import data_path, script_path
|
||||||
|
|
||||||
cache_filename = os.path.join(data_path, "cache.json")
|
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
||||||
cache_data = None
|
cache_data = None
|
||||||
cache_lock = threading.Lock()
|
cache_lock = threading.Lock()
|
||||||
|
|
||||||
@ -29,9 +30,12 @@ def dump_cache():
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
cache_filename_tmp = cache_filename + "-"
|
||||||
|
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
|
os.replace(cache_filename_tmp, cache_filename)
|
||||||
|
|
||||||
dump_cache_after = None
|
dump_cache_after = None
|
||||||
dump_cache_thread = None
|
dump_cache_thread = None
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
import html
|
import html
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress, errors, devices
|
from modules import shared, progress, errors, devices, fifo_lock
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = fifo_lock.FIFOLock()
|
||||||
|
|
||||||
|
|
||||||
def wrap_queued_call(func):
|
def wrap_queued_call(func):
|
||||||
|
@ -13,8 +13,11 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
|
|||||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||||
|
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||||
|
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
||||||
|
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
@ -33,9 +36,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
|
|||||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||||
|
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
|
||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||||
@ -66,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
|||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
|
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
@ -78,7 +83,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication l
|
|||||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it")
|
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
@ -110,3 +115,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
|
|||||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||||
|
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||||
|
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||||
|
@ -8,14 +8,12 @@ import time
|
|||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from collections import OrderedDict
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions, errors
|
from modules import shared, extensions, errors
|
||||||
from modules.paths_internal import script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
all_config_states = {}
|
||||||
all_config_states = OrderedDict()
|
|
||||||
|
|
||||||
|
|
||||||
def list_config_states():
|
def list_config_states():
|
||||||
@ -28,10 +26,14 @@ def list_config_states():
|
|||||||
for filename in os.listdir(config_states_dir):
|
for filename in os.listdir(config_states_dir):
|
||||||
if filename.endswith(".json"):
|
if filename.endswith(".json"):
|
||||||
path = os.path.join(config_states_dir, filename)
|
path = os.path.join(config_states_dir, filename)
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
try:
|
||||||
j = json.load(f)
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
j["filepath"] = path
|
j = json.load(f)
|
||||||
config_states.append(j)
|
assert "created_at" in j, '"created_at" does not exist'
|
||||||
|
j["filepath"] = path
|
||||||
|
config_states.append(j)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'[ERROR]: Config states {path}, {e}')
|
||||||
|
|
||||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import contextlib
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors, shared
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
@ -17,8 +17,6 @@ def has_mps() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if shared.cmd_opts.device_id is not None:
|
if shared.cmd_opts.device_id is not None:
|
||||||
return f"cuda:{shared.cmd_opts.device_id}"
|
return f"cuda:{shared.cmd_opts.device_id}"
|
||||||
|
|
||||||
@ -40,8 +38,6 @@ def get_optimal_device():
|
|||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if task in shared.cmd_opts.use_cpu:
|
if task in shared.cmd_opts.use_cpu:
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
@ -71,14 +67,17 @@ def enable_tf32():
|
|||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
device: torch.device = None
|
||||||
dtype = torch.float16
|
device_interrogate: torch.device = None
|
||||||
dtype_vae = torch.float16
|
device_gfpgan: torch.device = None
|
||||||
dtype_unet = torch.float16
|
device_esrgan: torch.device = None
|
||||||
|
device_codeformer: torch.device = None
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
dtype_vae: torch.dtype = torch.float16
|
||||||
|
dtype_unet: torch.dtype = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
@ -90,26 +89,10 @@ def cond_cast_float(input):
|
|||||||
return input.float() if unet_needs_upcast else input
|
return input.float() if unet_needs_upcast else input
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
nv_rng = None
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
|
||||||
return torch.randn(shape, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
@ -128,8 +111,6 @@ class NansException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
def test_for_nans(x, where):
|
def test_for_nans(x, where):
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
if shared.cmd_opts.disable_nan_check:
|
if shared.cmd_opts.disable_nan_check:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -169,3 +150,4 @@ def first_time_calculation():
|
|||||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||||
conv2d(x)
|
conv2d(x)
|
||||||
|
|
||||||
|
@ -84,3 +84,53 @@ def run(code, task):
|
|||||||
code()
|
code()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
display(task, e)
|
display(task, e)
|
||||||
|
|
||||||
|
|
||||||
|
def check_versions():
|
||||||
|
from packaging import version
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import gradio
|
||||||
|
|
||||||
|
expected_torch_version = "2.0.0"
|
||||||
|
expected_xformers_version = "0.0.20"
|
||||||
|
expected_gradio_version = "3.41.2"
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running torch {torch.__version__}.
|
||||||
|
The program is tested to work with torch {expected_torch_version}.
|
||||||
|
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||||
|
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||||
|
there are reports of issues with training tab on the latest version.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
if shared.xformers_available:
|
||||||
|
import xformers
|
||||||
|
|
||||||
|
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running xformers {xformers.__version__}.
|
||||||
|
The program is tested to work with xformers {expected_xformers_version}.
|
||||||
|
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
if gradio.__version__ != expected_gradio_version:
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running gradio {gradio.__version__}.
|
||||||
|
The program is designed to work with gradio {expected_gradio_version}.
|
||||||
|
Using a different version of gradio is extremely likely to break the program.
|
||||||
|
|
||||||
|
Reasons why you have the mismatched gradio version can be:
|
||||||
|
- you use --skip-install flag.
|
||||||
|
- you use webui.py to start the program instead of launch.py.
|
||||||
|
- an extension installs the incompatible gradio version.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from modules import shared, errors, cache
|
from modules import shared, errors, cache, scripts
|
||||||
from modules.gitpython_hack import Repo
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
|
|||||||
|
|
||||||
|
|
||||||
def active():
|
def active():
|
||||||
if shared.opts.disable_all_extensions == "all":
|
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||||
return []
|
return []
|
||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
||||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||||
else:
|
else:
|
||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
@ -90,8 +90,6 @@ class Extension:
|
|||||||
self.have_info_from_repo = True
|
self.have_info_from_repo = True
|
||||||
|
|
||||||
def list_files(self, subdir, extension):
|
def list_files(self, subdir, extension):
|
||||||
from modules import scripts
|
|
||||||
|
|
||||||
dirpath = os.path.join(self.path, subdir)
|
dirpath = os.path.join(self.path, subdir)
|
||||||
if not os.path.isdir(dirpath):
|
if not os.path.isdir(dirpath):
|
||||||
return []
|
return []
|
||||||
@ -141,8 +139,12 @@ def list_extensions():
|
|||||||
if not os.path.isdir(extensions_dir):
|
if not os.path.isdir(extensions_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.opts.disable_all_extensions == "all":
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
|
elif shared.opts.disable_all_extensions == "all":
|
||||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||||
|
elif shared.cmd_opts.disable_extra_extensions:
|
||||||
|
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
@ -84,27 +87,55 @@ class ExtraNetwork:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_extra_networks(extra_network_data):
|
||||||
|
"""returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.
|
||||||
|
|
||||||
|
Example input:
|
||||||
|
{
|
||||||
|
'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],
|
||||||
|
'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||||
|
'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||||
|
}
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
|
||||||
|
{
|
||||||
|
<extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],
|
||||||
|
<modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
for extra_network_name, extra_network_args in list(extra_network_data.items()):
|
||||||
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
|
alias = extra_network_aliases.get(extra_network_name, None)
|
||||||
|
|
||||||
|
if alias is not None and extra_network is None:
|
||||||
|
extra_network = alias
|
||||||
|
|
||||||
|
if extra_network is None:
|
||||||
|
logging.info(f"Skipping unknown extra network: {extra_network_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
res.setdefault(extra_network, []).extend(extra_network_args)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def activate(p, extra_network_data):
|
def activate(p, extra_network_data):
|
||||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||||
activate for all remaining registered networks with an empty argument list"""
|
activate for all remaining registered networks with an empty argument list"""
|
||||||
|
|
||||||
activated = []
|
activated = []
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
|
||||||
|
|
||||||
if extra_network is None:
|
|
||||||
extra_network = extra_network_aliases.get(extra_network_name, None)
|
|
||||||
|
|
||||||
if extra_network is None:
|
|
||||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extra_network.activate(p, extra_network_args)
|
extra_network.activate(p, extra_network_args)
|
||||||
activated.append(extra_network)
|
activated.append(extra_network)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
errors.display(e, f"activating extra network {extra_network.name} with arguments {extra_network_args}")
|
||||||
|
|
||||||
for extra_network_name, extra_network in extra_network_registry.items():
|
for extra_network_name, extra_network in extra_network_registry.items():
|
||||||
if extra_network in activated:
|
if extra_network in activated:
|
||||||
@ -123,19 +154,16 @@ def deactivate(p, extra_network_data):
|
|||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
deactivate for all remaining registered networks"""
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
for extra_network_name in extra_network_data:
|
data = lookup_extra_networks(extra_network_data)
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
|
||||||
if extra_network is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for extra_network in data:
|
||||||
try:
|
try:
|
||||||
extra_network.deactivate(p)
|
extra_network.deactivate(p)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"deactivating extra network {extra_network_name}")
|
errors.display(e, f"deactivating extra network {extra_network.name}")
|
||||||
|
|
||||||
for extra_network_name, extra_network in extra_network_registry.items():
|
for extra_network_name, extra_network in extra_network_registry.items():
|
||||||
args = extra_network_data.get(extra_network_name, None)
|
if extra_network in data:
|
||||||
if args is not None:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -177,3 +205,20 @@ def parse_prompts(prompts):
|
|||||||
|
|
||||||
return res, extra_data
|
return res, extra_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_metadata(filename):
|
||||||
|
if filename is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
basename, ext = os.path.splitext(filename)
|
||||||
|
metadata_filename = basename + '.json'
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
if os.path.isfile(metadata_filename):
|
||||||
|
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||||
|
metadata = json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
@ -7,7 +7,7 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||||
from modules.ui_common import plaintext_to_html
|
from modules.ui_common import plaintext_to_html
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -72,7 +72,20 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||||
|
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||||
|
if checkpoint_info is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata.update(checkpoint_info.metadata)
|
||||||
|
|
||||||
|
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||||
shared.state.begin(job="model-merge")
|
shared.state.begin(job="model-merge")
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
metadata = None
|
metadata = {}
|
||||||
|
|
||||||
|
if save_metadata and copy_metadata_fields:
|
||||||
|
if primary_model_info:
|
||||||
|
metadata.update(primary_model_info.metadata)
|
||||||
|
if secondary_model_info:
|
||||||
|
metadata.update(secondary_model_info.metadata)
|
||||||
|
if tertiary_model_info:
|
||||||
|
metadata.update(tertiary_model_info.metadata)
|
||||||
|
|
||||||
if save_metadata:
|
if save_metadata:
|
||||||
metadata = {"format": "pt"}
|
try:
|
||||||
|
metadata.update(json.loads(metadata_json))
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "readin metadata from json")
|
||||||
|
|
||||||
|
metadata["format"] = "pt"
|
||||||
|
|
||||||
|
if save_metadata and add_merge_recipe:
|
||||||
merge_recipe = {
|
merge_recipe = {
|
||||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
"primary_model_hash": primary_model_info.sha256,
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
"is_inpainting": result_is_inpainting_model,
|
"is_inpainting": result_is_inpainting_model,
|
||||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
}
|
}
|
||||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
|
||||||
|
|
||||||
sd_merge_models = {}
|
sd_merge_models = {}
|
||||||
|
|
||||||
@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
if tertiary_model_info:
|
if tertiary_model_info:
|
||||||
add_model_metadata(tertiary_model_info)
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
|
37
modules/fifo_lock.py
Normal file
37
modules/fifo_lock.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import threading
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
|
||||||
|
class FIFOLock(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._inner_lock = threading.Lock()
|
||||||
|
self._pending_threads = collections.deque()
|
||||||
|
|
||||||
|
def acquire(self, blocking=True):
|
||||||
|
with self._inner_lock:
|
||||||
|
lock_acquired = self._lock.acquire(False)
|
||||||
|
if lock_acquired:
|
||||||
|
return True
|
||||||
|
elif not blocking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
release_event = threading.Event()
|
||||||
|
self._pending_threads.append(release_event)
|
||||||
|
|
||||||
|
release_event.wait()
|
||||||
|
return self._lock.acquire()
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
with self._inner_lock:
|
||||||
|
if self._pending_threads:
|
||||||
|
release_event = self._pending_threads.popleft()
|
||||||
|
release_event.set()
|
||||||
|
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
__enter__ = acquire
|
||||||
|
|
||||||
|
def __exit__(self, t, v, tb):
|
||||||
|
self.release()
|
@ -6,10 +6,10 @@ import re
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||||
@ -32,6 +32,7 @@ class ParamBinding:
|
|||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
paste_fields.clear()
|
paste_fields.clear()
|
||||||
|
registered_param_bindings.clear()
|
||||||
|
|
||||||
|
|
||||||
def quote(text):
|
def quote(text):
|
||||||
@ -198,7 +199,6 @@ def restore_old_hires_fix_params(res):
|
|||||||
height = int(res.get("Size-2", 512))
|
height = int(res.get("Size-2", 512))
|
||||||
|
|
||||||
if firstpass_width == 0 or firstpass_height == 0:
|
if firstpass_width == 0 or firstpass_height == 0:
|
||||||
from modules import processing
|
|
||||||
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
|
||||||
|
|
||||||
res['Size-1'] = firstpass_width
|
res['Size-1'] = firstpass_width
|
||||||
@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Hires sampler" not in res:
|
if "Hires sampler" not in res:
|
||||||
res["Hires sampler"] = "Use same sampler"
|
res["Hires sampler"] = "Use same sampler"
|
||||||
|
|
||||||
|
if "Hires checkpoint" not in res:
|
||||||
|
res["Hires checkpoint"] = "Use same checkpoint"
|
||||||
|
|
||||||
if "Hires prompt" not in res:
|
if "Hires prompt" not in res:
|
||||||
res["Hires prompt"] = ""
|
res["Hires prompt"] = ""
|
||||||
|
|
||||||
@ -304,32 +307,28 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Schedule rho" not in res:
|
if "Schedule rho" not in res:
|
||||||
res["Schedule rho"] = 0
|
res["Schedule rho"] = 0
|
||||||
|
|
||||||
|
if "VAE Encoder" not in res:
|
||||||
|
res["VAE Encoder"] = "Full"
|
||||||
|
|
||||||
|
if "VAE Decoder" not in res:
|
||||||
|
res["VAE Decoder"] = "Full"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
infotext_to_setting_name_mapping = [
|
infotext_to_setting_name_mapping = [
|
||||||
('Clip skip', 'CLIP_stop_at_last_layers', ),
|
|
||||||
|
]
|
||||||
|
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
|
||||||
|
Example content:
|
||||||
|
|
||||||
|
infotext_to_setting_name_mapping = [
|
||||||
('Conditional mask weight', 'inpainting_mask_weight'),
|
('Conditional mask weight', 'inpainting_mask_weight'),
|
||||||
('Model hash', 'sd_model_checkpoint'),
|
('Model hash', 'sd_model_checkpoint'),
|
||||||
('ENSD', 'eta_noise_seed_delta'),
|
('ENSD', 'eta_noise_seed_delta'),
|
||||||
('Schedule type', 'k_sched_type'),
|
('Schedule type', 'k_sched_type'),
|
||||||
('Schedule max sigma', 'sigma_max'),
|
|
||||||
('Schedule min sigma', 'sigma_min'),
|
|
||||||
('Schedule rho', 'rho'),
|
|
||||||
('Noise multiplier', 'initial_noise_multiplier'),
|
|
||||||
('Eta', 'eta_ancestral'),
|
|
||||||
('Eta DDIM', 'eta_ddim'),
|
|
||||||
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
|
|
||||||
('UniPC variant', 'uni_pc_variant'),
|
|
||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
|
||||||
('UniPC order', 'uni_pc_order'),
|
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
|
||||||
('Token merging ratio', 'token_merging_ratio'),
|
|
||||||
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
|
||||||
('RNG', 'randn_source'),
|
|
||||||
('NGMS', 's_min_uncond'),
|
|
||||||
('Pad conds', 'pad_cond_uncond'),
|
|
||||||
]
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_override_settings_dict(text_pairs):
|
def create_override_settings_dict(text_pairs):
|
||||||
@ -350,7 +349,8 @@ def create_override_settings_dict(text_pairs):
|
|||||||
|
|
||||||
params[k] = v.strip()
|
params[k] = v.strip()
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||||
|
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||||
value = params.get(param_name, None)
|
value = params.get(param_name, None)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -399,10 +399,16 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
if override_settings_component is not None:
|
if override_settings_component is not None:
|
||||||
|
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||||
|
|
||||||
def paste_settings(params):
|
def paste_settings(params):
|
||||||
vals = {}
|
vals = {}
|
||||||
|
|
||||||
for param_name, setting_name in infotext_to_setting_name_mapping:
|
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||||
|
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||||
|
if param_name in already_handled_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
v = params.get(param_name, None)
|
v = params.get(param_name, None)
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
|
73
modules/gradio_extensons.py
Normal file
73
modules/gradio_extensons.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, ui_tempdir, patches
|
||||||
|
|
||||||
|
|
||||||
|
def add_classes_to_gradio_component(comp):
|
||||||
|
"""
|
||||||
|
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||||
|
"""
|
||||||
|
|
||||||
|
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||||
|
|
||||||
|
if getattr(comp, 'multiselect', False):
|
||||||
|
comp.elem_classes.append('multiselect')
|
||||||
|
|
||||||
|
|
||||||
|
def IOComponent_init(self, *args, **kwargs):
|
||||||
|
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||||
|
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.before_component(self, **kwargs)
|
||||||
|
|
||||||
|
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||||
|
|
||||||
|
res = original_IOComponent_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
add_classes_to_gradio_component(self)
|
||||||
|
|
||||||
|
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||||
|
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.after_component(self, **kwargs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def Block_get_config(self):
|
||||||
|
config = original_Block_get_config(self)
|
||||||
|
|
||||||
|
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||||
|
if webui_tooltip:
|
||||||
|
config["webui_tooltip"] = webui_tooltip
|
||||||
|
|
||||||
|
config.pop('example_inputs', None)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def BlockContext_init(self, *args, **kwargs):
|
||||||
|
res = original_BlockContext_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
add_classes_to_gradio_component(self)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def Blocks_get_config_file(self, *args, **kwargs):
|
||||||
|
config = original_Blocks_get_config_file(self, *args, **kwargs)
|
||||||
|
|
||||||
|
for comp_config in config["components"]:
|
||||||
|
if "example_inputs" in comp_config:
|
||||||
|
comp_config["example_inputs"] = {"serialized": []}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
|
||||||
|
original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
|
||||||
|
original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
|
||||||
|
original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
|
||||||
|
|
||||||
|
|
||||||
|
ui_tempdir.install_ui_tempdir_override()
|
@ -10,7 +10,7 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
from modules import images, processing
|
||||||
from modules import images
|
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
|
@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
|
|||||||
from modules.paths_internal import roboto_ttf_file
|
from modules.paths_internal import roboto_ttf_file
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
import modules.sd_vae as sd_vae
|
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
@ -318,7 +316,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||||
invalid_filename_prefix = ' '
|
invalid_filename_prefix = ' '
|
||||||
invalid_filename_postfix = ' .'
|
invalid_filename_postfix = ' .'
|
||||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||||
@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
|
|||||||
|
|
||||||
|
|
||||||
class FilenameGenerator:
|
class FilenameGenerator:
|
||||||
def get_vae_filename(self): #get the name of the VAE file.
|
|
||||||
if sd_vae.loaded_vae_file is None:
|
|
||||||
return "NoneType"
|
|
||||||
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
|
||||||
split_file_name = file_name.split('.')
|
|
||||||
if len(split_file_name) > 1 and split_file_name[0] == '':
|
|
||||||
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
|
||||||
else:
|
|
||||||
return split_file_name[0]
|
|
||||||
|
|
||||||
replacements = {
|
replacements = {
|
||||||
'seed': lambda self: self.seed if self.seed is not None else '',
|
'seed': lambda self: self.seed if self.seed is not None else '',
|
||||||
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
|
||||||
@ -367,7 +355,9 @@ class FilenameGenerator:
|
|||||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||||
'prompt_hash': lambda self: hashlib.sha256(self.prompt.encode()).hexdigest()[0:8],
|
'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),
|
||||||
|
'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),
|
||||||
|
'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string
|
||||||
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
'prompt': lambda self: sanitize_filename_part(self.prompt),
|
||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
@ -380,7 +370,8 @@ class FilenameGenerator:
|
|||||||
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
'user': lambda self: self.p.user,
|
'user': lambda self: self.p.user,
|
||||||
'vae_filename': lambda self: self.get_vae_filename(),
|
'vae_filename': lambda self: self.get_vae_filename(),
|
||||||
'none': lambda self: '', # Overrides the default so you can get just the sequence number
|
'none': lambda self: '', # Overrides the default, so you can get just the sequence number
|
||||||
|
'image_hash': lambda self, *args: self.image_hash(*args) # accepts formats: [image_hash<length>] default full hash
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -391,6 +382,22 @@ class FilenameGenerator:
|
|||||||
self.image = image
|
self.image = image
|
||||||
self.zip = zip
|
self.zip = zip
|
||||||
|
|
||||||
|
def get_vae_filename(self):
|
||||||
|
"""Get the name of the VAE file."""
|
||||||
|
|
||||||
|
import modules.sd_vae as sd_vae
|
||||||
|
|
||||||
|
if sd_vae.loaded_vae_file is None:
|
||||||
|
return "NoneType"
|
||||||
|
|
||||||
|
file_name = os.path.basename(sd_vae.loaded_vae_file)
|
||||||
|
split_file_name = file_name.split('.')
|
||||||
|
if len(split_file_name) > 1 and split_file_name[0] == '':
|
||||||
|
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
|
||||||
|
else:
|
||||||
|
return split_file_name[0]
|
||||||
|
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@ -444,6 +451,14 @@ class FilenameGenerator:
|
|||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
|
|
||||||
|
def image_hash(self, *args):
|
||||||
|
length = int(args[0]) if (args and args[0] != "") else None
|
||||||
|
return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]
|
||||||
|
|
||||||
|
def string_hash(self, text, *args):
|
||||||
|
length = int(args[0]) if (args and args[0] != "") else 8
|
||||||
|
return hashlib.sha256(text.encode()).hexdigest()[0:length]
|
||||||
|
|
||||||
def apply(self, x):
|
def apply(self, x):
|
||||||
res = ''
|
res = ''
|
||||||
|
|
||||||
@ -585,6 +600,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
"""
|
"""
|
||||||
namegen = FilenameGenerator(p, seed, prompt, image)
|
namegen = FilenameGenerator(p, seed, prompt, image)
|
||||||
|
|
||||||
|
# WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit
|
||||||
|
if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp":
|
||||||
|
print('Image dimensions too large; saving as PNG')
|
||||||
|
extension = ".png"
|
||||||
|
|
||||||
if save_to_dirs is None:
|
if save_to_dirs is None:
|
||||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||||
|
|
||||||
|
@ -3,14 +3,13 @@ from contextlib import closing
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers, images as imgutil
|
from modules import images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from modules.images import save_image
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
@ -18,9 +17,10 @@ import modules.scripts
|
|||||||
|
|
||||||
|
|
||||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||||
|
output_dir = output_dir.strip()
|
||||||
processing.fix_seed(p)
|
processing.fix_seed(p)
|
||||||
|
|
||||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||||
|
|
||||||
is_inpaint_batch = False
|
is_inpaint_batch = False
|
||||||
if inpaint_mask_dir:
|
if inpaint_mask_dir:
|
||||||
@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||||
|
|
||||||
save_normally = output_dir == ''
|
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
|
||||||
p.do_not_save_samples = not save_normally
|
|
||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(images) * p.n_iter
|
||||||
|
|
||||||
# extract "default" params to use in case getting png info fails
|
# extract "default" params to use in case getting png info fails
|
||||||
@ -111,40 +106,30 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
proc = process_images(p)
|
if output_dir:
|
||||||
|
p.outpath_samples = output_dir
|
||||||
for n, processed_image in enumerate(proc.images):
|
p.override_settings['save_to_dirs'] = False
|
||||||
filename = image_path.stem
|
if p.n_iter > 1 or p.batch_size > 1:
|
||||||
infotext = proc.infotext(p, n)
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
else:
|
||||||
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||||
if n > 0:
|
process_images(p)
|
||||||
filename += f"-{n}"
|
|
||||||
|
|
||||||
if not save_normally:
|
|
||||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
|
||||||
if processed_image.mode == 'RGBA':
|
|
||||||
processed_image = processed_image.convert("RGB")
|
|
||||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
|
|
||||||
if mode == 0: # img2img
|
if mode == 0: # img2img
|
||||||
image = init_img.convert("RGB")
|
image = init_img
|
||||||
mask = None
|
mask = None
|
||||||
elif mode == 1: # img2img sketch
|
elif mode == 1: # img2img sketch
|
||||||
image = sketch.convert("RGB")
|
image = sketch
|
||||||
mask = None
|
mask = None
|
||||||
elif mode == 2: # inpaint
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
mask = processing.create_binary_mask(mask)
|
||||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
|
||||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
|
||||||
image = image.convert("RGB")
|
|
||||||
elif mode == 3: # inpaint sketch
|
elif mode == 3: # inpaint sketch
|
||||||
image = inpaint_color_sketch
|
image = inpaint_color_sketch
|
||||||
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
orig = inpaint_color_sketch_orig or inpaint_color_sketch
|
||||||
@ -153,7 +138,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
|
||||||
blur = ImageFilter.GaussianBlur(mask_blur)
|
blur = ImageFilter.GaussianBlur(mask_blur)
|
||||||
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
image = Image.composite(image.filter(blur), orig, mask.filter(blur))
|
||||||
image = image.convert("RGB")
|
|
||||||
elif mode == 4: # inpaint upload mask
|
elif mode == 4: # inpaint upload mask
|
||||||
image = init_img_inpaint
|
image = init_img_inpaint
|
||||||
mask = init_mask_inpaint
|
mask = init_mask_inpaint
|
||||||
@ -180,21 +164,13 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
seed=seed,
|
sampler_name=sampler_name,
|
||||||
subseed=subseed,
|
|
||||||
subseed_strength=subseed_strength,
|
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
|
||||||
seed_enable_extras=seed_enable_extras,
|
|
||||||
sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
restore_faces=restore_faces,
|
|
||||||
tiling=tiling,
|
|
||||||
init_images=[image],
|
init_images=[image],
|
||||||
mask=mask,
|
mask=mask,
|
||||||
mask_blur=mask_blur,
|
mask_blur=mask_blur,
|
||||||
|
168
modules/initialize.py
Normal file
168
modules/initialize.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
|
def imports():
|
||||||
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
|
import torch # noqa: F401
|
||||||
|
startup_timer.record("import torch")
|
||||||
|
import pytorch_lightning # noqa: F401
|
||||||
|
startup_timer.record("import torch")
|
||||||
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
|
||||||
|
import gradio # noqa: F401
|
||||||
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
|
from modules import paths, timer, import_hook, errors # noqa: F401
|
||||||
|
startup_timer.record("setup paths")
|
||||||
|
|
||||||
|
import ldm.modules.encoders.modules # noqa: F401
|
||||||
|
startup_timer.record("import ldm")
|
||||||
|
|
||||||
|
import sgm.modules.encoders.modules # noqa: F401
|
||||||
|
startup_timer.record("import sgm")
|
||||||
|
|
||||||
|
from modules import shared_init
|
||||||
|
shared_init.initialize()
|
||||||
|
startup_timer.record("initialize shared")
|
||||||
|
|
||||||
|
from modules import processing, gradio_extensons, ui # noqa: F401
|
||||||
|
startup_timer.record("other imports")
|
||||||
|
|
||||||
|
|
||||||
|
def check_versions():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if not cmd_opts.skip_version_check:
|
||||||
|
from modules import errors
|
||||||
|
errors.check_versions()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
from modules import initialize_util
|
||||||
|
initialize_util.fix_torch_version()
|
||||||
|
initialize_util.fix_asyncio_event_loop_policy()
|
||||||
|
initialize_util.validate_tls_options()
|
||||||
|
initialize_util.configure_sigint_handler()
|
||||||
|
initialize_util.configure_opts_onchange()
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
modelloader.cleanup_models()
|
||||||
|
|
||||||
|
from modules import sd_models
|
||||||
|
sd_models.setup_model()
|
||||||
|
startup_timer.record("setup SD model")
|
||||||
|
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
from modules import codeformer_model
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
|
||||||
|
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
|
||||||
|
startup_timer.record("setup codeformer")
|
||||||
|
|
||||||
|
from modules import gfpgan_model
|
||||||
|
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
|
startup_timer.record("setup gfpgan")
|
||||||
|
|
||||||
|
initialize_rest(reload_script_modules=False)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_rest(*, reload_script_modules=False):
|
||||||
|
"""
|
||||||
|
Called both from initialize() and when reloading the webui.
|
||||||
|
"""
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
from modules import sd_samplers
|
||||||
|
sd_samplers.set_samplers()
|
||||||
|
startup_timer.record("set samplers")
|
||||||
|
|
||||||
|
from modules import extensions
|
||||||
|
extensions.list_extensions()
|
||||||
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
|
from modules import initialize_util
|
||||||
|
initialize_util.restore_config_state_file()
|
||||||
|
startup_timer.record("restore config state file")
|
||||||
|
|
||||||
|
from modules import shared, upscaler, scripts
|
||||||
|
if cmd_opts.ui_debug_mode:
|
||||||
|
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||||
|
scripts.load_scripts()
|
||||||
|
return
|
||||||
|
|
||||||
|
from modules import sd_models
|
||||||
|
sd_models.list_models()
|
||||||
|
startup_timer.record("list SD models")
|
||||||
|
|
||||||
|
from modules import localization
|
||||||
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
startup_timer.record("list localizations")
|
||||||
|
|
||||||
|
with startup_timer.subcategory("load scripts"):
|
||||||
|
scripts.load_scripts()
|
||||||
|
|
||||||
|
if reload_script_modules:
|
||||||
|
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||||
|
importlib.reload(module)
|
||||||
|
startup_timer.record("reload script modules")
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
modelloader.load_upscalers()
|
||||||
|
startup_timer.record("load upscalers")
|
||||||
|
|
||||||
|
from modules import sd_vae
|
||||||
|
sd_vae.refresh_vae_list()
|
||||||
|
startup_timer.record("refresh VAE")
|
||||||
|
|
||||||
|
from modules import textual_inversion
|
||||||
|
textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
|
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
|
||||||
|
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
|
||||||
|
sd_hijack.list_optimizers()
|
||||||
|
startup_timer.record("scripts list_optimizers")
|
||||||
|
|
||||||
|
from modules import sd_unet
|
||||||
|
sd_unet.list_unets()
|
||||||
|
startup_timer.record("scripts list_unets")
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
"""
|
||||||
|
Accesses shared.sd_model property to load model.
|
||||||
|
After it's available, if it has been loaded before this access by some extension,
|
||||||
|
its optimization may be None because the list of optimizaers has neet been filled
|
||||||
|
by that time, so we apply optimization again.
|
||||||
|
"""
|
||||||
|
|
||||||
|
shared.sd_model # noqa: B018
|
||||||
|
|
||||||
|
if sd_hijack.current_optimizer is None:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
devices.first_time_calculation()
|
||||||
|
|
||||||
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
|
from modules import shared_items
|
||||||
|
shared_items.reload_hypernetworks()
|
||||||
|
startup_timer.record("reload hypernetworks")
|
||||||
|
|
||||||
|
from modules import ui_extra_networks
|
||||||
|
ui_extra_networks.initialize()
|
||||||
|
ui_extra_networks.register_default_pages()
|
||||||
|
|
||||||
|
from modules import extra_networks
|
||||||
|
extra_networks.initialize()
|
||||||
|
extra_networks.register_default_extra_networks()
|
||||||
|
startup_timer.record("initialize extra networks")
|
202
modules/initialize_util.py
Normal file
202
modules/initialize_util.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
|
||||||
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
|
|
||||||
|
def gradio_server_name():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if cmd_opts.server_name:
|
||||||
|
return cmd_opts.server_name
|
||||||
|
else:
|
||||||
|
return "0.0.0.0" if cmd_opts.listen else None
|
||||||
|
|
||||||
|
|
||||||
|
def fix_torch_version():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
|
||||||
|
if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
||||||
|
torch.__long_version__ = torch.__version__
|
||||||
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_asyncio_event_loop_policy():
|
||||||
|
"""
|
||||||
|
The default `asyncio` event loop policy only automatically creates
|
||||||
|
event loops in the main threads. Other threads must create event
|
||||||
|
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||||
|
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||||
|
loops to be created automatically on any thread, matching the
|
||||||
|
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||||
|
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||||
|
# interface for composing policies so pick the right base.
|
||||||
|
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||||
|
else:
|
||||||
|
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||||
|
|
||||||
|
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||||
|
"""Event loop policy that allows loop creation on any thread.
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||||
|
try:
|
||||||
|
return super().get_event_loop()
|
||||||
|
except (RuntimeError, AssertionError):
|
||||||
|
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||||
|
# and changed to a RuntimeError in 3.4.3.
|
||||||
|
# "There is no current event loop in thread %r"
|
||||||
|
loop = self.new_event_loop()
|
||||||
|
self.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
|
def restore_config_state_file():
|
||||||
|
from modules import shared, config_states
|
||||||
|
|
||||||
|
config_state_file = shared.opts.restore_config_state_file
|
||||||
|
if config_state_file == "":
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.opts.restore_config_state_file = ""
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
if os.path.isfile(config_state_file):
|
||||||
|
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||||
|
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||||
|
config_state = json.load(f)
|
||||||
|
config_states.restore_extension_config(config_state)
|
||||||
|
startup_timer.record("restore extension config")
|
||||||
|
elif config_state_file:
|
||||||
|
print(f"!!! Config state backup not found: {config_state_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tls_options():
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not os.path.exists(cmd_opts.tls_keyfile):
|
||||||
|
print("Invalid path to TLS keyfile given")
|
||||||
|
if not os.path.exists(cmd_opts.tls_certfile):
|
||||||
|
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
|
||||||
|
except TypeError:
|
||||||
|
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
|
||||||
|
print("TLS setup invalid, running webui without TLS")
|
||||||
|
else:
|
||||||
|
print("Running with TLS")
|
||||||
|
startup_timer.record("TLS")
|
||||||
|
|
||||||
|
|
||||||
|
def get_gradio_auth_creds():
|
||||||
|
"""
|
||||||
|
Convert the gradio_auth and gradio_auth_path commandline arguments into
|
||||||
|
an iterable of (username, password) tuples.
|
||||||
|
"""
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
def process_credential_line(s):
|
||||||
|
s = s.strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
return tuple(s.split(':', 1))
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth:
|
||||||
|
for cred in cmd_opts.gradio_auth.split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
if cmd_opts.gradio_auth_path:
|
||||||
|
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
|
||||||
|
for line in file.readlines():
|
||||||
|
for cred in line.strip().split(','):
|
||||||
|
cred = process_credential_line(cred)
|
||||||
|
if cred:
|
||||||
|
yield cred
|
||||||
|
|
||||||
|
|
||||||
|
def dumpstacks():
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
id2name = {th.ident: th.name for th in threading.enumerate()}
|
||||||
|
code = []
|
||||||
|
for threadId, stack in sys._current_frames().items():
|
||||||
|
code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
|
||||||
|
for filename, lineno, name, line in traceback.extract_stack(stack):
|
||||||
|
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||||
|
if line:
|
||||||
|
code.append(" " + line.strip())
|
||||||
|
|
||||||
|
print("\n".join(code))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_sigint_handler():
|
||||||
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
|
def sigint_handler(sig, frame):
|
||||||
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
|
||||||
|
dumpstacks()
|
||||||
|
|
||||||
|
os._exit(0)
|
||||||
|
|
||||||
|
if not os.environ.get("COVERAGE_RUN"):
|
||||||
|
# Don't install the immediate-quit handler when running under coverage,
|
||||||
|
# as then the coverage report won't be generated.
|
||||||
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_opts_onchange():
|
||||||
|
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
||||||
|
from modules.call_queue import wrap_queued_call
|
||||||
|
|
||||||
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||||
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||||
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
|
||||||
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
|
startup_timer.record("opts onchange")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_middleware(app):
|
||||||
|
from starlette.middleware.gzip import GZipMiddleware
|
||||||
|
|
||||||
|
app.middleware_stack = None # reset current middleware to allow modifying user provided list
|
||||||
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
configure_cors_middleware(app)
|
||||||
|
app.build_middleware_stack() # rebuild middleware stack on-the-fly
|
||||||
|
|
||||||
|
|
||||||
|
def configure_cors_middleware(app):
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
cors_options = {
|
||||||
|
"allow_methods": ["*"],
|
||||||
|
"allow_headers": ["*"],
|
||||||
|
"allow_credentials": True,
|
||||||
|
}
|
||||||
|
if cmd_opts.cors_allow_origins:
|
||||||
|
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
|
||||||
|
if cmd_opts.cors_allow_origins_regex:
|
||||||
|
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_options)
|
||||||
|
|
@ -186,9 +186,8 @@ class InterrogateModels:
|
|||||||
res = ""
|
res = ""
|
||||||
shared.state.begin(job="interrogate")
|
shared.state.begin(job="interrogate")
|
||||||
try:
|
try:
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
lowvram.send_everything_to_cpu()
|
||||||
lowvram.send_everything_to_cpu()
|
devices.torch_gc()
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
self.load()
|
self.load()
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# this scripts installs necessary requirements and launches main program in webui.py
|
# this scripts installs necessary requirements and launches main program in webui.py
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import platform
|
import platform
|
||||||
@ -10,11 +12,11 @@ from functools import lru_cache
|
|||||||
|
|
||||||
from modules import cmd_args, errors
|
from modules import cmd_args, errors
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
from modules import timer
|
from modules.timer import startup_timer
|
||||||
|
from modules import logging_config
|
||||||
timer.startup_timer.record("start")
|
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
logging_config.setup_logging(args.loglevel)
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
@ -141,6 +143,25 @@ def check_run_python(code: str) -> bool:
|
|||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def git_fix_workspace(dir, name):
|
||||||
|
run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True)
|
||||||
|
run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):
|
||||||
|
try:
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
except RuntimeError:
|
||||||
|
if not autofix:
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"{errdesc}, attempting autofix...")
|
||||||
|
git_fix_workspace(dir, name)
|
||||||
|
|
||||||
|
return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
# TODO clone into temporary dir and move if successful
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
@ -148,15 +169,24 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if commithash is None:
|
if commithash is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
|
||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||||
|
|
||||||
|
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||||
|
|
||||||
|
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
try:
|
||||||
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||||
|
except RuntimeError:
|
||||||
|
shutil.rmtree(dir, ignore_errors=True)
|
||||||
|
raise
|
||||||
|
|
||||||
if commithash is not None:
|
if commithash is not None:
|
||||||
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
||||||
@ -198,7 +228,9 @@ def run_extension_installer(extension_dir):
|
|||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||||
|
|
||||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
stdout = run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env).strip()
|
||||||
|
if stdout:
|
||||||
|
print(stdout)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.report(str(e))
|
errors.report(str(e))
|
||||||
|
|
||||||
@ -216,7 +248,7 @@ def list_extensions(settings_file):
|
|||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
|
|
||||||
if disable_all_extensions != 'none':
|
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||||
@ -226,8 +258,15 @@ def run_extensions_installers(settings_file):
|
|||||||
if not os.path.isdir(extensions_dir):
|
if not os.path.isdir(extensions_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
for dirname_extension in list_extensions(settings_file):
|
with startup_timer.subcategory("run extensions installers"):
|
||||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
for dirname_extension in list_extensions(settings_file):
|
||||||
|
logging.debug(f"Installing {dirname_extension}")
|
||||||
|
|
||||||
|
path = os.path.join(extensions_dir, dirname_extension)
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
run_extension_installer(path)
|
||||||
|
startup_timer.record(dirname_extension)
|
||||||
|
|
||||||
|
|
||||||
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||||
@ -274,7 +313,6 @@ def prepare_environment():
|
|||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
@ -285,13 +323,13 @@ def prepare_environment():
|
|||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||||
except OSError:
|
except OSError:
|
||||||
@ -300,8 +338,11 @@ def prepare_environment():
|
|||||||
if not args.skip_python_version_check:
|
if not args.skip_python_version_check:
|
||||||
check_python_version()
|
check_python_version()
|
||||||
|
|
||||||
|
startup_timer.record("checks")
|
||||||
|
|
||||||
commit = commit_hash()
|
commit = commit_hash()
|
||||||
tag = git_tag()
|
tag = git_tag()
|
||||||
|
startup_timer.record("git version info")
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
print(f"Version: {tag}")
|
print(f"Version: {tag}")
|
||||||
@ -309,36 +350,30 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
startup_timer.record("install torch")
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Torch is not able to use GPU; '
|
'Torch is not able to use GPU; '
|
||||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
)
|
)
|
||||||
|
startup_timer.record("torch GPU test")
|
||||||
if not is_installed("gfpgan"):
|
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
|
||||||
|
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
startup_timer.record("install clip")
|
||||||
|
|
||||||
if not is_installed("open_clip"):
|
if not is_installed("open_clip"):
|
||||||
run_pip(f"install {openclip_package}", "open_clip")
|
run_pip(f"install {openclip_package}", "open_clip")
|
||||||
|
startup_timer.record("install open_clip")
|
||||||
|
|
||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
if platform.system() == "Windows":
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||||
if platform.python_version().startswith("3.10"):
|
startup_timer.record("install xformers")
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
|
||||||
else:
|
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
|
||||||
if not is_installed("xformers"):
|
|
||||||
exit(0)
|
|
||||||
elif platform.system() == "Linux":
|
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
|
||||||
|
|
||||||
if not is_installed("ngrok") and args.ngrok:
|
if not is_installed("ngrok") and args.ngrok:
|
||||||
run_pip("install ngrok", "ngrok")
|
run_pip("install ngrok", "ngrok")
|
||||||
|
startup_timer.record("install ngrok")
|
||||||
|
|
||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
@ -348,22 +383,29 @@ def prepare_environment():
|
|||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
|
startup_timer.record("clone repositores")
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
if not is_installed("lpips"):
|
||||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||||
|
startup_timer.record("install CodeFormer requirements")
|
||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
if not os.path.isfile(requirements_file):
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
|
|
||||||
if not requirements_met(requirements_file):
|
if not requirements_met(requirements_file):
|
||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
|
startup_timer.record("install requirements")
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
if not args.skip_install:
|
||||||
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
if args.update_check:
|
if args.update_check:
|
||||||
version_check(commit)
|
version_check(commit)
|
||||||
|
startup_timer.record("check version")
|
||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
startup_timer.record("update extensions")
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
@ -392,3 +434,16 @@ def start():
|
|||||||
webui.api_only()
|
webui.api_only()
|
||||||
else:
|
else:
|
||||||
webui.webui()
|
webui.webui()
|
||||||
|
|
||||||
|
|
||||||
|
def dump_sysinfo():
|
||||||
|
from modules import sysinfo
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
text = sysinfo.get()
|
||||||
|
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
||||||
|
|
||||||
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
|
file.write(text)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors, scripts
|
||||||
|
|
||||||
localizations = {}
|
localizations = {}
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ def list_localizations(dirname):
|
|||||||
|
|
||||||
localizations[fn] = os.path.join(dirname, file)
|
localizations[fn] = os.path.join(dirname, file)
|
||||||
|
|
||||||
from modules import scripts
|
|
||||||
for file in scripts.list_scripts("localizations", ".json"):
|
for file in scripts.list_scripts("localizations", ".json"):
|
||||||
fn, ext = os.path.splitext(file.filename)
|
fn, ext = os.path.splitext(file.filename)
|
||||||
localizations[fn] = file.path
|
localizations[fn] = file.path
|
||||||
|
16
modules/logging_config.py
Normal file
16
modules/logging_config.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(loglevel):
|
||||||
|
if loglevel is None:
|
||||||
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
|
if loglevel:
|
||||||
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
)
|
||||||
|
|
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
from modules import devices
|
from modules import devices, shared
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
@ -14,7 +14,24 @@ def send_everything_to_cpu():
|
|||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_needed(sd_model):
|
||||||
|
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
|
||||||
|
|
||||||
|
|
||||||
|
def apply(sd_model):
|
||||||
|
enable = is_needed(sd_model)
|
||||||
|
shared.parallel_processing_allowed = not enable
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
|
||||||
|
else:
|
||||||
|
sd_model.lowvram = False
|
||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
|
if getattr(sd_model, 'lowvram', False):
|
||||||
|
return
|
||||||
|
|
||||||
sd_model.lowvram = True
|
sd_model.lowvram = True
|
||||||
|
|
||||||
parents = {}
|
parents = {}
|
||||||
@ -127,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
|
|
||||||
|
|
||||||
def is_enabled(sd_model):
|
def is_enabled(sd_model):
|
||||||
return getattr(sd_model, 'lowvram', False)
|
return sd_model.lowvram
|
||||||
|
@ -4,6 +4,7 @@ import torch
|
|||||||
import platform
|
import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,8 +31,7 @@ has_mps = check_for_mps()
|
|||||||
|
|
||||||
def torch_mps_gc() -> None:
|
def torch_mps_gc() -> None:
|
||||||
try:
|
try:
|
||||||
from modules.shared import state
|
if shared.state.current_latent is not None:
|
||||||
if state.current_latent is not None:
|
|
||||||
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
log.debug("`current_latent` is set, skipping MPS garbage collection")
|
||||||
return
|
return
|
||||||
from torch.mps import empty_cache
|
from torch.mps import empty_cache
|
||||||
@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
|
||||||
|
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||||
|
245
modules/options.py
Normal file
245
modules/options.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
|
class OptionInfo:
|
||||||
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
||||||
|
self.default = default
|
||||||
|
self.label = label
|
||||||
|
self.component = component
|
||||||
|
self.component_args = component_args
|
||||||
|
self.onchange = onchange
|
||||||
|
self.section = section
|
||||||
|
self.refresh = refresh
|
||||||
|
self.do_not_save = False
|
||||||
|
|
||||||
|
self.comment_before = comment_before
|
||||||
|
"""HTML text that will be added after label in UI"""
|
||||||
|
|
||||||
|
self.comment_after = comment_after
|
||||||
|
"""HTML text that will be added before label in UI"""
|
||||||
|
|
||||||
|
self.infotext = infotext
|
||||||
|
|
||||||
|
self.restrict_api = restrict_api
|
||||||
|
"""If True, the setting will not be accessible via API"""
|
||||||
|
|
||||||
|
def link(self, label, url):
|
||||||
|
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def js(self, label, js_func):
|
||||||
|
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self, info):
|
||||||
|
self.comment_after += f"<span class='info'>({info})</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def html(self, html):
|
||||||
|
self.comment_after += html
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_restart(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_reload_ui(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class OptionHTML(OptionInfo):
|
||||||
|
def __init__(self, text):
|
||||||
|
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
|
||||||
|
|
||||||
|
self.do_not_save = True
|
||||||
|
|
||||||
|
|
||||||
|
def options_section(section_identifier, options_dict):
|
||||||
|
for v in options_dict.values():
|
||||||
|
v.section = section_identifier
|
||||||
|
|
||||||
|
return options_dict
|
||||||
|
|
||||||
|
|
||||||
|
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
|
||||||
|
|
||||||
|
|
||||||
|
class Options:
|
||||||
|
typemap = {int: float}
|
||||||
|
|
||||||
|
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||||
|
self.data_labels = data_labels
|
||||||
|
self.data = {k: v.default for k, v in self.data_labels.items()}
|
||||||
|
self.restricted_opts = restricted_opts
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in options_builtin_fields:
|
||||||
|
return super(Options, self).__setattr__(key, value)
|
||||||
|
|
||||||
|
if self.data is not None:
|
||||||
|
if key in self.data or key in self.data_labels:
|
||||||
|
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||||
|
|
||||||
|
info = self.data_labels.get(key, None)
|
||||||
|
if info.do_not_save:
|
||||||
|
return
|
||||||
|
|
||||||
|
comp_args = info.component_args if info else None
|
||||||
|
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||||
|
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||||
|
|
||||||
|
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||||
|
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||||
|
|
||||||
|
self.data[key] = value
|
||||||
|
return
|
||||||
|
|
||||||
|
return super(Options, self).__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item in options_builtin_fields:
|
||||||
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
if self.data is not None:
|
||||||
|
if item in self.data:
|
||||||
|
return self.data[item]
|
||||||
|
|
||||||
|
if item in self.data_labels:
|
||||||
|
return self.data_labels[item].default
|
||||||
|
|
||||||
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
def set(self, key, value, is_api=False, run_callbacks=True):
|
||||||
|
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
||||||
|
|
||||||
|
oldval = self.data.get(key, None)
|
||||||
|
if oldval == value:
|
||||||
|
return False
|
||||||
|
|
||||||
|
option = self.data_labels[key]
|
||||||
|
if option.do_not_save:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_api and option.restrict_api:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
setattr(self, key, value)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if run_callbacks and option.onchange is not None:
|
||||||
|
try:
|
||||||
|
option.onchange()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"changing setting {key} to {value}")
|
||||||
|
setattr(self, key, oldval)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_default(self, key):
|
||||||
|
"""returns the default value for the key"""
|
||||||
|
|
||||||
|
data_label = self.data_labels.get(key)
|
||||||
|
if data_label is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data_label.default
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||||
|
|
||||||
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(self.data, file, indent=4)
|
||||||
|
|
||||||
|
def same_type(self, x, y):
|
||||||
|
if x is None or y is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
type_x = self.typemap.get(type(x), type(x))
|
||||||
|
type_y = self.typemap.get(type(y), type(y))
|
||||||
|
|
||||||
|
return type_x == type_y
|
||||||
|
|
||||||
|
def load(self, filename):
|
||||||
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.6.0 VAE defaults
|
||||||
|
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||||
|
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||||
|
|
||||||
|
# 1.1.1 quicksettings list migration
|
||||||
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
|
||||||
|
# 1.4.0 ui_reorder
|
||||||
|
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
||||||
|
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
||||||
|
|
||||||
|
bad_settings = 0
|
||||||
|
for k, v in self.data.items():
|
||||||
|
info = self.data_labels.get(k, None)
|
||||||
|
if info is not None and not self.same_type(info.default, v):
|
||||||
|
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
||||||
|
bad_settings += 1
|
||||||
|
|
||||||
|
if bad_settings > 0:
|
||||||
|
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
||||||
|
|
||||||
|
def onchange(self, key, func, call=True):
|
||||||
|
item = self.data_labels.get(key)
|
||||||
|
item.onchange = func
|
||||||
|
|
||||||
|
if call:
|
||||||
|
func()
|
||||||
|
|
||||||
|
def dumpjson(self):
|
||||||
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
|
return json.dumps(d)
|
||||||
|
|
||||||
|
def add_option(self, key, info):
|
||||||
|
self.data_labels[key] = info
|
||||||
|
|
||||||
|
def reorder(self):
|
||||||
|
"""reorder settings so that all items related to section always go together"""
|
||||||
|
|
||||||
|
section_ids = {}
|
||||||
|
settings_items = self.data_labels.items()
|
||||||
|
for _, item in settings_items:
|
||||||
|
if item.section not in section_ids:
|
||||||
|
section_ids[item.section] = len(section_ids)
|
||||||
|
|
||||||
|
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||||
|
|
||||||
|
def cast_value(self, key, value):
|
||||||
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
|
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
default_value = self.data_labels[key].default
|
||||||
|
if default_value is None:
|
||||||
|
default_value = getattr(self, key, None)
|
||||||
|
if default_value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
expected_type = type(default_value)
|
||||||
|
if expected_type == bool and value == "False":
|
||||||
|
value = False
|
||||||
|
else:
|
||||||
|
value = expected_type(value)
|
||||||
|
|
||||||
|
return value
|
64
modules/patches.py
Normal file
64
modules/patches.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
def patch(key, obj, field, replacement):
|
||||||
|
"""Replaces a function in a module or a class.
|
||||||
|
|
||||||
|
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
|
||||||
|
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
replacement: the new function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the original function
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
if patch_key in originals[key]:
|
||||||
|
raise RuntimeError(f"patch for {field} is already applied")
|
||||||
|
|
||||||
|
original_func = getattr(obj, field)
|
||||||
|
originals[key][patch_key] = original_func
|
||||||
|
|
||||||
|
setattr(obj, field, replacement)
|
||||||
|
|
||||||
|
return original_func
|
||||||
|
|
||||||
|
|
||||||
|
def undo(key, obj, field):
|
||||||
|
"""Undoes the peplacement by the patch().
|
||||||
|
|
||||||
|
If the function is not replaced, raises an exception.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
key: identifying information for who is doing the replacement. You can use __name__.
|
||||||
|
obj: the module or the class
|
||||||
|
field: name of the function as a string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always None
|
||||||
|
"""
|
||||||
|
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
if patch_key not in originals[key]:
|
||||||
|
raise RuntimeError(f"there is no patch for {field} to undo")
|
||||||
|
|
||||||
|
original_func = originals[key].pop(patch_key)
|
||||||
|
setattr(obj, field, original_func)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def original(key, obj, field):
|
||||||
|
"""Returns the original function for the patch created by the patch() function"""
|
||||||
|
patch_key = (obj, field)
|
||||||
|
|
||||||
|
return originals[key].get(patch_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
originals = defaultdict(dict)
|
@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
shared.state.begin(job="extras")
|
shared.state.begin(job="extras")
|
||||||
|
|
||||||
image_data = []
|
|
||||||
image_names = []
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
if extras_mode == 1:
|
def get_images(extras_mode, image, image_folder, input_dir):
|
||||||
for img in image_folder:
|
if extras_mode == 1:
|
||||||
if isinstance(img, Image.Image):
|
for img in image_folder:
|
||||||
image = img
|
if isinstance(img, Image.Image):
|
||||||
fn = ''
|
image = img
|
||||||
else:
|
fn = ''
|
||||||
image = Image.open(os.path.abspath(img.name))
|
else:
|
||||||
fn = os.path.splitext(img.orig_name)[0]
|
image = Image.open(os.path.abspath(img.name))
|
||||||
image_data.append(image)
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_names.append(fn)
|
yield image, fn
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
|
|
||||||
image_list = shared.listfiles(input_dir)
|
image_list = shared.listfiles(input_dir)
|
||||||
for filename in image_list:
|
for filename in image_list:
|
||||||
try:
|
try:
|
||||||
image = Image.open(filename)
|
image = Image.open(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
image_data.append(image)
|
yield image, filename
|
||||||
image_names.append(filename)
|
else:
|
||||||
else:
|
assert image, 'image not selected'
|
||||||
assert image, 'image not selected'
|
yield image, None
|
||||||
|
|
||||||
image_data.append(image)
|
|
||||||
image_names.append(None)
|
|
||||||
|
|
||||||
if extras_mode == 2 and output_dir != '':
|
if extras_mode == 2 and output_dir != '':
|
||||||
outpath = output_dir
|
outpath = output_dir
|
||||||
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
infotext = ''
|
infotext = ''
|
||||||
|
|
||||||
for image, name in zip(image_data, image_names):
|
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
||||||
|
image_data: Image.Image
|
||||||
|
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(pp, args)
|
scripts.scripts_postproc.run(pp, args)
|
||||||
|
|
||||||
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(pp.image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
File diff suppressed because it is too large
Load Diff
49
modules/processing_scripts/refiner.py
Normal file
49
modules/processing_scripts/refiner.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, sd_models
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||||
|
section = "accordions"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Refiner"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||||
|
with gr.Row():
|
||||||
|
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
|
||||||
|
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||||
|
|
||||||
|
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||||
|
|
||||||
|
def lookup_checkpoint(title):
|
||||||
|
info = sd_models.get_closet_checkpoint_match(title)
|
||||||
|
return None if info is None else info.title
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(enable_refiner, lambda d: 'Refiner' in d),
|
||||||
|
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||||
|
(refiner_switch_at, 'Refiner switch at'),
|
||||||
|
]
|
||||||
|
|
||||||
|
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||||
|
|
||||||
|
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||||
|
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||||
|
|
||||||
|
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||||
|
p.refiner_checkpoint = None
|
||||||
|
p.refiner_switch_at = None
|
||||||
|
else:
|
||||||
|
p.refiner_checkpoint = refiner_checkpoint
|
||||||
|
p.refiner_switch_at = refiner_switch_at
|
111
modules/processing_scripts/seed.py
Normal file
111
modules/processing_scripts/seed.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, ui, errors
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||||
|
section = "seed"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.seed = None
|
||||||
|
self.reuse_seed = None
|
||||||
|
self.reuse_subseed = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Seed"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_row")):
|
||||||
|
if cmd_opts.use_textbox_seed:
|
||||||
|
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
|
||||||
|
else:
|
||||||
|
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||||
|
|
||||||
|
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
||||||
|
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
||||||
|
|
||||||
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||||
|
|
||||||
|
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
|
||||||
|
with gr.Row(elem_id=self.elem_id("subseed_row")):
|
||||||
|
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
|
||||||
|
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
|
||||||
|
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
|
||||||
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
|
||||||
|
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
|
||||||
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
|
||||||
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
|
||||||
|
|
||||||
|
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(self.seed, "Seed"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
|
(subseed, "Variation seed"),
|
||||||
|
(subseed_strength, "Variation seed strength"),
|
||||||
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
|
||||||
|
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
|
||||||
|
|
||||||
|
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
|
||||||
|
p.seed = seed
|
||||||
|
|
||||||
|
if seed_checkbox and subseed_strength > 0:
|
||||||
|
p.subseed = subseed
|
||||||
|
p.subseed_strength = subseed_strength
|
||||||
|
|
||||||
|
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
|
||||||
|
p.seed_resize_from_w = seed_resize_from_w
|
||||||
|
p.seed_resize_from_h = seed_resize_from_h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||||
|
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||||
|
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||||
|
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
||||||
|
|
||||||
|
def copy_seed(gen_info_string: str, index):
|
||||||
|
res = -1
|
||||||
|
|
||||||
|
try:
|
||||||
|
gen_info = json.loads(gen_info_string)
|
||||||
|
index -= gen_info.get('index_of_first_image', 0)
|
||||||
|
|
||||||
|
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||||
|
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||||
|
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||||
|
else:
|
||||||
|
all_seeds = gen_info.get('all_seeds', [-1])
|
||||||
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
if gen_info_string:
|
||||||
|
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||||
|
|
||||||
|
return [res, gr.update()]
|
||||||
|
|
||||||
|
reuse_seed.click(
|
||||||
|
fn=copy_seed,
|
||||||
|
_js="(x, y) => [x, selected_gallery_index()]",
|
||||||
|
show_progress=False,
|
||||||
|
inputs=[generation_info, seed],
|
||||||
|
outputs=[seed, seed]
|
||||||
|
)
|
@ -48,6 +48,7 @@ def add_task_to_queue(id_job):
|
|||||||
class ProgressRequest(BaseModel):
|
class ProgressRequest(BaseModel):
|
||||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||||
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
|
||||||
|
live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
|
||||||
|
|
||||||
|
|
||||||
class ProgressResponse(BaseModel):
|
class ProgressResponse(BaseModel):
|
||||||
@ -71,7 +72,12 @@ def progressapi(req: ProgressRequest):
|
|||||||
completed = req.id_task in finished_tasks
|
completed = req.id_task in finished_tasks
|
||||||
|
|
||||||
if not active:
|
if not active:
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
|
textinfo = "Waiting..."
|
||||||
|
if queued:
|
||||||
|
sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
|
||||||
|
queue_index = sorted_queued.index(req.id_task)
|
||||||
|
textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
|
||||||
|
return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
|
||||||
|
|
||||||
progress = 0
|
progress = 0
|
||||||
|
|
||||||
@ -89,31 +95,30 @@ def progressapi(req: ProgressRequest):
|
|||||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||||
|
|
||||||
|
live_preview = None
|
||||||
id_live_preview = req.id_live_preview
|
id_live_preview = req.id_live_preview
|
||||||
shared.state.set_current_image()
|
|
||||||
if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
|
|
||||||
image = shared.state.current_image
|
|
||||||
if image is not None:
|
|
||||||
buffered = io.BytesIO()
|
|
||||||
|
|
||||||
if opts.live_previews_image_format == "png":
|
if opts.live_previews_enable and req.live_preview:
|
||||||
# using optimize for large images takes an enormous amount of time
|
shared.state.set_current_image()
|
||||||
if max(*image.size) <= 256:
|
if shared.state.id_live_preview != req.id_live_preview:
|
||||||
save_kwargs = {"optimize": True}
|
image = shared.state.current_image
|
||||||
|
if image is not None:
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
|
||||||
|
if opts.live_previews_image_format == "png":
|
||||||
|
# using optimize for large images takes an enormous amount of time
|
||||||
|
if max(*image.size) <= 256:
|
||||||
|
save_kwargs = {"optimize": True}
|
||||||
|
else:
|
||||||
|
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
save_kwargs = {"optimize": False, "compress_level": 1}
|
save_kwargs = {}
|
||||||
|
|
||||||
else:
|
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||||
save_kwargs = {}
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
|
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||||
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
id_live_preview = shared.state.id_live_preview
|
||||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
|
||||||
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
|
||||||
id_live_preview = shared.state.id_live_preview
|
|
||||||
else:
|
|
||||||
live_preview = None
|
|
||||||
else:
|
|
||||||
live_preview = None
|
|
||||||
|
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
|
@ -19,14 +19,14 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
|||||||
!emphasized: "(" prompt ")"
|
!emphasized: "(" prompt ")"
|
||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||||
alternate: "[" prompt ("|" prompt)+ "]"
|
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():|]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""
|
"""
|
||||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||||
>>> g("test")
|
>>> g("test")
|
||||||
@ -53,18 +53,43 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||||
>>> g("[a|(b:1.1)]")
|
>>> g("[a|(b:1.1)]")
|
||||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||||
|
>>> g("[fe|]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g("[fe|||]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
|
||||||
|
>>> g("a [b:.5] c")
|
||||||
|
[[10, 'a b c']]
|
||||||
|
>>> g("a [b:1.5] c")
|
||||||
|
[[5, 'a c'], [10, 'a b c']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if hires_steps is None or use_old_scheduling:
|
||||||
|
int_offset = 0
|
||||||
|
flt_offset = 0
|
||||||
|
steps = base_steps
|
||||||
|
else:
|
||||||
|
int_offset = base_steps
|
||||||
|
flt_offset = 1.0
|
||||||
|
steps = hires_steps
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
res = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
s = tree.children[-2]
|
||||||
if tree.children[-1] < 1:
|
v = float(s)
|
||||||
tree.children[-1] *= steps
|
if use_old_scheduling:
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
v = v*steps if v<1 else v
|
||||||
res.append(tree.children[-1])
|
else:
|
||||||
|
if "." in s:
|
||||||
|
v = (v - flt_offset) * steps
|
||||||
|
else:
|
||||||
|
v = (v - int_offset)
|
||||||
|
tree.children[-2] = min(steps, int(v))
|
||||||
|
if tree.children[-2] >= 1:
|
||||||
|
res.append(tree.children[-2])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
res.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
@ -75,13 +100,14 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
def scheduled(self, args):
|
def scheduled(self, args):
|
||||||
before, after, _, when = args
|
before, after, _, when, _ = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
def alternate(self, args):
|
def alternate(self, args):
|
||||||
yield next(args[(step - 1)%len(args)])
|
args = ["" if not arg else arg for arg in args]
|
||||||
|
yield args[(step - 1) % len(args)]
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if isinstance(x, str):
|
||||||
yield x
|
yield x
|
||||||
else:
|
else:
|
||||||
for gen in x:
|
for gen in x:
|
||||||
@ -129,7 +155,7 @@ class SdConditioning(list):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
and the sampling step at which this condition is to be replaced by the next one.
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
@ -149,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
|||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||||
@ -224,7 +250,7 @@ class MulticondLearnedConditioning:
|
|||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
|
||||||
@ -233,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
|||||||
|
|
||||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||||
|
|
||||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for indexes in res_indexes:
|
for indexes in res_indexes:
|
||||||
@ -333,7 +359,7 @@ re_attention = re.compile(r"""
|
|||||||
\\|
|
\\|
|
||||||
\(|
|
\(|
|
||||||
\[|
|
\[|
|
||||||
:([+-]?[.\d]+)\)|
|
:\s*([+-]?[.\d]+)\s*\)|
|
||||||
\)|
|
\)|
|
||||||
]|
|
]|
|
||||||
[^\\()\[\]:]+|
|
[^\\()\[\]:]+|
|
||||||
|
@ -55,6 +55,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||||
tile=opts.ESRGAN_tile,
|
tile=opts.ESRGAN_tile,
|
||||||
tile_pad=opts.ESRGAN_tile_overlap,
|
tile_pad=opts.ESRGAN_tile_overlap,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||||
|
170
modules/rng.py
Normal file
170
modules/rng.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import devices, rng_philox, shared
|
||||||
|
|
||||||
|
|
||||||
|
def randn(seed, shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||||
|
|
||||||
|
manual_seed(seed)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_local(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
rng = rng_philox.Generator(seed)
|
||||||
|
return torch.asarray(rng.randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||||
|
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_like(x):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
|
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_without_seed(shape, generator=None):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
|
||||||
|
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
|
||||||
|
|
||||||
|
return torch.randn(shape, device=devices.device, generator=generator)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed):
|
||||||
|
"""Set up a global random number generator using the specified seed."""
|
||||||
|
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
global nv_rng
|
||||||
|
nv_rng = rng_philox.Generator(seed)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def create_generator(seed):
|
||||||
|
if shared.opts.randn_source == "NV":
|
||||||
|
return rng_philox.Generator(seed)
|
||||||
|
|
||||||
|
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
|
||||||
|
generator = torch.Generator(device).manual_seed(int(seed))
|
||||||
|
return generator
|
||||||
|
|
||||||
|
|
||||||
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
|
def slerp(val, low, high):
|
||||||
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||||
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||||
|
dot = (low_norm*high_norm).sum(1)
|
||||||
|
|
||||||
|
if dot.mean() > 0.9995:
|
||||||
|
return low * val + high * (1 - val)
|
||||||
|
|
||||||
|
omega = torch.acos(dot)
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRNG:
|
||||||
|
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||||
|
self.shape = tuple(map(int, shape))
|
||||||
|
self.seeds = seeds
|
||||||
|
self.subseeds = subseeds
|
||||||
|
self.subseed_strength = subseed_strength
|
||||||
|
self.seed_resize_from_h = seed_resize_from_h
|
||||||
|
self.seed_resize_from_w = seed_resize_from_w
|
||||||
|
|
||||||
|
self.generators = [create_generator(seed) for seed in seeds]
|
||||||
|
|
||||||
|
self.is_first = True
|
||||||
|
|
||||||
|
def first(self):
|
||||||
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
|
||||||
|
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
|
||||||
|
subnoise = None
|
||||||
|
if self.subseeds is not None and self.subseed_strength != 0:
|
||||||
|
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
|
||||||
|
subnoise = randn(subseed, noise_shape)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
noise = randn(seed, noise_shape)
|
||||||
|
else:
|
||||||
|
noise = randn(seed, self.shape, generator=generator)
|
||||||
|
|
||||||
|
if subnoise is not None:
|
||||||
|
noise = slerp(self.subseed_strength, noise, subnoise)
|
||||||
|
|
||||||
|
if noise_shape != self.shape:
|
||||||
|
x = randn(seed, self.shape, generator=generator)
|
||||||
|
dx = (self.shape[2] - noise_shape[2]) // 2
|
||||||
|
dy = (self.shape[1] - noise_shape[1]) // 2
|
||||||
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||||
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||||
|
tx = 0 if dx < 0 else dx
|
||||||
|
ty = 0 if dy < 0 else dy
|
||||||
|
dx = max(-dx, 0)
|
||||||
|
dy = max(-dy, 0)
|
||||||
|
|
||||||
|
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
|
||||||
|
noise = x
|
||||||
|
|
||||||
|
xs.append(noise)
|
||||||
|
|
||||||
|
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
|
||||||
|
if eta_noise_seed_delta:
|
||||||
|
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
if self.is_first:
|
||||||
|
self.is_first = False
|
||||||
|
return self.first()
|
||||||
|
|
||||||
|
xs = []
|
||||||
|
for generator in self.generators:
|
||||||
|
x = randn_without_seed(self.shape, generator=generator)
|
||||||
|
xs.append(x)
|
||||||
|
|
||||||
|
return torch.stack(xs).to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
devices.randn = randn
|
||||||
|
devices.randn_local = randn_local
|
||||||
|
devices.randn_like = randn_like
|
||||||
|
devices.randn_without_seed = randn_without_seed
|
||||||
|
devices.manual_seed = manual_seed
|
102
modules/rng_philox.py
Normal file
102
modules/rng_philox.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
g = Generator(seed=0)
|
||||||
|
print(g.randn(shape=(3, 4)))
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||||
|
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||||
|
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||||
|
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||||
|
|
||||||
|
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||||
|
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def uint32(x):
|
||||||
|
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||||
|
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def philox4_round(counter, key):
|
||||||
|
"""A single round of the Philox 4x32 random number generator."""
|
||||||
|
|
||||||
|
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||||
|
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||||
|
|
||||||
|
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||||
|
counter[1] = v2[0]
|
||||||
|
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||||
|
counter[3] = v1[0]
|
||||||
|
|
||||||
|
|
||||||
|
def philox4_32(counter, key, rounds=10):
|
||||||
|
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||||
|
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||||
|
rounds (int): The number of rounds to perform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for _ in range(rounds - 1):
|
||||||
|
philox4_round(counter, key)
|
||||||
|
|
||||||
|
key[0] = key[0] + philox_w[0]
|
||||||
|
key[1] = key[1] + philox_w[1]
|
||||||
|
|
||||||
|
philox4_round(counter, key)
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
def box_muller(x, y):
|
||||||
|
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||||
|
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||||
|
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||||
|
|
||||||
|
s = np.sqrt(-2.0 * np.log(u))
|
||||||
|
|
||||||
|
r1 = s * np.sin(v)
|
||||||
|
return r1.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator:
|
||||||
|
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||||
|
|
||||||
|
def __init__(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def randn(self, shape):
|
||||||
|
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||||
|
|
||||||
|
n = 1
|
||||||
|
for x in shape:
|
||||||
|
n *= x
|
||||||
|
|
||||||
|
counter = np.zeros((4, n), dtype=np.uint32)
|
||||||
|
counter[0] = self.offset
|
||||||
|
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||||
|
self.offset += 1
|
||||||
|
|
||||||
|
key = np.empty(n, dtype=np.uint64)
|
||||||
|
key.fill(self.seed)
|
||||||
|
key = uint32(key)
|
||||||
|
|
||||||
|
g = philox4_32(counter, key)
|
||||||
|
|
||||||
|
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
@ -28,6 +28,18 @@ class ImageSaveParams:
|
|||||||
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
|
||||||
|
|
||||||
|
|
||||||
|
class ExtraNoiseParams:
|
||||||
|
def __init__(self, noise, x, xi):
|
||||||
|
self.noise = noise
|
||||||
|
"""Random noise generated by the seed"""
|
||||||
|
|
||||||
|
self.x = x
|
||||||
|
"""Latent representation of the image"""
|
||||||
|
|
||||||
|
self.xi = xi
|
||||||
|
"""Noisy latent representation of the image"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiserParams:
|
class CFGDenoiserParams:
|
||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
@ -100,6 +112,7 @@ callback_map = dict(
|
|||||||
callbacks_ui_settings=[],
|
callbacks_ui_settings=[],
|
||||||
callbacks_before_image_saved=[],
|
callbacks_before_image_saved=[],
|
||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
|
callbacks_extra_noise=[],
|
||||||
callbacks_cfg_denoiser=[],
|
callbacks_cfg_denoiser=[],
|
||||||
callbacks_cfg_denoised=[],
|
callbacks_cfg_denoised=[],
|
||||||
callbacks_cfg_after_cfg=[],
|
callbacks_cfg_after_cfg=[],
|
||||||
@ -189,6 +202,14 @@ def image_saved_callback(params: ImageSaveParams):
|
|||||||
report_exception(c, 'image_saved_callback')
|
report_exception(c, 'image_saved_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def extra_noise_callback(params: ExtraNoiseParams):
|
||||||
|
for c in callback_map['callbacks_extra_noise']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_extra_noise')
|
||||||
|
|
||||||
|
|
||||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||||
for c in callback_map['callbacks_cfg_denoiser']:
|
for c in callback_map['callbacks_cfg_denoiser']:
|
||||||
try:
|
try:
|
||||||
@ -367,6 +388,14 @@ def on_image_saved(callback):
|
|||||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
add_callback(callback_map['callbacks_image_saved'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_extra_noise(callback):
|
||||||
|
"""register a function to be called before adding extra noise in img2img or hires fix;
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_extra_noise'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_cfg_denoiser(callback):
|
def on_cfg_denoiser(callback):
|
||||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
|
@ -3,6 +3,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
|
|||||||
self.images = images
|
self.images = images
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OnComponent:
|
||||||
|
component: gr.blocks.Block
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
@ -35,9 +41,13 @@ class Script:
|
|||||||
|
|
||||||
is_txt2img = False
|
is_txt2img = False
|
||||||
is_img2img = False
|
is_img2img = False
|
||||||
|
tabname = None
|
||||||
|
|
||||||
group = None
|
group = None
|
||||||
"""A gr.Group component that has all script's UI inside it"""
|
"""A gr.Group component that has all script's UI inside it."""
|
||||||
|
|
||||||
|
create_group = True
|
||||||
|
"""If False, for alwayson scripts, a group component will not be created."""
|
||||||
|
|
||||||
infotext_fields = None
|
infotext_fields = None
|
||||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
@ -52,6 +62,15 @@ class Script:
|
|||||||
api_info = None
|
api_info = None
|
||||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||||
|
|
||||||
|
on_before_component_elem_id = None
|
||||||
|
"""list of callbacks to be called before a component with an elem_id is created"""
|
||||||
|
|
||||||
|
on_after_component_elem_id = None
|
||||||
|
"""list of callbacks to be called after a component with an elem_id is created"""
|
||||||
|
|
||||||
|
setup_for_ui_only = False
|
||||||
|
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -90,9 +109,16 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup(self, p, *args):
|
||||||
|
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||||
|
args contains all values returned by components from ui().
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def before_process(self, p, *args):
|
def before_process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||||
You can modify the processing object (p) here, inject hooks, etc.
|
You can modify the processing object (p) here, inject hooks, etc.
|
||||||
args contains all values returned by components from ui()
|
args contains all values returned by components from ui()
|
||||||
"""
|
"""
|
||||||
@ -212,6 +238,29 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_before_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
|
||||||
|
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
|
||||||
|
|
||||||
|
This function is an alternative to before_component in that it also cllows to run before a component is created, but
|
||||||
|
it doesn't require to be called for every created component - just for the one you need.
|
||||||
|
"""
|
||||||
|
if self.on_before_component_elem_id is None:
|
||||||
|
self.on_before_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
|
def on_after_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
"""
|
||||||
|
if self.on_after_component_elem_id is None:
|
||||||
|
self.on_after_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
"""unused"""
|
"""unused"""
|
||||||
return ""
|
return ""
|
||||||
@ -220,7 +269,7 @@ class Script:
|
|||||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||||
|
|
||||||
need_tabname = self.show(True) == self.show(False)
|
need_tabname = self.show(True) == self.show(False)
|
||||||
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
tabkind = 'img2img' if self.is_img2img else 'txt2img'
|
||||||
tabname = f"{tabkind}_" if need_tabname else ""
|
tabname = f"{tabkind}_" if need_tabname else ""
|
||||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||||
|
|
||||||
@ -232,6 +281,19 @@ class Script:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptBuiltinUI(Script):
|
||||||
|
setup_for_ui_only = True
|
||||||
|
|
||||||
|
def elem_id(self, item_id):
|
||||||
|
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
|
||||||
|
|
||||||
|
need_tabname = self.show(True) == self.show(False)
|
||||||
|
tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else ""
|
||||||
|
|
||||||
|
return f'{tabname}{item_id}'
|
||||||
|
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
|
||||||
@ -250,7 +312,7 @@ postprocessing_scripts_data = []
|
|||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts_list = []
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
@ -258,8 +320,9 @@ def list_scripts(scriptdirname, extension):
|
|||||||
for filename in sorted(os.listdir(basedir)):
|
for filename in sorted(os.listdir(basedir)):
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||||
|
|
||||||
for ext in extensions.active():
|
if include_extensions:
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
for ext in extensions.active():
|
||||||
|
scripts_list += ext.list_files(scriptdirname, extension)
|
||||||
|
|
||||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
@ -288,7 +351,7 @@ def load_scripts():
|
|||||||
postprocessing_scripts_data.clear()
|
postprocessing_scripts_data.clear()
|
||||||
script_callbacks.clear_callbacks()
|
script_callbacks.clear_callbacks()
|
||||||
|
|
||||||
scripts_list = list_scripts("scripts", ".py")
|
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||||
|
|
||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
@ -349,10 +412,17 @@ class ScriptRunner:
|
|||||||
self.selectable_scripts = []
|
self.selectable_scripts = []
|
||||||
self.alwayson_scripts = []
|
self.alwayson_scripts = []
|
||||||
self.titles = []
|
self.titles = []
|
||||||
|
self.title_map = {}
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
self.paste_field_names = []
|
self.paste_field_names = []
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
|
self.on_before_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
|
self.on_after_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
from modules import scripts_auto_postprocessing
|
from modules import scripts_auto_postprocessing
|
||||||
|
|
||||||
@ -367,6 +437,7 @@ class ScriptRunner:
|
|||||||
script.filename = script_data.path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
script.tabname = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
visibility = script.show(script.is_img2img)
|
visibility = script.show(script.is_img2img)
|
||||||
|
|
||||||
@ -379,6 +450,28 @@ class ScriptRunner:
|
|||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
|
def apply_on_before_component_callbacks(self):
|
||||||
|
for script in self.scripts:
|
||||||
|
on_before = script.on_before_component_elem_id or []
|
||||||
|
on_after = script.on_after_component_elem_id or []
|
||||||
|
|
||||||
|
for elem_id, callback in on_before:
|
||||||
|
if elem_id not in self.on_before_component_elem_id:
|
||||||
|
self.on_before_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
for elem_id, callback in on_after:
|
||||||
|
if elem_id not in self.on_after_component_elem_id:
|
||||||
|
self.on_after_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
on_before.clear()
|
||||||
|
on_after.clear()
|
||||||
|
|
||||||
def create_script_ui(self, script):
|
def create_script_ui(self, script):
|
||||||
import modules.api.models as api_models
|
import modules.api.models as api_models
|
||||||
|
|
||||||
@ -429,15 +522,20 @@ class ScriptRunner:
|
|||||||
if script.alwayson and script.section != section:
|
if script.alwayson and script.section != section:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with gr.Group(visible=script.alwayson) as group:
|
if script.create_group:
|
||||||
self.create_script_ui(script)
|
with gr.Group(visible=script.alwayson) as group:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
script.group = group
|
script.group = group
|
||||||
|
else:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
def prepare_ui(self):
|
def prepare_ui(self):
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
def setup_ui(self):
|
def setup_ui(self):
|
||||||
|
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
|
||||||
|
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
self.setup_ui_for_section(None)
|
self.setup_ui_for_section(None)
|
||||||
@ -484,6 +582,8 @@ class ScriptRunner:
|
|||||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
return self.inputs
|
return self.inputs
|
||||||
|
|
||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
@ -577,6 +677,12 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
@ -584,12 +690,21 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def after_component(self, component, **kwargs):
|
def after_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def script(self, title):
|
||||||
|
return self.title_map.get(title.lower())
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
@ -608,7 +723,6 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_from = args_from
|
self.scripts[si].args_from = args_from
|
||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
|
||||||
def before_hr(self, p):
|
def before_hr(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@ -617,6 +731,17 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def setup_scrips(self, p, *, is_ui=True):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
if not is_ui and script.setup_for_ui_only:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.setup(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img: ScriptRunner = None
|
scripts_txt2img: ScriptRunner = None
|
||||||
scripts_img2img: ScriptRunner = None
|
scripts_img2img: ScriptRunner = None
|
||||||
@ -631,49 +756,3 @@ def reload_script_body_only():
|
|||||||
|
|
||||||
|
|
||||||
reload_scripts = load_scripts # compatibility alias
|
reload_scripts = load_scripts # compatibility alias
|
||||||
|
|
||||||
|
|
||||||
def add_classes_to_gradio_component(comp):
|
|
||||||
"""
|
|
||||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
|
||||||
"""
|
|
||||||
|
|
||||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
|
||||||
|
|
||||||
if getattr(comp, 'multiselect', False):
|
|
||||||
comp.elem_classes.append('multiselect')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def IOComponent_init(self, *args, **kwargs):
|
|
||||||
if scripts_current is not None:
|
|
||||||
scripts_current.before_component(self, **kwargs)
|
|
||||||
|
|
||||||
script_callbacks.before_component_callback(self, **kwargs)
|
|
||||||
|
|
||||||
res = original_IOComponent_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
add_classes_to_gradio_component(self)
|
|
||||||
|
|
||||||
script_callbacks.after_component_callback(self, **kwargs)
|
|
||||||
|
|
||||||
if scripts_current is not None:
|
|
||||||
scripts_current.after_component(self, **kwargs)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
|
||||||
gr.components.IOComponent.__init__ = IOComponent_init
|
|
||||||
|
|
||||||
|
|
||||||
def BlockContext_init(self, *args, **kwargs):
|
|
||||||
res = original_BlockContext_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
add_classes_to_gradio_component(self)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
|
||||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
|
||||||
|
@ -3,8 +3,31 @@ import open_clip
|
|||||||
import torch
|
import torch
|
||||||
import transformers.utils.hub
|
import transformers.utils.hub
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
class DisableInitialization:
|
|
||||||
|
class ReplaceHelper:
|
||||||
|
def __init__(self):
|
||||||
|
self.replaced = []
|
||||||
|
|
||||||
|
def replace(self, obj, field, func):
|
||||||
|
original = getattr(obj, field, None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.replaced.append((obj, field, original))
|
||||||
|
setattr(obj, field, func)
|
||||||
|
|
||||||
|
return original
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
for obj, field, original in self.replaced:
|
||||||
|
setattr(obj, field, original)
|
||||||
|
|
||||||
|
self.replaced.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class DisableInitialization(ReplaceHelper):
|
||||||
"""
|
"""
|
||||||
When an object of this class enters a `with` block, it starts:
|
When an object of this class enters a `with` block, it starts:
|
||||||
- preventing torch's layer initialization functions from working
|
- preventing torch's layer initialization functions from working
|
||||||
@ -21,7 +44,7 @@ class DisableInitialization:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, disable_clip=True):
|
def __init__(self, disable_clip=True):
|
||||||
self.replaced = []
|
super().__init__()
|
||||||
self.disable_clip = disable_clip
|
self.disable_clip = disable_clip
|
||||||
|
|
||||||
def replace(self, obj, field, func):
|
def replace(self, obj, field, func):
|
||||||
@ -86,8 +109,124 @@ class DisableInitialization:
|
|||||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
for obj, field, original in self.replaced:
|
self.restore()
|
||||||
setattr(obj, field, original)
|
|
||||||
|
|
||||||
self.replaced.clear()
|
|
||||||
|
|
||||||
|
class InitializeOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||||
|
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||||
|
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_device(x):
|
||||||
|
x["device"] = "meta"
|
||||||
|
return x
|
||||||
|
|
||||||
|
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||||
|
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||||
|
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||||
|
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStateDictOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||||
|
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||||
|
Meant to be used together with InitializeOnMeta above.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_dict, device, weight_dtype_conversion=None):
|
||||||
|
super().__init__()
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.device = device
|
||||||
|
self.weight_dtype_conversion = weight_dtype_conversion or {}
|
||||||
|
self.default_dtype = self.weight_dtype_conversion.get('')
|
||||||
|
|
||||||
|
def get_weight_dtype(self, key):
|
||||||
|
key_first_term, _ = key.split('.', 1)
|
||||||
|
return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
sd = self.state_dict
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
|
||||||
|
used_param_keys = []
|
||||||
|
|
||||||
|
for name, param in module._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = prefix + name
|
||||||
|
sd_param = sd.pop(key, None)
|
||||||
|
if sd_param is not None:
|
||||||
|
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
|
||||||
|
used_param_keys.append(key)
|
||||||
|
|
||||||
|
if param.is_meta:
|
||||||
|
dtype = sd_param.dtype if sd_param is not None else param.dtype
|
||||||
|
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
|
||||||
|
|
||||||
|
for name in module._buffers:
|
||||||
|
key = prefix + name
|
||||||
|
|
||||||
|
sd_param = sd.pop(key, None)
|
||||||
|
if sd_param is not None:
|
||||||
|
state_dict[key] = sd_param
|
||||||
|
used_param_keys.append(key)
|
||||||
|
|
||||||
|
original(module, state_dict, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
for key in used_param_keys:
|
||||||
|
state_dict.pop(key, None)
|
||||||
|
|
||||||
|
def load_state_dict(original, module, state_dict, strict=True):
|
||||||
|
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
|
||||||
|
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
|
||||||
|
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
|
||||||
|
|
||||||
|
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
|
||||||
|
|
||||||
|
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
|
||||||
|
the function and does not call the original) the state dict will just fail to load because weights
|
||||||
|
would be on the meta device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if state_dict == sd:
|
||||||
|
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
original(module, state_dict, strict=strict)
|
||||||
|
|
||||||
|
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
|
||||||
|
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
|
||||||
|
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||||
|
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||||
|
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||||
|
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
|
||||||
|
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
@ -2,7 +2,6 @@ import torch
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
@ -30,8 +29,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
|||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
|
||||||
# silence new console spam from SD2
|
# silence new console spam from SD2
|
||||||
ldm.modules.attention.print = lambda *args: None
|
ldm.modules.attention.print = shared.ldm_print
|
||||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
ldm.util.print = shared.ldm_print
|
||||||
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
@ -164,12 +165,13 @@ class StableDiffusionModelHijack:
|
|||||||
clip = None
|
clip = None
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
|
||||||
self.extra_generation_params = {}
|
self.extra_generation_params = {}
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
|
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def apply_optimizations(self, option=None):
|
def apply_optimizations(self, option=None):
|
||||||
@ -197,7 +199,7 @@ class StableDiffusionModelHijack:
|
|||||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||||
text_cond_models.append(conditioner.embedders[i])
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||||
text_cond_models.append(conditioner.embedders[i])
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
|
||||||
@ -243,7 +245,21 @@ class StableDiffusionModelHijack:
|
|||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
|
if conditioner:
|
||||||
|
for i in range(len(conditioner.embedders)):
|
||||||
|
embedder = conditioner.embedders[i]
|
||||||
|
if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
|
||||||
|
embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
|
||||||
|
conditioner.embedders[i] = embedder.wrapped
|
||||||
|
if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
|
||||||
|
embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
|
||||||
|
conditioner.embedders[i] = embedder.wrapped
|
||||||
|
|
||||||
|
if hasattr(m, 'cond_stage_model'):
|
||||||
|
delattr(m, 'cond_stage_model')
|
||||||
|
|
||||||
|
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
@ -292,10 +308,11 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
self.textual_inversion_key = textual_inversion_key
|
||||||
|
|
||||||
def forward(self, input_ids):
|
def forward(self, input_ids):
|
||||||
batch_fixes = self.embeddings.fixes
|
batch_fixes = self.embeddings.fixes
|
||||||
@ -309,7 +326,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||||||
vecs = []
|
vecs = []
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = devices.cond_cast_unet(embedding.vec)
|
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
emb = devices.cond_cast_unet(vec)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
position += 1
|
position += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
emb_len = int(embedding.vectors)
|
||||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
next_chunk()
|
next_chunk()
|
||||||
|
|
||||||
@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
hashes.append(f"{name}: {shorthash}")
|
hashes.append(f"{name}: {shorthash}")
|
||||||
|
|
||||||
if hashes:
|
if hashes:
|
||||||
|
if self.hijack.extra_generation_params.get("TI hashes"):
|
||||||
|
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||||
|
|
||||||
if getattr(self.wrapped, 'return_pooled', False):
|
if getattr(self.wrapped, 'return_pooled', False):
|
||||||
|
@ -1,97 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
from ldm.models.diffusion.ddim import noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
def get_model_output(x, t):
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
|
|
||||||
if isinstance(c, dict):
|
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
|
||||||
c_in = {}
|
|
||||||
for k in c:
|
|
||||||
if isinstance(c[k], list):
|
|
||||||
c_in[k] = [
|
|
||||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
|
||||||
for i in range(len(c[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
|
||||||
else:
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
if len(old_eps) == 0:
|
|
||||||
# Pseudo Improved Euler (2nd order)
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
|
||||||
elif len(old_eps) == 1:
|
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
|
||||||
elif len(old_eps) == 2:
|
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
||||||
elif len(old_eps) >= 3:
|
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import psutil
|
import psutil
|
||||||
|
import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
name = "sub-quadratic"
|
name = "sub-quadratic"
|
||||||
cmd_opt = "opt_sub_quad_attention"
|
cmd_opt = "opt_sub_quad_attention"
|
||||||
priority = 10
|
|
||||||
|
@property
|
||||||
|
def priority(self):
|
||||||
|
return 1000 if shared.device.type == 'mps' else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self):
|
def priority(self):
|
||||||
return 1000 if not torch.cuda.is_available() else 10
|
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
if chunk_threshold is None:
|
if chunk_threshold is None:
|
||||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
if q.device.type == 'mps':
|
||||||
|
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||||
elif chunk_threshold == 0:
|
elif chunk_threshold == 0:
|
||||||
chunk_threshold_bytes = None
|
chunk_threshold_bytes = None
|
||||||
else:
|
else:
|
||||||
|
@ -14,8 +14,7 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
@ -28,11 +27,31 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
|
|||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def replace_key(d, key, new_key, value):
|
||||||
|
keys = list(d.keys())
|
||||||
|
|
||||||
|
d[new_key] = value
|
||||||
|
|
||||||
|
if key not in keys:
|
||||||
|
return d
|
||||||
|
|
||||||
|
index = keys.index(key)
|
||||||
|
keys[index] = new_key
|
||||||
|
|
||||||
|
new_d = {k: d[k] for k in keys}
|
||||||
|
|
||||||
|
d.clear()
|
||||||
|
d.update(new_d)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
class CheckpointInfo:
|
class CheckpointInfo:
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
@ -43,6 +62,19 @@ class CheckpointInfo:
|
|||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = read_metadata_from_safetensors(filename)
|
||||||
|
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading metadata for {filename}")
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
@ -52,17 +84,11 @@ class CheckpointInfo:
|
|||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
|
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
|
||||||
|
if self.shorthash:
|
||||||
self.metadata = {}
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||||
|
|
||||||
_, ext = os.path.splitext(self.filename)
|
|
||||||
if ext.lower() == ".safetensors":
|
|
||||||
try:
|
|
||||||
self.metadata = read_metadata_from_safetensors(filename)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
@ -74,13 +100,20 @@ class CheckpointInfo:
|
|||||||
if self.sha256 is None:
|
if self.sha256 is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.shorthash = self.sha256[0:10]
|
shorthash = self.sha256[0:10]
|
||||||
|
if self.shorthash == self.sha256[0:10]:
|
||||||
|
return self.shorthash
|
||||||
|
|
||||||
|
self.shorthash = shorthash
|
||||||
|
|
||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
|
||||||
|
|
||||||
checkpoints_list.pop(self.title)
|
old_title = self.title
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
|
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
|
|
||||||
|
replace_key(checkpoints_list, old_title, self.title, self)
|
||||||
self.register()
|
self.register()
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
@ -101,14 +134,8 @@ def setup_model():
|
|||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles(use_short=False):
|
||||||
def convert(name):
|
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
||||||
return int(name) if name.isdigit() else name.lower()
|
|
||||||
|
|
||||||
def alphanumeric_key(key):
|
|
||||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
|
||||||
|
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
@ -131,12 +158,18 @@ def list_models():
|
|||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
|
||||||
for filename in sorted(model_list, key=str.lower):
|
for filename in model_list:
|
||||||
checkpoint_info = CheckpointInfo(filename)
|
checkpoint_info = CheckpointInfo(filename)
|
||||||
checkpoint_info.register()
|
checkpoint_info.register()
|
||||||
|
|
||||||
|
|
||||||
|
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
|
if not search_string:
|
||||||
|
return None
|
||||||
|
|
||||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -145,6 +178,11 @@ def get_closet_checkpoint_match(search_string):
|
|||||||
if found:
|
if found:
|
||||||
return found[0]
|
return found[0]
|
||||||
|
|
||||||
|
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||||
|
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||||
|
if found:
|
||||||
|
return found[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -280,11 +318,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class SkipWritingToConfig:
|
||||||
|
"""This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
|
||||||
|
|
||||||
|
skip = False
|
||||||
|
previous = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.previous = SkipWritingToConfig.skip
|
||||||
|
SkipWritingToConfig.skip = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||||
|
SkipWritingToConfig.skip = self.previous
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
if not SkipWritingToConfig.skip:
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
@ -297,18 +351,23 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
sd_models_xl.extend_sdxl(model)
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
del state_dict
|
|
||||||
timer.record("apply weights to model")
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_info] = state_dict
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
timer.record("apply channels_last")
|
timer.record("apply channels_last")
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
|
model.float()
|
||||||
|
devices.dtype_unet = torch.float32
|
||||||
|
timer.record("apply float()")
|
||||||
|
else:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
depth_model = getattr(model, 'depth_model', None)
|
depth_model = getattr(model, 'depth_model', None)
|
||||||
|
|
||||||
@ -324,9 +383,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if depth_model:
|
if depth_model:
|
||||||
model.depth_model = depth_model
|
model.depth_model = depth_model
|
||||||
|
|
||||||
|
devices.dtype_unet = torch.float16
|
||||||
timer.record("apply half()")
|
timer.record("apply half()")
|
||||||
|
|
||||||
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
@ -346,7 +405,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
timer.record("load VAE")
|
timer.record("load VAE")
|
||||||
|
|
||||||
@ -423,6 +482,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
|||||||
class SdModelData:
|
class SdModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
|
self.loaded_sd_models = []
|
||||||
self.was_loaded_at_least_once = False
|
self.was_loaded_at_least_once = False
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@ -437,6 +497,7 @@ class SdModelData:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
@ -445,14 +506,30 @@ class SdModelData:
|
|||||||
|
|
||||||
return self.sd_model
|
return self.sd_model
|
||||||
|
|
||||||
def set_sd_model(self, v):
|
def set_sd_model(self, v, already_loaded=False):
|
||||||
self.sd_model = v
|
self.sd_model = v
|
||||||
|
if already_loaded:
|
||||||
|
sd_vae.base_vae = getattr(v, "base_vae", None)
|
||||||
|
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
|
||||||
|
sd_vae.checkpoint_info = v.sd_checkpoint_info
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.loaded_sd_models.remove(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if v is not None:
|
||||||
|
self.loaded_sd_models.insert(0, v)
|
||||||
|
|
||||||
|
|
||||||
model_data = SdModelData()
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
def get_empty_cond(sd_model):
|
def get_empty_cond(sd_model):
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img()
|
||||||
|
extra_networks.activate(p, {})
|
||||||
|
|
||||||
if hasattr(sd_model, 'conditioner'):
|
if hasattr(sd_model, 'conditioner'):
|
||||||
d = sd_model.get_learned_conditioning([""])
|
d = sd_model.get_learned_conditioning([""])
|
||||||
return d['crossattn']
|
return d['crossattn']
|
||||||
@ -460,20 +537,46 @@ def get_empty_cond(sd_model):
|
|||||||
return sd_model.cond_stage_model([""])
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_cpu(m):
|
||||||
|
if m.lowvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
m.to(devices.cpu)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def model_target_device(m):
|
||||||
|
if lowvram.is_needed(m):
|
||||||
|
return devices.cpu
|
||||||
|
else:
|
||||||
|
return devices.device
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_device(m):
|
||||||
|
lowvram.apply(m)
|
||||||
|
|
||||||
|
if not m.lowvram:
|
||||||
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_trash(m):
|
||||||
|
m.to(device="meta")
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
send_model_to_trash(model_data.sd_model)
|
||||||
model_data.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
do_inpainting_hijack()
|
timer.record("unload existing model")
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
if already_loaded_state_dict is not None:
|
if already_loaded_state_dict is not None:
|
||||||
state_dict = already_loaded_state_dict
|
state_dict = already_loaded_state_dict
|
||||||
@ -495,25 +598,35 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
sd_model = None
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
except Exception:
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
pass
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "creating model quickly", full_traceback=True)
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
sd_model.used_config = checkpoint_config
|
sd_model.used_config = checkpoint_config
|
||||||
|
|
||||||
timer.record("create model")
|
timer.record("create model")
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
if shared.cmd_opts.no_half:
|
||||||
|
weight_dtype_conversion = None
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
|
||||||
else:
|
else:
|
||||||
sd_model.to(shared.device)
|
weight_dtype_conversion = {
|
||||||
|
'first_stage_model': None,
|
||||||
|
'': torch.float16,
|
||||||
|
}
|
||||||
|
|
||||||
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||||
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
|
send_model_to_device(sd_model)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
@ -521,7 +634,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
model_data.sd_model = sd_model
|
model_data.set_sd_model(sd_model)
|
||||||
model_data.was_loaded_at_least_once = True
|
model_data.was_loaded_at_least_once = True
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
@ -542,10 +655,70 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||||
|
"""
|
||||||
|
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||||
|
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||||
|
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||||
|
If no such model exists, returns None.
|
||||||
|
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||||
|
"""
|
||||||
|
|
||||||
|
already_loaded = None
|
||||||
|
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||||
|
loaded_model = model_data.loaded_sd_models[i]
|
||||||
|
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
already_loaded = loaded_model
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||||
|
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||||
|
model_data.loaded_sd_models.pop()
|
||||||
|
send_model_to_trash(loaded_model)
|
||||||
|
timer.record("send model to trash")
|
||||||
|
|
||||||
|
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
|
timer.record("send model to cpu")
|
||||||
|
|
||||||
|
if already_loaded is not None:
|
||||||
|
send_model_to_device(already_loaded)
|
||||||
|
timer.record("send model to device")
|
||||||
|
|
||||||
|
model_data.set_sd_model(already_loaded, already_loaded=True)
|
||||||
|
|
||||||
|
if not SkipWritingToConfig.skip:
|
||||||
|
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
|
||||||
|
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
|
||||||
|
|
||||||
|
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||||
|
sd_vae.reload_vae_weights(already_loaded)
|
||||||
|
return model_data.sd_model
|
||||||
|
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||||
|
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||||
|
|
||||||
|
model_data.sd_model = None
|
||||||
|
load_model(checkpoint_info)
|
||||||
|
return model_data.sd_model
|
||||||
|
elif len(model_data.loaded_sd_models) > 0:
|
||||||
|
sd_model = model_data.loaded_sd_models.pop()
|
||||||
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
|
sd_vae.base_vae = getattr(sd_model, "base_vae", None)
|
||||||
|
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
|
||||||
|
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
|
|
||||||
|
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||||
|
return sd_model
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = model_data.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
@ -554,19 +727,17 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
else:
|
else:
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return sd_model
|
||||||
|
|
||||||
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
|
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
return sd_model
|
||||||
|
|
||||||
|
if sd_model is not None:
|
||||||
sd_unet.apply_unet("None")
|
sd_unet.apply_unet("None")
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
|
||||||
else:
|
|
||||||
sd_model.to(devices.cpu)
|
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
@ -574,7 +745,9 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
if sd_model is not None:
|
||||||
|
send_model_to_trash(sd_model)
|
||||||
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
@ -591,17 +764,19 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
timer.record("script callbacks")
|
timer.record("script callbacks")
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
|
model_data.set_sd_model(sd_model)
|
||||||
|
sd_unet.apply_unet()
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import devices, sd_hijack
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import shared, paths, sd_disable_initialization
|
from modules import shared, paths, sd_disable_initialization, devices
|
||||||
|
|
||||||
sd_configs_path = shared.sd_configs_path
|
sd_configs_path = shared.sd_configs_path
|
||||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
from modules import devices
|
|
||||||
|
|
||||||
device = devices.cpu
|
device = devices.cpu
|
||||||
|
|
||||||
|
31
modules/sd_models_types.py
Normal file
31
modules/sd_models_types.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from modules.sd_models import CheckpointInfo
|
||||||
|
|
||||||
|
|
||||||
|
class WebuiSdModel(LatentDiffusion):
|
||||||
|
"""This class is not actually instantinated, but its fields are created and fieeld by webui"""
|
||||||
|
|
||||||
|
lowvram: bool
|
||||||
|
"""True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info"""
|
||||||
|
|
||||||
|
sd_model_hash: str
|
||||||
|
"""short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used"""
|
||||||
|
|
||||||
|
sd_model_checkpoint: str
|
||||||
|
"""path to the file on disk that model weights were obtained from"""
|
||||||
|
|
||||||
|
sd_checkpoint_info: 'CheckpointInfo'
|
||||||
|
"""structure with additional information about the file with model's weights"""
|
||||||
|
|
||||||
|
is_sdxl: bool
|
||||||
|
"""True if the model's architecture is SDXL"""
|
||||||
|
|
||||||
|
is_sd2: bool
|
||||||
|
"""True if the model's architecture is SD 2.x"""
|
||||||
|
|
||||||
|
is_sd1: bool
|
||||||
|
"""True if the model's architecture is SD 1.x"""
|
@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
|
|||||||
return torch.cat(res, dim=1)
|
return torch.cat(res, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||||
|
return embedder.tokenize(texts)
|
||||||
|
|
||||||
|
raise AssertionError('no tokenizer available')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def process_texts(self, texts):
|
def process_texts(self, texts):
|
||||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
return embedder.process_texts(texts)
|
return embedder.process_texts(texts)
|
||||||
@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
|
|||||||
|
|
||||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
|
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
|
||||||
@ -89,10 +98,10 @@ def extend_sdxl(model):
|
|||||||
model.conditioner.wrapped = torch.nn.Module()
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
|
||||||
sgm.modules.attention.print = lambda *args: None
|
sgm.modules.attention.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
sgm.modules.encoders.modules.print = lambda *args: None
|
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
|
||||||
# this gets the code to load the vanilla attention that we override
|
# this gets the code to load the vanilla attention that we override
|
||||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
*sd_samplers_compvis.samplers_data_compvis,
|
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img = []
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
samplers_hidden = {}
|
||||||
|
|
||||||
|
|
||||||
def find_sampler_config(name):
|
def find_sampler_config(name):
|
||||||
@ -38,13 +39,11 @@ def create_sampler(name, model):
|
|||||||
|
|
||||||
|
|
||||||
def set_samplers():
|
def set_samplers():
|
||||||
global samplers, samplers_for_img2img
|
global samplers, samplers_for_img2img, samplers_hidden
|
||||||
|
|
||||||
hidden = set(shared.opts.hide_samplers)
|
samplers_hidden = set(shared.opts.hide_samplers)
|
||||||
hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC'])
|
samplers = all_samplers
|
||||||
|
samplers_for_img2img = all_samplers
|
||||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
|
||||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
|
||||||
|
|
||||||
samplers_map.clear()
|
samplers_map.clear()
|
||||||
for sampler in all_samplers:
|
for sampler in all_samplers:
|
||||||
@ -53,4 +52,8 @@ def set_samplers():
|
|||||||
samplers_map[alias.lower()] = sampler.name
|
samplers_map[alias.lower()] = sampler.name
|
||||||
|
|
||||||
|
|
||||||
|
def visible_sampler_names():
|
||||||
|
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
230
modules/sd_samplers_cfg_denoiser.py
Normal file
230
modules/sd_samplers_cfg_denoiser.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import torch
|
||||||
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
|
from modules.shared import opts, state
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
|
|
||||||
|
|
||||||
|
def catenate_conds(conds):
|
||||||
|
if not isinstance(conds[0], dict):
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
|
def subscript_cond(cond, a, b):
|
||||||
|
if not isinstance(cond, dict):
|
||||||
|
return cond[a:b]
|
||||||
|
|
||||||
|
return {key: vec[a:b] for key, vec in cond.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def pad_cond(tensor, repeats, empty):
|
||||||
|
if not isinstance(tensor, dict):
|
||||||
|
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||||
|
|
||||||
|
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiser(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||||
|
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||||
|
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||||
|
negative prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
super().__init__()
|
||||||
|
self.model_wrap = None
|
||||||
|
self.mask = None
|
||||||
|
self.nmask = None
|
||||||
|
self.init_latent = None
|
||||||
|
self.steps = None
|
||||||
|
"""number of steps as specified by user in UI"""
|
||||||
|
|
||||||
|
self.total_steps = None
|
||||||
|
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||||
|
|
||||||
|
self.step = 0
|
||||||
|
self.image_cfg_scale = None
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
self.sampler = sampler
|
||||||
|
self.model_wrap = None
|
||||||
|
self.p = None
|
||||||
|
self.mask_before_denoising = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||||
|
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||||
|
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
return x_out
|
||||||
|
|
||||||
|
def update_inner_model(self):
|
||||||
|
self.model_wrap = None
|
||||||
|
|
||||||
|
c, uc = self.p.get_conds()
|
||||||
|
self.sampler.sampler_extra_args['cond'] = c
|
||||||
|
self.sampler.sampler_extra_args['uncond'] = uc
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
|
if state.interrupted or state.skipped:
|
||||||
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
|
if sd_samplers_common.apply_refiner(self):
|
||||||
|
cond = self.sampler.sampler_extra_args['cond']
|
||||||
|
uncond = self.sampler.sampler_extra_args['uncond']
|
||||||
|
|
||||||
|
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
||||||
|
# so is_edit_model is set to False to support AND composition.
|
||||||
|
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
||||||
|
|
||||||
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
|
x = self.init_latent * self.mask + self.nmask * x
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||||
|
else:
|
||||||
|
image_uncond = image_cond
|
||||||
|
if isinstance(uncond, dict):
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||||
|
else:
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
||||||
|
cfg_denoiser_callback(denoiser_params)
|
||||||
|
x_in = denoiser_params.x
|
||||||
|
image_cond_in = denoiser_params.image_cond
|
||||||
|
sigma_in = denoiser_params.sigma
|
||||||
|
tensor = denoiser_params.text_cond
|
||||||
|
uncond = denoiser_params.text_uncond
|
||||||
|
skip_uncond = False
|
||||||
|
|
||||||
|
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||||
|
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||||
|
skip_uncond = True
|
||||||
|
x_in = x_in[:-batch_size]
|
||||||
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
|
self.padded_cond_uncond = False
|
||||||
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
|
if num_repeats < 0:
|
||||||
|
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
elif num_repeats > 0:
|
||||||
|
uncond = pad_cond(uncond, num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
|
if is_edit_model:
|
||||||
|
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||||
|
elif skip_uncond:
|
||||||
|
cond_in = tensor
|
||||||
|
else:
|
||||||
|
cond_in = catenate_conds([tensor, uncond])
|
||||||
|
|
||||||
|
if shared.opts.batch_cond_uncond:
|
||||||
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||||
|
else:
|
||||||
|
x_out = torch.zeros_like(x_in)
|
||||||
|
batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
|
||||||
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
|
a = batch_offset
|
||||||
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
|
|
||||||
|
if not is_edit_model:
|
||||||
|
c_crossattn = subscript_cond(tensor, a, b)
|
||||||
|
else:
|
||||||
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
|
if not skip_uncond:
|
||||||
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
|
if skip_uncond:
|
||||||
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
|
devices.test_for_nans(x_out, "unet")
|
||||||
|
|
||||||
|
if is_edit_model:
|
||||||
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
|
elif skip_uncond:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||||
|
else:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
if opts.live_preview_content == "Prompt":
|
||||||
|
preview = self.sampler.last_latent
|
||||||
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
|
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||||
|
else:
|
||||||
|
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
sd_samplers_common.store_latent(preview)
|
||||||
|
|
||||||
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||||
|
denoised = after_cfg_callback_params.x
|
||||||
|
|
||||||
|
self.step += 1
|
||||||
|
return denoised
|
||||||
|
|
@ -1,13 +1,22 @@
|
|||||||
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import k_diffusion.sampling
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
|
||||||
|
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerData(SamplerDataTuple):
|
||||||
|
def total_steps(self, steps):
|
||||||
|
if self.options.get("second_order", False):
|
||||||
|
steps = steps * 2
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
@ -25,19 +34,34 @@ def setup_img2img_steps(p, steps=None):
|
|||||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||||
|
|
||||||
|
|
||||||
def single_sample_to_image(sample, approximation=None):
|
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||||
if approximation is None:
|
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
|
||||||
|
|
||||||
|
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||||
|
|
||||||
|
from modules import lowvram
|
||||||
|
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
||||||
|
approximation = 1
|
||||||
|
|
||||||
if approximation == 2:
|
if approximation == 2:
|
||||||
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
x_sample = sd_vae_approx.cheap_approximation(sample)
|
||||||
elif approximation == 1:
|
elif approximation == 1:
|
||||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
||||||
elif approximation == 3:
|
elif approximation == 3:
|
||||||
x_sample = sample * 1.5
|
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
x_sample = x_sample * 2 - 1
|
||||||
else:
|
else:
|
||||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
if model is None:
|
||||||
|
model = shared.sd_model
|
||||||
|
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||||
|
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||||
|
|
||||||
|
return x_sample
|
||||||
|
|
||||||
|
|
||||||
|
def single_sample_to_image(sample, approximation=None):
|
||||||
|
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
||||||
|
|
||||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
@ -46,6 +70,12 @@ def single_sample_to_image(sample, approximation=None):
|
|||||||
return Image.fromarray(x_sample)
|
return Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_first_stage(model, x):
|
||||||
|
x = x.to(devices.dtype_vae)
|
||||||
|
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||||
|
return samples_to_images_tensor(x, approx_index, model)
|
||||||
|
|
||||||
|
|
||||||
def sample_to_image(samples, index=0, approximation=None):
|
def sample_to_image(samples, index=0, approximation=None):
|
||||||
return single_sample_to_image(samples[index], approximation)
|
return single_sample_to_image(samples[index], approximation)
|
||||||
|
|
||||||
@ -54,6 +84,34 @@ def samples_to_image_grid(samples, approximation=None):
|
|||||||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||||
|
|
||||||
|
|
||||||
|
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||||
|
'''image[0, 1] -> latent'''
|
||||||
|
if approximation is None:
|
||||||
|
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
|
||||||
|
|
||||||
|
if approximation == 3:
|
||||||
|
image = image.to(devices.device, devices.dtype)
|
||||||
|
x_latent = sd_vae_taesd.encoder_model()(image)
|
||||||
|
else:
|
||||||
|
if model is None:
|
||||||
|
model = shared.sd_model
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||||
|
image = image * 2 - 1
|
||||||
|
if len(image) > 1:
|
||||||
|
x_latent = torch.stack([
|
||||||
|
model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||||
|
)[0]
|
||||||
|
for img in image
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||||
|
|
||||||
|
return x_latent
|
||||||
|
|
||||||
|
|
||||||
def store_latent(decoded):
|
def store_latent(decoded):
|
||||||
state.current_latent = decoded
|
state.current_latent = decoded
|
||||||
|
|
||||||
@ -85,11 +143,195 @@ class InterruptedException(BaseException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU":
|
def replace_torchsde_browinan():
|
||||||
import torchsde._brownian.brownian_interval
|
import torchsde._brownian.brownian_interval
|
||||||
|
|
||||||
def torchsde_randn(size, dtype, device, seed):
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
||||||
|
|
||||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
|
||||||
|
|
||||||
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_refiner(cfg_denoiser):
|
||||||
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
|
|
||||||
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if getattr(cfg_denoiser.p, "enable_hr", False):
|
||||||
|
is_second_pass = cfg_denoiser.p.is_hr_pass
|
||||||
|
|
||||||
|
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if opts.hires_fix_refiner_pass != "second pass":
|
||||||
|
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
||||||
|
|
||||||
|
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||||
|
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||||
|
|
||||||
|
with sd_models.SkipWritingToConfig():
|
||||||
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
cfg_denoiser.p.setup_conds()
|
||||||
|
cfg_denoiser.update_inner_model()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TorchHijack:
|
||||||
|
"""This is here to replace torch.randn_like of k-diffusion.
|
||||||
|
|
||||||
|
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
||||||
|
this is needed to properly replace every use of torch.randn_like.
|
||||||
|
|
||||||
|
We need to replace to make images generated in batches to be same as images generated individually."""
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
self.rng = p.rng
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if item == 'randn_like':
|
||||||
|
return self.randn_like
|
||||||
|
|
||||||
|
if hasattr(torch, item):
|
||||||
|
return getattr(torch, item)
|
||||||
|
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
|
def randn_like(self, x):
|
||||||
|
return self.rng.next()
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def __init__(self, funcname):
|
||||||
|
self.funcname = funcname
|
||||||
|
self.func = funcname
|
||||||
|
self.extra_params = []
|
||||||
|
self.sampler_noises = None
|
||||||
|
self.stop_at = None
|
||||||
|
self.eta = None
|
||||||
|
self.config: SamplerData = None # set by the function calling the constructor
|
||||||
|
self.last_latent = None
|
||||||
|
self.s_min_uncond = None
|
||||||
|
self.s_churn = 0.0
|
||||||
|
self.s_tmin = 0.0
|
||||||
|
self.s_tmax = float('inf')
|
||||||
|
self.s_noise = 1.0
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ancestral'
|
||||||
|
self.eta_infotext_field = 'Eta'
|
||||||
|
self.eta_default = 1.0
|
||||||
|
|
||||||
|
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||||
|
|
||||||
|
self.p = None
|
||||||
|
self.model_wrap_cfg = None
|
||||||
|
self.sampler_extra_args = None
|
||||||
|
self.options = {}
|
||||||
|
|
||||||
|
def callback_state(self, d):
|
||||||
|
step = d['i']
|
||||||
|
|
||||||
|
if self.stop_at is not None and step > self.stop_at:
|
||||||
|
raise InterruptedException
|
||||||
|
|
||||||
|
state.sampling_step = step
|
||||||
|
shared.total_tqdm.update()
|
||||||
|
|
||||||
|
def launch_sampling(self, steps, func):
|
||||||
|
self.model_wrap_cfg.steps = steps
|
||||||
|
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||||
|
state.sampling_steps = steps
|
||||||
|
state.sampling_step = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func()
|
||||||
|
except RecursionError:
|
||||||
|
print(
|
||||||
|
'Encountered RecursionError during sampling, returning last latent. '
|
||||||
|
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||||
|
'You should try to use a smaller rho value instead.'
|
||||||
|
)
|
||||||
|
return self.last_latent
|
||||||
|
except InterruptedException:
|
||||||
|
return self.last_latent
|
||||||
|
|
||||||
|
def number_of_needed_noises(self, p):
|
||||||
|
return p.steps
|
||||||
|
|
||||||
|
def initialize(self, p) -> dict:
|
||||||
|
self.p = p
|
||||||
|
self.model_wrap_cfg.p = p
|
||||||
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
self.model_wrap_cfg.step = 0
|
||||||
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||||
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||||
|
|
||||||
|
k_diffusion.sampling.torch = TorchHijack(p)
|
||||||
|
|
||||||
|
extra_params_kwargs = {}
|
||||||
|
for param_name in self.extra_params:
|
||||||
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||||
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||||
|
|
||||||
|
if 'eta' in inspect.signature(self.func).parameters:
|
||||||
|
if self.eta != self.eta_default:
|
||||||
|
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||||
|
|
||||||
|
extra_params_kwargs['eta'] = self.eta
|
||||||
|
|
||||||
|
if len(self.extra_params) > 0:
|
||||||
|
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||||
|
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||||
|
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||||
|
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||||
|
|
||||||
|
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||||
|
extra_params_kwargs['s_churn'] = s_churn
|
||||||
|
p.s_churn = s_churn
|
||||||
|
p.extra_generation_params['Sigma churn'] = s_churn
|
||||||
|
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||||
|
extra_params_kwargs['s_tmin'] = s_tmin
|
||||||
|
p.s_tmin = s_tmin
|
||||||
|
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||||
|
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||||
|
extra_params_kwargs['s_tmax'] = s_tmax
|
||||||
|
p.s_tmax = s_tmax
|
||||||
|
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||||
|
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||||
|
extra_params_kwargs['s_noise'] = s_noise
|
||||||
|
p.s_noise = s_noise
|
||||||
|
p.extra_generation_params['Sigma noise'] = s_noise
|
||||||
|
|
||||||
|
return extra_params_kwargs
|
||||||
|
|
||||||
|
def create_noise_sampler(self, x, sigmas, p):
|
||||||
|
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||||
|
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||||
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
@ -1,224 +0,0 @@
|
|||||||
import math
|
|
||||||
import ldm.models.diffusion.ddim
|
|
||||||
import ldm.models.diffusion.plms
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from modules.shared import state
|
|
||||||
from modules import sd_samplers_common, prompt_parser, shared
|
|
||||||
import modules.models.diffusion.uni_pc
|
|
||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
|
||||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class VanillaStableDiffusionSampler:
|
|
||||||
def __init__(self, constructor, sd_model):
|
|
||||||
self.sampler = constructor(sd_model)
|
|
||||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
|
||||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
|
||||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
|
||||||
self.orig_p_sample_ddim = None
|
|
||||||
if self.is_plms:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
|
||||||
elif self.is_ddim:
|
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.step = 0
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None
|
|
||||||
self.last_latent = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
if self.stop_at is not None and self.step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
|
||||||
image_conditioning = None
|
|
||||||
uc_image_conditioning = None
|
|
||||||
if isinstance(cond, dict):
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
image_conditioning = cond["c_adm"]
|
|
||||||
uc_image_conditioning = unconditional_conditioning["c_adm"]
|
|
||||||
else:
|
|
||||||
image_conditioning = cond["c_concat"][0]
|
|
||||||
cond = cond["c_crossattn"][0]
|
|
||||||
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
|
||||||
|
|
||||||
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
|
||||||
cond = tensor
|
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
|
||||||
# filling unconditional_conditioning with repeats of the last vector to match length is
|
|
||||||
# not 100% correct but should work well enough
|
|
||||||
if unconditional_conditioning.shape[1] < cond.shape[1]:
|
|
||||||
last_vector = unconditional_conditioning[:, -1:]
|
|
||||||
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
|
|
||||||
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
|
|
||||||
elif unconditional_conditioning.shape[1] > cond.shape[1]:
|
|
||||||
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
|
||||||
x = img_orig * self.mask + self.nmask * x
|
|
||||||
|
|
||||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
|
||||||
# Note that they need to be lists because it just concatenates them later.
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
return x, ts, cond, unconditional_conditioning
|
|
||||||
|
|
||||||
def update_step(self, last_latent):
|
|
||||||
if self.mask is not None:
|
|
||||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
|
||||||
else:
|
|
||||||
self.last_latent = last_latent
|
|
||||||
|
|
||||||
sd_samplers_common.store_latent(self.last_latent)
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
state.sampling_step = self.step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def after_sample(self, x, ts, cond, uncond, res):
|
|
||||||
if not self.is_unipc:
|
|
||||||
self.update_step(res[1])
|
|
||||||
|
|
||||||
return x, ts, cond, uncond, res
|
|
||||||
|
|
||||||
def unipc_after_update(self, x, model_x):
|
|
||||||
self.update_step(x)
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
if self.is_ddim:
|
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
|
||||||
else:
|
|
||||||
self.eta = 0.0
|
|
||||||
|
|
||||||
if self.eta != 0.0:
|
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
|
||||||
|
|
||||||
if self.is_unipc:
|
|
||||||
keys = [
|
|
||||||
('UniPC variant', 'uni_pc_variant'),
|
|
||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
|
||||||
('UniPC order', 'uni_pc_order'),
|
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, key in keys:
|
|
||||||
v = getattr(shared.opts, key)
|
|
||||||
if v != shared.opts.get_default(key):
|
|
||||||
p.extra_generation_params[name] = v
|
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
|
||||||
if hasattr(self.sampler, fieldname):
|
|
||||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
|
||||||
if self.is_unipc:
|
|
||||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
|
||||||
|
|
||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
|
||||||
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
|
|
||||||
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
|
|
||||||
num_steps = shared.opts.uni_pc_order
|
|
||||||
valid_step = 999 / (1000 // num_steps)
|
|
||||||
if valid_step == math.floor(valid_step):
|
|
||||||
return int(valid_step) + 1
|
|
||||||
|
|
||||||
return num_steps
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps)
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
|
|
||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
|
||||||
|
|
||||||
self.init_latent = x
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
|
|
||||||
else:
|
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
||||||
self.initialize(p)
|
|
||||||
|
|
||||||
self.init_latent = None
|
|
||||||
self.last_latent = x
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
|
||||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
|
||||||
if image_conditioning is not None:
|
|
||||||
if self.conditioning_key == "crossattn-adm":
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
|
|
||||||
else:
|
|
||||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
|
||||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
|
||||||
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
|
||||||
|
|
||||||
return samples_ddim
|
|
74
modules/sd_samplers_extra.py
Normal file
74
modules/sd_samplers_extra.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import k_diffusion.sampling
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
|
||||||
|
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
|
||||||
|
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
|
||||||
|
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
step_id = 0
|
||||||
|
from k_diffusion.sampling import to_d, get_sigmas_karras
|
||||||
|
|
||||||
|
def heun_step(x, old_sigma, new_sigma, second_order=True):
|
||||||
|
nonlocal step_id
|
||||||
|
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||||
|
d = to_d(x, old_sigma, denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||||
|
dt = new_sigma - old_sigma
|
||||||
|
if new_sigma == 0 or not second_order:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Heun's method
|
||||||
|
x_2 = x + d * dt
|
||||||
|
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||||
|
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||||
|
d_prime = (d + d_2) / 2
|
||||||
|
x = x + d_prime * dt
|
||||||
|
step_id += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
steps = sigmas.shape[0] - 1
|
||||||
|
if restart_list is None:
|
||||||
|
if steps >= 20:
|
||||||
|
restart_steps = 9
|
||||||
|
restart_times = 1
|
||||||
|
if steps >= 36:
|
||||||
|
restart_steps = steps // 4
|
||||||
|
restart_times = 2
|
||||||
|
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
|
||||||
|
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
|
||||||
|
else:
|
||||||
|
restart_list = {}
|
||||||
|
|
||||||
|
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
|
||||||
|
|
||||||
|
step_list = []
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
step_list.append((sigmas[i], sigmas[i + 1]))
|
||||||
|
if i + 1 in restart_list:
|
||||||
|
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||||
|
min_idx = i + 1
|
||||||
|
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||||
|
if max_idx < min_idx:
|
||||||
|
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||||
|
while restart_times > 0:
|
||||||
|
restart_times -= 1
|
||||||
|
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
||||||
|
|
||||||
|
last_sigma = None
|
||||||
|
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||||
|
if last_sigma is None:
|
||||||
|
last_sigma = old_sigma
|
||||||
|
elif last_sigma < old_sigma:
|
||||||
|
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
|
||||||
|
x = heun_step(x, old_sigma, new_sigma)
|
||||||
|
last_sigma = new_sigma
|
||||||
|
|
||||||
|
return x
|
@ -1,47 +1,60 @@
|
|||||||
from collections import deque
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||||
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||||
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
|
||||||
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
samplers_data_k_diffusion = [
|
||||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||||
for label, funcname, aliases, options in samplers_k_diffusion
|
for label, funcname, aliases, options in samplers_k_diffusion
|
||||||
if hasattr(k_diffusion.sampling, funcname)
|
if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
|
||||||
]
|
]
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
'sample_dpm_fast': ['s_noise'],
|
||||||
|
'sample_dpm_2_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||||
}
|
}
|
||||||
|
|
||||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
@ -53,289 +66,27 @@ k_diffusion_scheduler = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def catenate_conds(conds):
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
if not isinstance(conds[0], dict):
|
@property
|
||||||
return torch.cat(conds)
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
|
||||||
|
|
||||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
def subscript_cond(cond, a, b):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
if not isinstance(cond, dict):
|
def __init__(self, funcname, sd_model, options=None):
|
||||||
return cond[a:b]
|
super().__init__(funcname)
|
||||||
|
|
||||||
return {key: vec[a:b] for key, vec in cond.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def pad_cond(tensor, repeats, empty):
|
|
||||||
if not isinstance(tensor, dict):
|
|
||||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
|
||||||
|
|
||||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
|
||||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
|
||||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
|
||||||
negative prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = model
|
|
||||||
self.mask = None
|
|
||||||
self.nmask = None
|
|
||||||
self.init_latent = None
|
|
||||||
self.step = 0
|
|
||||||
self.image_cfg_scale = None
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
|
||||||
denoised = torch.clone(denoised_uncond)
|
|
||||||
|
|
||||||
for i, conds in enumerate(conds_list):
|
|
||||||
for cond_index, weight in conds:
|
|
||||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
|
||||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
|
||||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
|
||||||
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
|
|
||||||
# so is_edit_model is set to False to support AND composition.
|
|
||||||
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
|
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
|
||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
|
||||||
else:
|
|
||||||
image_uncond = image_cond
|
|
||||||
if isinstance(uncond, dict):
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
|
||||||
else:
|
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
|
||||||
cfg_denoiser_callback(denoiser_params)
|
|
||||||
x_in = denoiser_params.x
|
|
||||||
image_cond_in = denoiser_params.image_cond
|
|
||||||
sigma_in = denoiser_params.sigma
|
|
||||||
tensor = denoiser_params.text_cond
|
|
||||||
uncond = denoiser_params.text_uncond
|
|
||||||
skip_uncond = False
|
|
||||||
|
|
||||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
|
||||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
|
||||||
skip_uncond = True
|
|
||||||
x_in = x_in[:-batch_size]
|
|
||||||
sigma_in = sigma_in[:-batch_size]
|
|
||||||
|
|
||||||
self.padded_cond_uncond = False
|
|
||||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
|
||||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
|
||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
|
||||||
|
|
||||||
if num_repeats < 0:
|
|
||||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
elif num_repeats > 0:
|
|
||||||
uncond = pad_cond(uncond, num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
|
||||||
if is_edit_model:
|
|
||||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
|
||||||
elif skip_uncond:
|
|
||||||
cond_in = tensor
|
|
||||||
else:
|
|
||||||
cond_in = catenate_conds([tensor, uncond])
|
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = a + batch_size
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
|
||||||
else:
|
|
||||||
x_out = torch.zeros_like(x_in)
|
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
|
||||||
a = batch_offset
|
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
|
||||||
|
|
||||||
if not is_edit_model:
|
|
||||||
c_crossattn = subscript_cond(tensor, a, b)
|
|
||||||
else:
|
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
|
||||||
|
|
||||||
if not skip_uncond:
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
|
||||||
|
|
||||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
|
||||||
if skip_uncond:
|
|
||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
|
||||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
|
||||||
cfg_denoised_callback(denoised_params)
|
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
|
||||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
|
||||||
|
|
||||||
if is_edit_model:
|
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
|
||||||
elif skip_uncond:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
|
||||||
else:
|
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
||||||
|
|
||||||
if self.mask is not None:
|
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
|
||||||
|
|
||||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
|
||||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
|
||||||
denoised = after_cfg_callback_params.x
|
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
return denoised
|
|
||||||
|
|
||||||
|
|
||||||
class TorchHijack:
|
|
||||||
def __init__(self, sampler_noises):
|
|
||||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
|
||||||
# implementation.
|
|
||||||
self.sampler_noises = deque(sampler_noises)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if item == 'randn_like':
|
|
||||||
return self.randn_like
|
|
||||||
|
|
||||||
if hasattr(torch, item):
|
|
||||||
return getattr(torch, item)
|
|
||||||
|
|
||||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
|
||||||
|
|
||||||
def randn_like(self, x):
|
|
||||||
if self.sampler_noises:
|
|
||||||
noise = self.sampler_noises.popleft()
|
|
||||||
if noise.shape == x.shape:
|
|
||||||
return noise
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
|
||||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
|
||||||
else:
|
|
||||||
return torch.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
|
||||||
def __init__(self, funcname, sd_model):
|
|
||||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
|
||||||
|
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
|
||||||
self.funcname = funcname
|
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
|
||||||
self.sampler_noises = None
|
|
||||||
self.stop_at = None
|
|
||||||
self.eta = None
|
|
||||||
self.config = None # set by the function calling the constructor
|
|
||||||
self.last_latent = None
|
|
||||||
self.s_min_uncond = None
|
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
self.options = options or {}
|
||||||
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
|
||||||
def callback_state(self, d):
|
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||||
step = d['i']
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
latent = d["denoised"]
|
|
||||||
if opts.live_preview_content == "Combined":
|
|
||||||
sd_samplers_common.store_latent(latent)
|
|
||||||
self.last_latent = latent
|
|
||||||
|
|
||||||
if self.stop_at is not None and step > self.stop_at:
|
|
||||||
raise sd_samplers_common.InterruptedException
|
|
||||||
|
|
||||||
state.sampling_step = step
|
|
||||||
shared.total_tqdm.update()
|
|
||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
|
||||||
state.sampling_steps = steps
|
|
||||||
state.sampling_step = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except RecursionError:
|
|
||||||
print(
|
|
||||||
'Encountered RecursionError during sampling, returning last latent. '
|
|
||||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
|
||||||
'You should try to use a smaller rho value instead.'
|
|
||||||
)
|
|
||||||
return self.last_latent
|
|
||||||
except sd_samplers_common.InterruptedException:
|
|
||||||
return self.last_latent
|
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
|
||||||
return p.steps
|
|
||||||
|
|
||||||
def initialize(self, p):
|
|
||||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
||||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
||||||
self.model_wrap_cfg.step = 0
|
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
|
||||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
|
||||||
|
|
||||||
extra_params_kwargs = {}
|
|
||||||
for param_name in self.extra_params:
|
|
||||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
||||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
||||||
|
|
||||||
if 'eta' in inspect.signature(self.func).parameters:
|
|
||||||
if self.eta != 1.0:
|
|
||||||
p.extra_generation_params["Eta"] = self.eta
|
|
||||||
|
|
||||||
extra_params_kwargs['eta'] = self.eta
|
|
||||||
|
|
||||||
return extra_params_kwargs
|
|
||||||
|
|
||||||
def get_sigmas(self, p, steps):
|
def get_sigmas(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@ -376,6 +127,9 @@ class KDiffusionSampler:
|
|||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||||
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
||||||
|
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
||||||
else:
|
else:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
@ -384,24 +138,21 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def create_noise_sampler(self, x, sigmas, p):
|
|
||||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
|
||||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
|
||||||
return None
|
|
||||||
|
|
||||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
|
||||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
||||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
|
||||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
|
if opts.img2img_extra_noise > 0:
|
||||||
|
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||||
|
extra_noise_params = ExtraNoiseParams(noise, x, xi)
|
||||||
|
extra_noise_callback(extra_noise_params)
|
||||||
|
noise = extra_noise_params.noise
|
||||||
|
xi += noise * opts.img2img_extra_noise
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@ -421,9 +172,12 @@ class KDiffusionSampler:
|
|||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args = {
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
@ -431,7 +185,7 @@ class KDiffusionSampler:
|
|||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
@ -443,34 +197,46 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps)
|
sigmas = self.get_sigmas(p, steps)
|
||||||
|
|
||||||
x = x * sigmas[0]
|
if opts.sgm_noise_multiplier:
|
||||||
|
p.extra_generation_params["SGM noise multiplier"] = True
|
||||||
|
x = x * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||||
|
else:
|
||||||
|
x = x * sigmas[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'n' in parameters:
|
||||||
|
extra_params_kwargs['n'] = steps
|
||||||
|
|
||||||
if 'sigma_min' in parameters:
|
if 'sigma_min' in parameters:
|
||||||
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
|
||||||
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
|
||||||
if 'n' in parameters:
|
|
||||||
extra_params_kwargs['n'] = steps
|
if 'sigmas' in parameters:
|
||||||
else:
|
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
if self.config.options.get('brownian_noise', False):
|
if self.config.options.get('brownian_noise', False):
|
||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
p.extra_generation_params["Pad conds"] = True
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
167
modules/sd_samplers_timesteps.py
Normal file
167
modules/sd_samplers_timesteps.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import torch
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
|
||||||
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
|
||||||
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
import modules.shared as shared
|
||||||
|
|
||||||
|
samplers_timesteps = [
|
||||||
|
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||||
|
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
|
||||||
|
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
samplers_data_timesteps = [
|
||||||
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
|
||||||
|
for label, funcname, aliases, options in samplers_timesteps
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
return self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisTimestepsVDenoiser(torch.nn.Module):
|
||||||
|
def __init__(self, model, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||||
|
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
||||||
|
|
||||||
|
def forward(self, input, timesteps, **kwargs):
|
||||||
|
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
|
e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
super().__init__(sampler)
|
||||||
|
|
||||||
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
|
self.mask_before_denoising = True
|
||||||
|
|
||||||
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
|
ts = sigma.to(dtype=int)
|
||||||
|
|
||||||
|
a_t = self.alphas[ts][:, None, None, None]
|
||||||
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||||
|
|
||||||
|
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||||
|
|
||||||
|
return pred_x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
|
class CompVisSampler(sd_samplers_common.Sampler):
|
||||||
|
def __init__(self, funcname, sd_model):
|
||||||
|
super().__init__(funcname)
|
||||||
|
|
||||||
|
self.eta_option_field = 'eta_ddim'
|
||||||
|
self.eta_infotext_field = 'Eta DDIM'
|
||||||
|
self.eta_default = 0.0
|
||||||
|
|
||||||
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||||
|
|
||||||
|
def get_timesteps(self, p, steps):
|
||||||
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
|
if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
|
||||||
|
discard_next_to_last_sigma = True
|
||||||
|
p.extra_generation_params["Discard penultimate sigma"] = True
|
||||||
|
|
||||||
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)
|
||||||
|
|
||||||
|
return timesteps
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
timesteps_sched = timesteps[:t_enc]
|
||||||
|
|
||||||
|
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||||
|
sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
|
||||||
|
sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])
|
||||||
|
|
||||||
|
xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod
|
||||||
|
|
||||||
|
if opts.img2img_extra_noise > 0:
|
||||||
|
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
|
||||||
|
extra_noise_params = ExtraNoiseParams(noise, x, xi)
|
||||||
|
extra_noise_callback(extra_noise_params)
|
||||||
|
noise = extra_noise_params.noise
|
||||||
|
xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps_sched
|
||||||
|
if 'is_img2img' in parameters:
|
||||||
|
extra_params_kwargs['is_img2img'] = True
|
||||||
|
|
||||||
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
self.sampler_extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}
|
||||||
|
|
||||||
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
steps = steps or p.steps
|
||||||
|
timesteps = self.get_timesteps(p, steps)
|
||||||
|
|
||||||
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
|
if 'timesteps' in parameters:
|
||||||
|
extra_params_kwargs['timesteps'] = timesteps
|
||||||
|
|
||||||
|
self.last_latent = x
|
||||||
|
self.sampler_extra_args = {
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
|
}
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
|
||||||
|
VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
|
137
modules/sd_samplers_timesteps_impl.py
Normal file
137
modules/sd_samplers_timesteps_impl.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import k_diffusion.sampling
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.models.diffusion.uni_pc import uni_pc
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones((x.shape[0]))
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
|
||||||
|
a_t = alphas[index].item() * s_x
|
||||||
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
|
sigma_t = sigmas[index].item() * s_x
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||||
|
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
alphas = alphas_cumprod[timesteps]
|
||||||
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = alphas[index].item() * s_x
|
||||||
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev).sqrt() * e_t
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
|
index = len(timesteps) - 1 - i
|
||||||
|
ts = timesteps[index].item() * s_in
|
||||||
|
t_next = timesteps[max(index - 1, 0)].item() * s_in
|
||||||
|
|
||||||
|
e_t = model(x, ts, **extra_args)
|
||||||
|
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = model(x_prev, t_next, **extra_args)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
else:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
|
||||||
|
x = x_prev
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class UniPCCFG(uni_pc.UniPC):
|
||||||
|
def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):
|
||||||
|
super().__init__(None, *args, **kwargs)
|
||||||
|
|
||||||
|
def after_update(x, model_x):
|
||||||
|
callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})
|
||||||
|
self.index += 1
|
||||||
|
|
||||||
|
self.cfg_model = cfg_model
|
||||||
|
self.extra_args = extra_args
|
||||||
|
self.callback = callback
|
||||||
|
self.index = 0
|
||||||
|
self.after_update = after_update
|
||||||
|
|
||||||
|
def get_model_input_time(self, t_continuous):
|
||||||
|
return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.
|
||||||
|
|
||||||
|
def model(self, x, t):
|
||||||
|
t_input = self.get_model_input_time(t)
|
||||||
|
|
||||||
|
res = self.cfg_model(x, t_input, **self.extra_args)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):
|
||||||
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
|
|
||||||
|
ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
||||||
|
t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means
|
||||||
|
unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)
|
||||||
|
x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)
|
||||||
|
|
||||||
|
return x
|
@ -47,7 +47,7 @@ def apply_unet(option=None):
|
|||||||
if current_unet_option is None:
|
if current_unet_option is None:
|
||||||
current_unet = None
|
current_unet = None
|
||||||
|
|
||||||
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
if not shared.sd_model.lowvram:
|
||||||
shared.sd_model.model.diffusion_model.to(devices.device)
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||||
|
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@ -16,6 +19,23 @@ checkpoint_info = None
|
|||||||
|
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaded_vae_name():
|
||||||
|
if loaded_vae_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return os.path.basename(loaded_vae_file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaded_vae_hash():
|
||||||
|
if loaded_vae_file is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sha256 = hashes.sha256(loaded_vae_file, 'vae')
|
||||||
|
|
||||||
|
return sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
return base_vae
|
return base_vae
|
||||||
@ -83,6 +103,8 @@ def refresh_vae_list():
|
|||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
vae_dict[name] = filepath
|
vae_dict[name] = filepath
|
||||||
|
|
||||||
|
vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))
|
||||||
|
|
||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||||
@ -93,27 +115,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def resolve_vae(checkpoint_file):
|
@dataclass
|
||||||
if shared.cmd_opts.vae_path is not None:
|
class VaeResolution:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
vae: str = None
|
||||||
|
source: str = None
|
||||||
|
resolved: bool = True
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
def tuple(self):
|
||||||
|
return self.vae, self.source
|
||||||
|
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
|
||||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
|
||||||
return vae_near_checkpoint, 'found near the checkpoint'
|
|
||||||
|
|
||||||
|
def is_automatic():
|
||||||
|
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_setting() -> VaeResolution:
|
||||||
if shared.opts.sd_vae == "None":
|
if shared.opts.sd_vae == "None":
|
||||||
return None, None
|
return VaeResolution()
|
||||||
|
|
||||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||||
if vae_from_options is not None:
|
if vae_from_options is not None:
|
||||||
return vae_from_options, 'specified in settings'
|
return VaeResolution(vae_from_options, 'specified in settings')
|
||||||
|
|
||||||
if not is_automatic:
|
if not is_automatic():
|
||||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||||
|
|
||||||
return None, None
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||||
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
|
vae_metadata = metadata.get("vae", None)
|
||||||
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
|
if vae_metadata == "None":
|
||||||
|
return VaeResolution()
|
||||||
|
|
||||||
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
|
if vae_from_metadata is not None:
|
||||||
|
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||||
|
|
||||||
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||||
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
|
if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):
|
||||||
|
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||||
|
|
||||||
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||||
|
if shared.cmd_opts.vae_path is not None:
|
||||||
|
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||||
|
|
||||||
|
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||||
|
return resolve_vae_from_setting()
|
||||||
|
|
||||||
|
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_from_setting()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_vae_dict(filename, map_location):
|
def load_vae_dict(filename, map_location):
|
||||||
@ -123,7 +192,7 @@ def load_vae_dict(filename, map_location):
|
|||||||
|
|
||||||
|
|
||||||
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||||
global vae_dict, loaded_vae_file
|
global vae_dict, base_vae, loaded_vae_file
|
||||||
# save_settings = False
|
# save_settings = False
|
||||||
|
|
||||||
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
|
||||||
@ -161,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
|||||||
restore_base_vae(model)
|
restore_base_vae(model)
|
||||||
|
|
||||||
loaded_vae_file = vae_file
|
loaded_vae_file = vae_file
|
||||||
|
model.base_vae = base_vae
|
||||||
|
model.loaded_vae_file = loaded_vae_file
|
||||||
|
|
||||||
|
|
||||||
# don't call this from outside
|
# don't call this from outside
|
||||||
@ -178,8 +249,6 @@ unspecified = object()
|
|||||||
|
|
||||||
|
|
||||||
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||||
from modules import lowvram, devices, sd_hijack
|
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = shared.sd_model
|
||||||
|
|
||||||
@ -187,14 +256,14 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
|
||||||
if vae_file == unspecified:
|
if vae_file == unspecified:
|
||||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||||
else:
|
else:
|
||||||
vae_source = "from function argument"
|
vae_source = "from function argument"
|
||||||
|
|
||||||
if loaded_vae_file == vae_file:
|
if loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if sd_model.lowvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
sd_model.to(devices.cpu)
|
sd_model.to(devices.cpu)
|
||||||
@ -206,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
|
@ -81,6 +81,6 @@ def cheap_approximation(sample):
|
|||||||
|
|
||||||
coefs = torch.tensor(coeffs).to(sample.device)
|
coefs = torch.tensor(coeffs).to(sample.device)
|
||||||
|
|
||||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs)
|
||||||
|
|
||||||
return x_sample
|
return x_sample
|
||||||
|
@ -44,7 +44,17 @@ def decoder():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TAESD(nn.Module):
|
def encoder():
|
||||||
|
return nn.Sequential(
|
||||||
|
conv(3, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||||
|
conv(64, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TAESDDecoder(nn.Module):
|
||||||
latent_magnitude = 3
|
latent_magnitude = 3
|
||||||
latent_shift = 0.5
|
latent_shift = 0.5
|
||||||
|
|
||||||
@ -55,21 +65,28 @@ class TAESD(nn.Module):
|
|||||||
self.decoder.load_state_dict(
|
self.decoder.load_state_dict(
|
||||||
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unscale_latents(x):
|
class TAESDEncoder(nn.Module):
|
||||||
"""[0, 1] -> raw latents"""
|
latent_magnitude = 3
|
||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
latent_shift = 0.5
|
||||||
|
|
||||||
|
def __init__(self, encoder_path="taesd_encoder.pth"):
|
||||||
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder()
|
||||||
|
self.encoder.load_state_dict(
|
||||||
|
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_path, model_url):
|
def download_model(model_path, model_url):
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
|
||||||
print(f'Downloading TAESD decoder to: {model_path}')
|
print(f'Downloading TAESD model to: {model_path}')
|
||||||
torch.hub.download_url_to_file(model_url, model_path)
|
torch.hub.download_url_to_file(model_url, model_path)
|
||||||
|
|
||||||
|
|
||||||
def model():
|
def decoder_model():
|
||||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||||
|
|
||||||
@ -78,7 +95,7 @@ def model():
|
|||||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||||
|
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
loaded_model = TAESD(model_path)
|
loaded_model = TAESDDecoder(model_path)
|
||||||
loaded_model.eval()
|
loaded_model.eval()
|
||||||
loaded_model.to(devices.device, devices.dtype)
|
loaded_model.to(devices.device, devices.dtype)
|
||||||
sd_vae_taesd_models[model_name] = loaded_model
|
sd_vae_taesd_models[model_name] = loaded_model
|
||||||
@ -86,3 +103,22 @@ def model():
|
|||||||
raise FileNotFoundError('TAESD model not found')
|
raise FileNotFoundError('TAESD model not found')
|
||||||
|
|
||||||
return loaded_model.decoder
|
return loaded_model.decoder
|
||||||
|
|
||||||
|
|
||||||
|
def encoder_model():
|
||||||
|
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth"
|
||||||
|
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||||
|
|
||||||
|
if loaded_model is None:
|
||||||
|
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||||
|
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||||
|
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
loaded_model = TAESDEncoder(model_path)
|
||||||
|
loaded_model.eval()
|
||||||
|
loaded_model.to(devices.device, devices.dtype)
|
||||||
|
sd_vae_taesd_models[model_name] = loaded_model
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError('TAESD model not found')
|
||||||
|
|
||||||
|
return loaded_model.encoder
|
||||||
|
@ -1,771 +1,51 @@
|
|||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
import launch
|
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
||||||
import modules.interrogate
|
|
||||||
import modules.memmon
|
|
||||||
import modules.styles
|
|
||||||
import modules.devices as devices
|
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from modules import util
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
cmd_opts = shared_cmd_options.cmd_opts
|
||||||
|
parser = shared_cmd_options.parser
|
||||||
|
|
||||||
|
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||||
|
parallel_processing_allowed = True
|
||||||
|
styles_filename = cmd_opts.styles_file
|
||||||
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
parser = cmd_args.parser
|
device = None
|
||||||
|
|
||||||
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
weight_load_location = None
|
||||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
|
||||||
|
|
||||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
|
||||||
cmd_opts = parser.parse_args()
|
|
||||||
else:
|
|
||||||
cmd_opts, _ = parser.parse_known_args()
|
|
||||||
|
|
||||||
|
|
||||||
restricted_opts = {
|
|
||||||
"samples_filename_pattern",
|
|
||||||
"directories_filename_pattern",
|
|
||||||
"outdir_samples",
|
|
||||||
"outdir_txt2img_samples",
|
|
||||||
"outdir_img2img_samples",
|
|
||||||
"outdir_extras_samples",
|
|
||||||
"outdir_grids",
|
|
||||||
"outdir_txt2img_grids",
|
|
||||||
"outdir_save",
|
|
||||||
"outdir_init_images"
|
|
||||||
}
|
|
||||||
|
|
||||||
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
|
||||||
gradio_hf_hub_themes = [
|
|
||||||
"gradio/glass",
|
|
||||||
"gradio/monochrome",
|
|
||||||
"gradio/seafoam",
|
|
||||||
"gradio/soft",
|
|
||||||
"freddyaboulton/dracula_revamped",
|
|
||||||
"gradio/dracula_test",
|
|
||||||
"abidlabs/dracula_test",
|
|
||||||
"abidlabs/pakistan",
|
|
||||||
"dawood/microsoft_windows",
|
|
||||||
"ysharma/steampunk"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
|
||||||
|
|
||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
|
||||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
|
||||||
|
|
||||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
|
||||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
|
||||||
|
|
||||||
device = devices.device
|
|
||||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
|
||||||
xformers_available = False
|
xformers_available = False
|
||||||
config_filename = cmd_opts.ui_settings_file
|
|
||||||
|
|
||||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
|
||||||
hypernetworks = {}
|
hypernetworks = {}
|
||||||
|
|
||||||
loaded_hypernetworks = []
|
loaded_hypernetworks = []
|
||||||
|
|
||||||
|
state = None
|
||||||
|
|
||||||
def reload_hypernetworks():
|
prompt_styles = None
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
global hypernetworks
|
|
||||||
|
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
interrogator = None
|
||||||
|
|
||||||
|
|
||||||
class State:
|
|
||||||
skipped = False
|
|
||||||
interrupted = False
|
|
||||||
job = ""
|
|
||||||
job_no = 0
|
|
||||||
job_count = 0
|
|
||||||
processing_has_refined_job_count = False
|
|
||||||
job_timestamp = '0'
|
|
||||||
sampling_step = 0
|
|
||||||
sampling_steps = 0
|
|
||||||
current_latent = None
|
|
||||||
current_image = None
|
|
||||||
current_image_sampling_step = 0
|
|
||||||
id_live_preview = 0
|
|
||||||
textinfo = None
|
|
||||||
time_start = None
|
|
||||||
server_start = None
|
|
||||||
_server_command_signal = threading.Event()
|
|
||||||
_server_command: Optional[str] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def need_restart(self) -> bool:
|
|
||||||
# Compatibility getter for need_restart.
|
|
||||||
return self.server_command == "restart"
|
|
||||||
|
|
||||||
@need_restart.setter
|
|
||||||
def need_restart(self, value: bool) -> None:
|
|
||||||
# Compatibility setter for need_restart.
|
|
||||||
if value:
|
|
||||||
self.server_command = "restart"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def server_command(self):
|
|
||||||
return self._server_command
|
|
||||||
|
|
||||||
@server_command.setter
|
|
||||||
def server_command(self, value: Optional[str]) -> None:
|
|
||||||
"""
|
|
||||||
Set the server command to `value` and signal that it's been set.
|
|
||||||
"""
|
|
||||||
self._server_command = value
|
|
||||||
self._server_command_signal.set()
|
|
||||||
|
|
||||||
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Wait for server command to get set; return and clear the value and signal.
|
|
||||||
"""
|
|
||||||
if self._server_command_signal.wait(timeout):
|
|
||||||
self._server_command_signal.clear()
|
|
||||||
req = self._server_command
|
|
||||||
self._server_command = None
|
|
||||||
return req
|
|
||||||
return None
|
|
||||||
|
|
||||||
def request_restart(self) -> None:
|
|
||||||
self.interrupt()
|
|
||||||
self.server_command = "restart"
|
|
||||||
log.info("Received restart request")
|
|
||||||
|
|
||||||
def skip(self):
|
|
||||||
self.skipped = True
|
|
||||||
log.info("Received skip request")
|
|
||||||
|
|
||||||
def interrupt(self):
|
|
||||||
self.interrupted = True
|
|
||||||
log.info("Received interrupt request")
|
|
||||||
|
|
||||||
def nextjob(self):
|
|
||||||
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
|
|
||||||
self.do_set_current_image()
|
|
||||||
|
|
||||||
self.job_no += 1
|
|
||||||
self.sampling_step = 0
|
|
||||||
self.current_image_sampling_step = 0
|
|
||||||
|
|
||||||
def dict(self):
|
|
||||||
obj = {
|
|
||||||
"skipped": self.skipped,
|
|
||||||
"interrupted": self.interrupted,
|
|
||||||
"job": self.job,
|
|
||||||
"job_count": self.job_count,
|
|
||||||
"job_timestamp": self.job_timestamp,
|
|
||||||
"job_no": self.job_no,
|
|
||||||
"sampling_step": self.sampling_step,
|
|
||||||
"sampling_steps": self.sampling_steps,
|
|
||||||
}
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def begin(self, job: str = "(unknown)"):
|
|
||||||
self.sampling_step = 0
|
|
||||||
self.job_count = -1
|
|
||||||
self.processing_has_refined_job_count = False
|
|
||||||
self.job_no = 0
|
|
||||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
self.current_latent = None
|
|
||||||
self.current_image = None
|
|
||||||
self.current_image_sampling_step = 0
|
|
||||||
self.id_live_preview = 0
|
|
||||||
self.skipped = False
|
|
||||||
self.interrupted = False
|
|
||||||
self.textinfo = None
|
|
||||||
self.time_start = time.time()
|
|
||||||
self.job = job
|
|
||||||
devices.torch_gc()
|
|
||||||
log.info("Starting job %s", job)
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
duration = time.time() - self.time_start
|
|
||||||
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
|
||||||
self.job = ""
|
|
||||||
self.job_count = 0
|
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
def set_current_image(self):
|
|
||||||
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
|
|
||||||
if not parallel_processing_allowed:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1:
|
|
||||||
self.do_set_current_image()
|
|
||||||
|
|
||||||
def do_set_current_image(self):
|
|
||||||
if self.current_latent is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
import modules.sd_samplers
|
|
||||||
if opts.show_progress_grid:
|
|
||||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
|
||||||
else:
|
|
||||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
|
||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
|
||||||
|
|
||||||
def assign_current_image(self, image):
|
|
||||||
self.current_image = image
|
|
||||||
self.id_live_preview += 1
|
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
|
||||||
state.server_start = time.time()
|
|
||||||
|
|
||||||
styles_filename = cmd_opts.styles_file
|
|
||||||
prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|
||||||
|
|
||||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
options_templates = None
|
||||||
|
opts = None
|
||||||
|
restricted_opts = None
|
||||||
|
|
||||||
class OptionInfo:
|
sd_model: sd_models_types.WebuiSdModel = None
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
|
||||||
self.default = default
|
|
||||||
self.label = label
|
|
||||||
self.component = component
|
|
||||||
self.component_args = component_args
|
|
||||||
self.onchange = onchange
|
|
||||||
self.section = section
|
|
||||||
self.refresh = refresh
|
|
||||||
|
|
||||||
self.comment_before = comment_before
|
|
||||||
"""HTML text that will be added after label in UI"""
|
|
||||||
|
|
||||||
self.comment_after = comment_after
|
|
||||||
"""HTML text that will be added before label in UI"""
|
|
||||||
|
|
||||||
def link(self, label, url):
|
|
||||||
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def js(self, label, js_func):
|
|
||||||
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def info(self, info):
|
|
||||||
self.comment_after += f"<span class='info'>({info})</span>"
|
|
||||||
return self
|
|
||||||
|
|
||||||
def html(self, html):
|
|
||||||
self.comment_after += html
|
|
||||||
return self
|
|
||||||
|
|
||||||
def needs_restart(self):
|
|
||||||
self.comment_after += " <span class='info'>(requires restart)</span>"
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
|
||||||
for v in options_dict.values():
|
|
||||||
v.section = section_identifier
|
|
||||||
|
|
||||||
return options_dict
|
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint_tiles():
|
|
||||||
import modules.sd_models
|
|
||||||
return modules.sd_models.checkpoint_tiles()
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_checkpoints():
|
|
||||||
import modules.sd_models
|
|
||||||
return modules.sd_models.list_models()
|
|
||||||
|
|
||||||
|
|
||||||
def list_samplers():
|
|
||||||
import modules.sd_samplers
|
|
||||||
return modules.sd_samplers.all_samplers
|
|
||||||
|
|
||||||
|
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
|
||||||
tab_names = []
|
|
||||||
|
|
||||||
options_templates = {}
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
|
||||||
|
|
||||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
|
||||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
|
||||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
|
||||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
|
||||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
|
||||||
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
|
||||||
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
|
||||||
|
|
||||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
|
||||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
|
||||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
|
||||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
|
||||||
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
|
||||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
|
||||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
|
||||||
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
|
||||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
|
||||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
|
||||||
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
|
||||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
|
||||||
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
|
||||||
|
|
||||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
|
||||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
|
||||||
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
|
||||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
|
||||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
|
||||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
|
||||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
|
||||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
|
||||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
|
||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
|
||||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
|
||||||
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
|
||||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
|
||||||
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
|
||||||
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
|
||||||
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
|
||||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
|
||||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
|
||||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
|
||||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
|
||||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
|
||||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
|
||||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
|
||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
|
||||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
|
||||||
"experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
|
||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
|
||||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
|
||||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
|
||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
|
||||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
|
||||||
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
|
||||||
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
|
||||||
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
|
||||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|
||||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
|
||||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
|
||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
|
||||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
|
||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
|
||||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
|
||||||
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
|
||||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
|
||||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
|
||||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
|
||||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
|
||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
|
||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
|
||||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
|
||||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
|
||||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
|
||||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
|
||||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
|
||||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
|
||||||
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
|
||||||
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
|
||||||
</ul>"""),
|
|
||||||
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
|
||||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
|
||||||
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
|
||||||
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
|
||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
|
||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
|
||||||
'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"),
|
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"),
|
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
|
||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
|
||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
|
||||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
|
||||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
|
||||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
|
||||||
}))
|
|
||||||
|
|
||||||
options_templates.update(options_section((None, "Hidden options"), {
|
|
||||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
|
||||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
|
||||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
|
||||||
}))
|
|
||||||
|
|
||||||
|
|
||||||
options_templates.update()
|
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
|
||||||
data = None
|
|
||||||
data_labels = options_templates
|
|
||||||
typemap = {int: float}
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if self.data is not None:
|
|
||||||
if key in self.data or key in self.data_labels:
|
|
||||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
|
||||||
|
|
||||||
info = opts.data_labels.get(key, None)
|
|
||||||
comp_args = info.component_args if info else None
|
|
||||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
|
||||||
|
|
||||||
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
|
||||||
|
|
||||||
self.data[key] = value
|
|
||||||
return
|
|
||||||
|
|
||||||
return super(Options, self).__setattr__(key, value)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
if self.data is not None:
|
|
||||||
if item in self.data:
|
|
||||||
return self.data[item]
|
|
||||||
|
|
||||||
if item in self.data_labels:
|
|
||||||
return self.data_labels[item].default
|
|
||||||
|
|
||||||
return super(Options, self).__getattribute__(item)
|
|
||||||
|
|
||||||
def set(self, key, value):
|
|
||||||
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
|
|
||||||
|
|
||||||
oldval = self.data.get(key, None)
|
|
||||||
if oldval == value:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
setattr(self, key, value)
|
|
||||||
except RuntimeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.data_labels[key].onchange is not None:
|
|
||||||
try:
|
|
||||||
self.data_labels[key].onchange()
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"changing setting {key} to {value}")
|
|
||||||
setattr(self, key, oldval)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_default(self, key):
|
|
||||||
"""returns the default value for the key"""
|
|
||||||
|
|
||||||
data_label = self.data_labels.get(key)
|
|
||||||
if data_label is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return data_label.default
|
|
||||||
|
|
||||||
def save(self, filename):
|
|
||||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
|
||||||
json.dump(self.data, file, indent=4)
|
|
||||||
|
|
||||||
def same_type(self, x, y):
|
|
||||||
if x is None or y is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
type_x = self.typemap.get(type(x), type(x))
|
|
||||||
type_y = self.typemap.get(type(y), type(y))
|
|
||||||
|
|
||||||
return type_x == type_y
|
|
||||||
|
|
||||||
def load(self, filename):
|
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
|
||||||
self.data = json.load(file)
|
|
||||||
|
|
||||||
# 1.1.1 quicksettings list migration
|
|
||||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
|
||||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
|
||||||
|
|
||||||
# 1.4.0 ui_reorder
|
|
||||||
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
|
|
||||||
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
|
|
||||||
|
|
||||||
bad_settings = 0
|
|
||||||
for k, v in self.data.items():
|
|
||||||
info = self.data_labels.get(k, None)
|
|
||||||
if info is not None and not self.same_type(info.default, v):
|
|
||||||
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
|
|
||||||
bad_settings += 1
|
|
||||||
|
|
||||||
if bad_settings > 0:
|
|
||||||
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
|
|
||||||
|
|
||||||
def onchange(self, key, func, call=True):
|
|
||||||
item = self.data_labels.get(key)
|
|
||||||
item.onchange = func
|
|
||||||
|
|
||||||
if call:
|
|
||||||
func()
|
|
||||||
|
|
||||||
def dumpjson(self):
|
|
||||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
|
||||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
|
||||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
|
||||||
return json.dumps(d)
|
|
||||||
|
|
||||||
def add_option(self, key, info):
|
|
||||||
self.data_labels[key] = info
|
|
||||||
|
|
||||||
def reorder(self):
|
|
||||||
"""reorder settings so that all items related to section always go together"""
|
|
||||||
|
|
||||||
section_ids = {}
|
|
||||||
settings_items = self.data_labels.items()
|
|
||||||
for _, item in settings_items:
|
|
||||||
if item.section not in section_ids:
|
|
||||||
section_ids[item.section] = len(section_ids)
|
|
||||||
|
|
||||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
|
||||||
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
default_value = self.data_labels[key].default
|
|
||||||
if default_value is None:
|
|
||||||
default_value = getattr(self, key, None)
|
|
||||||
if default_value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
expected_type = type(default_value)
|
|
||||||
if expected_type == bool and value == "False":
|
|
||||||
value = False
|
|
||||||
else:
|
|
||||||
value = expected_type(value)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
|
||||||
if os.path.exists(config_filename):
|
|
||||||
opts.load(config_filename)
|
|
||||||
|
|
||||||
|
|
||||||
class Shared(sys.modules[__name__].__class__):
|
|
||||||
"""
|
|
||||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
|
||||||
at program startup.
|
|
||||||
"""
|
|
||||||
|
|
||||||
sd_model_val = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sd_model(self):
|
|
||||||
import modules.sd_models
|
|
||||||
|
|
||||||
return modules.sd_models.model_data.get_sd_model()
|
|
||||||
|
|
||||||
@sd_model.setter
|
|
||||||
def sd_model(self, value):
|
|
||||||
import modules.sd_models
|
|
||||||
|
|
||||||
modules.sd_models.model_data.set_sd_model(value)
|
|
||||||
|
|
||||||
|
|
||||||
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
|
||||||
sys.modules[__name__].__class__ = Shared
|
|
||||||
|
|
||||||
settings_components = None
|
settings_components = None
|
||||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
|
tab_names = []
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
"Latent": {"mode": "bilinear", "antialias": False},
|
"Latent": {"mode": "bilinear", "antialias": False},
|
||||||
@ -784,108 +64,24 @@ progress_print_out = sys.stdout
|
|||||||
|
|
||||||
gradio_theme = gr.themes.Base()
|
gradio_theme = gr.themes.Base()
|
||||||
|
|
||||||
|
total_tqdm = None
|
||||||
|
|
||||||
def reload_gradio_theme(theme_name=None):
|
mem_mon = None
|
||||||
global gradio_theme
|
|
||||||
if not theme_name:
|
|
||||||
theme_name = opts.gradio_theme
|
|
||||||
|
|
||||||
default_theme_args = dict(
|
options_section = options.options_section
|
||||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
OptionInfo = options.OptionInfo
|
||||||
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
OptionHTML = options.OptionHTML
|
||||||
)
|
|
||||||
|
|
||||||
if theme_name == "Default":
|
natural_sort_key = util.natural_sort_key
|
||||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
listfiles = util.listfiles
|
||||||
else:
|
html_path = util.html_path
|
||||||
try:
|
html = util.html
|
||||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
walk_files = util.walk_files
|
||||||
except Exception as e:
|
ldm_print = util.ldm_print
|
||||||
errors.display(e, "changing gradio theme")
|
|
||||||
gradio_theme = gr.themes.Default(**default_theme_args)
|
|
||||||
|
|
||||||
|
reload_gradio_theme = shared_gradio_themes.reload_gradio_theme
|
||||||
|
|
||||||
|
list_checkpoint_tiles = shared_items.list_checkpoint_tiles
|
||||||
class TotalTQDM:
|
refresh_checkpoints = shared_items.refresh_checkpoints
|
||||||
def __init__(self):
|
list_samplers = shared_items.list_samplers
|
||||||
self._tqdm = None
|
reload_hypernetworks = shared_items.reload_hypernetworks
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self._tqdm = tqdm.tqdm(
|
|
||||||
desc="Total progress",
|
|
||||||
total=state.job_count * state.sampling_steps,
|
|
||||||
position=1,
|
|
||||||
file=progress_print_out
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self):
|
|
||||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
|
||||||
return
|
|
||||||
if self._tqdm is None:
|
|
||||||
self.reset()
|
|
||||||
self._tqdm.update()
|
|
||||||
|
|
||||||
def updateTotal(self, new_total):
|
|
||||||
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
|
||||||
return
|
|
||||||
if self._tqdm is None:
|
|
||||||
self.reset()
|
|
||||||
self._tqdm.total = new_total
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
if self._tqdm is not None:
|
|
||||||
self._tqdm.refresh()
|
|
||||||
self._tqdm.close()
|
|
||||||
self._tqdm = None
|
|
||||||
|
|
||||||
|
|
||||||
total_tqdm = TotalTQDM()
|
|
||||||
|
|
||||||
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
|
|
||||||
mem_mon.start()
|
|
||||||
|
|
||||||
|
|
||||||
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
|
|
||||||
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
|
|
||||||
|
|
||||||
|
|
||||||
def listfiles(dirname):
|
|
||||||
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
|
|
||||||
return [file for file in filenames if os.path.isfile(file)]
|
|
||||||
|
|
||||||
|
|
||||||
def html_path(filename):
|
|
||||||
return os.path.join(script_path, "html", filename)
|
|
||||||
|
|
||||||
|
|
||||||
def html(filename):
|
|
||||||
path = html_path(filename)
|
|
||||||
|
|
||||||
if os.path.exists(path):
|
|
||||||
with open(path, encoding="utf8") as file:
|
|
||||||
return file.read()
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def walk_files(path, allowed_extensions=None):
|
|
||||||
if not os.path.exists(path):
|
|
||||||
return
|
|
||||||
|
|
||||||
if allowed_extensions is not None:
|
|
||||||
allowed_extensions = set(allowed_extensions)
|
|
||||||
|
|
||||||
items = list(os.walk(path, followlinks=True))
|
|
||||||
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
|
|
||||||
|
|
||||||
for root, _, files in items:
|
|
||||||
for filename in sorted(files, key=natural_sort_key):
|
|
||||||
if allowed_extensions is not None:
|
|
||||||
_, ext = os.path.splitext(filename)
|
|
||||||
if ext not in allowed_extensions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
|
||||||
|
18
modules/shared_cmd_options.py
Normal file
18
modules/shared_cmd_options.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import launch
|
||||||
|
from modules import cmd_args, script_loading
|
||||||
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
|
parser = cmd_args.parser
|
||||||
|
|
||||||
|
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
||||||
|
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||||
|
|
||||||
|
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||||
|
cmd_opts = parser.parse_args()
|
||||||
|
else:
|
||||||
|
cmd_opts, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
|
||||||
|
cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access
|
67
modules/shared_gradio_themes.py
Normal file
67
modules/shared_gradio_themes.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors, shared
|
||||||
|
from modules.paths_internal import script_path
|
||||||
|
|
||||||
|
|
||||||
|
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||||
|
gradio_hf_hub_themes = [
|
||||||
|
"gradio/base",
|
||||||
|
"gradio/glass",
|
||||||
|
"gradio/monochrome",
|
||||||
|
"gradio/seafoam",
|
||||||
|
"gradio/soft",
|
||||||
|
"gradio/dracula_test",
|
||||||
|
"abidlabs/dracula_test",
|
||||||
|
"abidlabs/Lime",
|
||||||
|
"abidlabs/pakistan",
|
||||||
|
"Ama434/neutral-barlow",
|
||||||
|
"dawood/microsoft_windows",
|
||||||
|
"finlaymacklon/smooth_slate",
|
||||||
|
"Franklisi/darkmode",
|
||||||
|
"freddyaboulton/dracula_revamped",
|
||||||
|
"freddyaboulton/test-blue",
|
||||||
|
"gstaff/xkcd",
|
||||||
|
"Insuz/Mocha",
|
||||||
|
"Insuz/SimpleIndigo",
|
||||||
|
"JohnSmith9982/small_and_pretty",
|
||||||
|
"nota-ai/theme",
|
||||||
|
"nuttea/Softblue",
|
||||||
|
"ParityError/Anime",
|
||||||
|
"reilnuud/polite",
|
||||||
|
"remilia/Ghostly",
|
||||||
|
"rottenlittlecreature/Moon_Goblin",
|
||||||
|
"step-3-profit/Midnight-Deep",
|
||||||
|
"Taithrah/Minimal",
|
||||||
|
"ysharma/huggingface",
|
||||||
|
"ysharma/steampunk",
|
||||||
|
"NoCrypt/miku"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def reload_gradio_theme(theme_name=None):
|
||||||
|
if not theme_name:
|
||||||
|
theme_name = shared.opts.gradio_theme
|
||||||
|
|
||||||
|
default_theme_args = dict(
|
||||||
|
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||||
|
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if theme_name == "Default":
|
||||||
|
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
|
||||||
|
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
|
||||||
|
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
|
||||||
|
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
|
||||||
|
else:
|
||||||
|
os.makedirs(theme_cache_dir, exist_ok=True)
|
||||||
|
shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||||
|
shared.gradio_theme.dump(theme_cache_path)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "changing gradio theme")
|
||||||
|
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
49
modules/shared_init.py
Normal file
49
modules/shared_init.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
"""Initializes fields inside the shared module in a controlled manner.
|
||||||
|
|
||||||
|
Should be called early because some other modules you can import mingt need these fields to be already set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
|
|
||||||
|
from modules import options, shared_options
|
||||||
|
shared.options_templates = shared_options.options_templates
|
||||||
|
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
||||||
|
shared.restricted_opts = shared_options.restricted_opts
|
||||||
|
if os.path.exists(shared.config_filename):
|
||||||
|
shared.opts.load(shared.config_filename)
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||||
|
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||||
|
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||||
|
|
||||||
|
shared.device = devices.device
|
||||||
|
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
|
||||||
|
from modules import shared_state
|
||||||
|
shared.state = shared_state.State()
|
||||||
|
|
||||||
|
from modules import styles
|
||||||
|
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
|
||||||
|
|
||||||
|
from modules import interrogate
|
||||||
|
shared.interrogator = interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
|
from modules import shared_total_tqdm
|
||||||
|
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
|
||||||
|
|
||||||
|
from modules import memmon, devices
|
||||||
|
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
|
||||||
|
shared.mem_mon.start()
|
||||||
|
|
@ -1,3 +1,6 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
@ -41,13 +44,36 @@ def refresh_unet_list():
|
|||||||
modules.sd_unet.list_unets()
|
modules.sd_unet.list_unets()
|
||||||
|
|
||||||
|
|
||||||
|
def list_checkpoint_tiles():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.checkpoint_tiles()
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_checkpoints():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.list_models()
|
||||||
|
|
||||||
|
|
||||||
|
def list_samplers():
|
||||||
|
import modules.sd_samplers
|
||||||
|
return modules.sd_samplers.all_samplers
|
||||||
|
|
||||||
|
|
||||||
|
def reload_hypernetworks():
|
||||||
|
from modules.hypernetworks import hypernetwork
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
|
|
||||||
|
|
||||||
ui_reorder_categories_builtin_items = [
|
ui_reorder_categories_builtin_items = [
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
|
"accordions",
|
||||||
"checkboxes",
|
"checkboxes",
|
||||||
"hires_fix",
|
|
||||||
"dimensions",
|
"dimensions",
|
||||||
"cfg",
|
"cfg",
|
||||||
|
"denoising",
|
||||||
"seed",
|
"seed",
|
||||||
"batch",
|
"batch",
|
||||||
"override_settings",
|
"override_settings",
|
||||||
@ -61,9 +87,33 @@ def ui_reorder_categories():
|
|||||||
|
|
||||||
sections = {}
|
sections = {}
|
||||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||||
if isinstance(script.section, str):
|
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
|
||||||
sections[script.section] = 1
|
sections[script.section] = 1
|
||||||
|
|
||||||
yield from sections
|
yield from sections
|
||||||
|
|
||||||
yield "scripts"
|
yield "scripts"
|
||||||
|
|
||||||
|
|
||||||
|
class Shared(sys.modules[__name__].__class__):
|
||||||
|
"""
|
||||||
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
at program startup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sd_model_val = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sd_model(self):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
return modules.sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
modules.sd_models.model_data.set_sd_model(value)
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules['modules.shared'].__class__ = Shared
|
||||||
|
332
modules/shared_options.py
Normal file
332
modules/shared_options.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||||
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
from modules.options import options_section, OptionInfo, OptionHTML
|
||||||
|
|
||||||
|
options_templates = {}
|
||||||
|
hide_dirs = shared.hide_dirs
|
||||||
|
|
||||||
|
restricted_opts = {
|
||||||
|
"samples_filename_pattern",
|
||||||
|
"directories_filename_pattern",
|
||||||
|
"outdir_samples",
|
||||||
|
"outdir_txt2img_samples",
|
||||||
|
"outdir_img2img_samples",
|
||||||
|
"outdir_extras_samples",
|
||||||
|
"outdir_grids",
|
||||||
|
"outdir_txt2img_grids",
|
||||||
|
"outdir_save",
|
||||||
|
"outdir_init_images"
|
||||||
|
}
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||||
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||||
|
|
||||||
|
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||||
|
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||||
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
|
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||||
|
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||||
|
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||||
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
|
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
|
||||||
|
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||||
|
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||||
|
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||||
|
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||||
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
|
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
|
||||||
|
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||||
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
|
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||||
|
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||||
|
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||||
|
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||||
|
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||||
|
|
||||||
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
|
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||||
|
|
||||||
|
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||||
|
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||||
|
|
||||||
|
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||||
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||||
|
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
|
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
|
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||||
|
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||||
|
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||||
|
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||||
|
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||||
|
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
|
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
|
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||||
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||||
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('system', "System"), {
|
||||||
|
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||||
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||||
|
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||||
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('API', "API"), {
|
||||||
|
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||||
|
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||||
|
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
|
"save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."),
|
||||||
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
|
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||||
|
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
|
||||||
|
"training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
|
||||||
|
"training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
|
||||||
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||||
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(),
|
||||||
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}, infotext="RNG").info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||||
|
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
||||||
|
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('vae', "VAE"), {
|
||||||
|
"sd_vae_explanation": OptionHTML("""
|
||||||
|
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||||
|
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||||
|
(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.
|
||||||
|
For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.
|
||||||
|
"""),
|
||||||
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
|
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('img2img', "img2img"), {
|
||||||
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||||
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||||
|
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||||
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||||
|
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(),
|
||||||
|
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(),
|
||||||
|
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(),
|
||||||
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||||
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||||
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
|
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||||
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
|
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||||
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
|
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
|
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
|
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
|
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||||
|
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types),
|
||||||
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
|
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||||
|
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||||
|
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||||
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||||
|
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||||
|
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||||
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||||
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||||
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
|
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||||
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
|
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||||
|
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
|
||||||
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||||
|
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||||
|
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
|
||||||
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
||||||
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||||
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||||
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||||
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||||
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
||||||
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||||
|
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||||
|
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||||
|
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||||
|
<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>
|
||||||
|
<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>
|
||||||
|
</ul>"""),
|
||||||
|
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
|
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||||
|
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||||
|
"live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"),
|
||||||
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||||
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||||
|
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||||
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||||
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
|
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
|
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
|
'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"),
|
||||||
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
|
||||||
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||||
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||||
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||||
|
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
|
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
|
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||||
|
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||||
|
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||||
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
|
}))
|
||||||
|
|
159
modules/shared_state.py
Normal file
159
modules/shared_state.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from modules import errors, shared, devices
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class State:
|
||||||
|
skipped = False
|
||||||
|
interrupted = False
|
||||||
|
job = ""
|
||||||
|
job_no = 0
|
||||||
|
job_count = 0
|
||||||
|
processing_has_refined_job_count = False
|
||||||
|
job_timestamp = '0'
|
||||||
|
sampling_step = 0
|
||||||
|
sampling_steps = 0
|
||||||
|
current_latent = None
|
||||||
|
current_image = None
|
||||||
|
current_image_sampling_step = 0
|
||||||
|
id_live_preview = 0
|
||||||
|
textinfo = None
|
||||||
|
time_start = None
|
||||||
|
server_start = None
|
||||||
|
_server_command_signal = threading.Event()
|
||||||
|
_server_command: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.server_start = time.time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_restart(self) -> bool:
|
||||||
|
# Compatibility getter for need_restart.
|
||||||
|
return self.server_command == "restart"
|
||||||
|
|
||||||
|
@need_restart.setter
|
||||||
|
def need_restart(self, value: bool) -> None:
|
||||||
|
# Compatibility setter for need_restart.
|
||||||
|
if value:
|
||||||
|
self.server_command = "restart"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_command(self):
|
||||||
|
return self._server_command
|
||||||
|
|
||||||
|
@server_command.setter
|
||||||
|
def server_command(self, value: Optional[str]) -> None:
|
||||||
|
"""
|
||||||
|
Set the server command to `value` and signal that it's been set.
|
||||||
|
"""
|
||||||
|
self._server_command = value
|
||||||
|
self._server_command_signal.set()
|
||||||
|
|
||||||
|
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Wait for server command to get set; return and clear the value and signal.
|
||||||
|
"""
|
||||||
|
if self._server_command_signal.wait(timeout):
|
||||||
|
self._server_command_signal.clear()
|
||||||
|
req = self._server_command
|
||||||
|
self._server_command = None
|
||||||
|
return req
|
||||||
|
return None
|
||||||
|
|
||||||
|
def request_restart(self) -> None:
|
||||||
|
self.interrupt()
|
||||||
|
self.server_command = "restart"
|
||||||
|
log.info("Received restart request")
|
||||||
|
|
||||||
|
def skip(self):
|
||||||
|
self.skipped = True
|
||||||
|
log.info("Received skip request")
|
||||||
|
|
||||||
|
def interrupt(self):
|
||||||
|
self.interrupted = True
|
||||||
|
log.info("Received interrupt request")
|
||||||
|
|
||||||
|
def nextjob(self):
|
||||||
|
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||||
|
self.do_set_current_image()
|
||||||
|
|
||||||
|
self.job_no += 1
|
||||||
|
self.sampling_step = 0
|
||||||
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
|
def dict(self):
|
||||||
|
obj = {
|
||||||
|
"skipped": self.skipped,
|
||||||
|
"interrupted": self.interrupted,
|
||||||
|
"job": self.job,
|
||||||
|
"job_count": self.job_count,
|
||||||
|
"job_timestamp": self.job_timestamp,
|
||||||
|
"job_no": self.job_no,
|
||||||
|
"sampling_step": self.sampling_step,
|
||||||
|
"sampling_steps": self.sampling_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def begin(self, job: str = "(unknown)"):
|
||||||
|
self.sampling_step = 0
|
||||||
|
self.job_count = -1
|
||||||
|
self.processing_has_refined_job_count = False
|
||||||
|
self.job_no = 0
|
||||||
|
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
self.current_latent = None
|
||||||
|
self.current_image = None
|
||||||
|
self.current_image_sampling_step = 0
|
||||||
|
self.id_live_preview = 0
|
||||||
|
self.skipped = False
|
||||||
|
self.interrupted = False
|
||||||
|
self.textinfo = None
|
||||||
|
self.time_start = time.time()
|
||||||
|
self.job = job
|
||||||
|
devices.torch_gc()
|
||||||
|
log.info("Starting job %s", job)
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
duration = time.time() - self.time_start
|
||||||
|
log.info("Ending job %s (%.2f seconds)", self.job, duration)
|
||||||
|
self.job = ""
|
||||||
|
self.job_count = 0
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
def set_current_image(self):
|
||||||
|
"""if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
|
||||||
|
if not shared.parallel_processing_allowed:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
|
||||||
|
self.do_set_current_image()
|
||||||
|
|
||||||
|
def do_set_current_image(self):
|
||||||
|
if self.current_latent is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
import modules.sd_samplers
|
||||||
|
|
||||||
|
try:
|
||||||
|
if shared.opts.show_progress_grid:
|
||||||
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||||
|
else:
|
||||||
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||||
|
|
||||||
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||||
|
# we silently ignore this error
|
||||||
|
errors.record_exception()
|
||||||
|
|
||||||
|
def assign_current_image(self, image):
|
||||||
|
self.current_image = image
|
||||||
|
self.id_live_preview += 1
|
37
modules/shared_total_tqdm.py
Normal file
37
modules/shared_total_tqdm.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import tqdm
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
|
class TotalTQDM:
|
||||||
|
def __init__(self):
|
||||||
|
self._tqdm = None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._tqdm = tqdm.tqdm(
|
||||||
|
desc="Total progress",
|
||||||
|
total=shared.state.job_count * shared.state.sampling_steps,
|
||||||
|
position=1,
|
||||||
|
file=shared.progress_print_out
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||||
|
return
|
||||||
|
if self._tqdm is None:
|
||||||
|
self.reset()
|
||||||
|
self._tqdm.update()
|
||||||
|
|
||||||
|
def updateTotal(self, new_total):
|
||||||
|
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
|
||||||
|
return
|
||||||
|
if self._tqdm is None:
|
||||||
|
self.reset()
|
||||||
|
self._tqdm.total = new_total
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if self._tqdm is not None:
|
||||||
|
self._tqdm.refresh()
|
||||||
|
self._tqdm.close()
|
||||||
|
self._tqdm = None
|
||||||
|
|
@ -106,10 +106,7 @@ class StyleDatabase:
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
shutil.copy(path, f"{path}.bak")
|
shutil.copy(path, f"{path}.bak")
|
||||||
|
|
||||||
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
|
||||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
|
||||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||||
|
@ -58,7 +58,7 @@ def _summarize_chunk(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user