mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
Merge branch 'dev' into master
This commit is contained in:
commit
ea3aae9c39
@ -78,6 +78,8 @@ module.exports = {
|
||||
//extraNetworks.js
|
||||
requestGet: "readonly",
|
||||
popup: "readonly",
|
||||
// profilerVisualization.js
|
||||
createVisualizationTable: "readonly",
|
||||
// from python
|
||||
localization: "readonly",
|
||||
// progrssbar.js
|
||||
@ -86,8 +88,6 @@ module.exports = {
|
||||
// imageviewer.js
|
||||
modalPrevImage: "readonly",
|
||||
modalNextImage: "readonly",
|
||||
// token-counters.js
|
||||
setupTokenCounters: "readonly",
|
||||
// localStorage.js
|
||||
localSet: "readonly",
|
||||
localGet: "readonly",
|
||||
|
10
.github/workflows/run_tests.yaml
vendored
10
.github/workflows/run_tests.yaml
vendored
@ -20,6 +20,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
launch.py
|
||||
- name: Cache models
|
||||
id: cache-models
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: models
|
||||
key: "2023-12-30"
|
||||
- name: Install test dependencies
|
||||
run: pip install wait-for-it -r requirements-test.txt
|
||||
env:
|
||||
@ -33,6 +39,8 @@ jobs:
|
||||
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||
PYTHONUNBUFFERED: "1"
|
||||
- name: Print installed packages
|
||||
run: pip freeze
|
||||
- name: Start test server
|
||||
run: >
|
||||
python -m coverage run
|
||||
@ -49,7 +57,7 @@ jobs:
|
||||
2>&1 | tee output.txt &
|
||||
- name: Run tests
|
||||
run: |
|
||||
wait-for-it --service 127.0.0.1:7860 -t 600
|
||||
wait-for-it --service 127.0.0.1:7860 -t 20
|
||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||
- name: Kill test server
|
||||
if: always()
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -37,3 +37,4 @@ notification.mp3
|
||||
/node_modules
|
||||
/package-lock.json
|
||||
/.coverage*
|
||||
/test/test_outputs
|
||||
|
144
CHANGELOG.md
144
CHANGELOG.md
@ -1,3 +1,134 @@
|
||||
## 1.8.0-RC
|
||||
|
||||
### Features:
|
||||
* Update torch to version 2.1.2
|
||||
* Soft Inpainting ([#14208](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14208))
|
||||
* FP8 support ([#14031](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14031), [#14327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14327))
|
||||
* Support for SDXL-Inpaint Model ([#14390](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14390))
|
||||
* Use Spandrel for upscaling and face restoration architectures ([#14425](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14425), [#14467](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14467), [#14473](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14473), [#14474](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14474), [#14477](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14477), [#14476](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14476), [#14484](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14484), [#14500](https://github.com/AUTOMATIC1111/stable-difusion-webui/pull/14500), [#14501](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14501), [#14504](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14504), [#14524](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14524), [#14809](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14809))
|
||||
* Automatic backwards version compatibility (when loading infotexts from old images with program version specified, will add compatibility settings)
|
||||
* Implement zero terminal SNR noise schedule option (**[SEED BREAKING CHANGE](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Seed-breaking-changes#180-dev-170-225-2024-01-01---zero-terminal-snr-noise-schedule-option)**, [#14145](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14145), [#14979](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14979))
|
||||
* Add a [✨] button to run hires fix on selected image in the gallery (with help from [#14598](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14598), [#14626](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14626), [#14728](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14728))
|
||||
* [Separate assets repository](https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets); serve fonts locally rather than from google's servers
|
||||
* Official LCM Sampler Support ([#14583](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14583))
|
||||
* Add support for DAT upscaler models ([#14690](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14690), [#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039))
|
||||
* Extra Networks Tree View ([#14588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14588), [#14900](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14900))
|
||||
* NPU Support ([#14801](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14801))
|
||||
* Prompt comments support
|
||||
|
||||
### Minor:
|
||||
* Allow pasting in WIDTHxHEIGHT strings into the width/height fields ([#14296](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14296))
|
||||
* add option: Live preview in full page image viewer ([#14230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14230), [#14307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14307))
|
||||
* Add keyboard shortcuts for generate/skip/interrupt ([#14269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14269))
|
||||
* Better TCMALLOC support on different platforms ([#14227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14227), [#14883](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14883), [#14910](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14910))
|
||||
* Lora not found warning ([#14464](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14464))
|
||||
* Adding negative prompts to Loras in extra networks ([#14475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14475))
|
||||
* xyz_grid: allow varying the seed along an axis separate from axis options ([#12180](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12180))
|
||||
* option to convert VAE to bfloat16 (implementation of [#9295](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9295))
|
||||
* Better IPEX support ([#14229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14229), [#14353](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14353), [#14559](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14559), [#14562](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14562), [#14597](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14597))
|
||||
* Option to interrupt after current generation rather than immediately ([#13653](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13653), [#14659](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14659))
|
||||
* Fullscreen Preview control fading/disable ([#14291](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14291))
|
||||
* Finer settings freezing control ([#13789](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13789))
|
||||
* Increase Upscaler Limits ([#14589](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14589))
|
||||
* Adjust brush size with hotkeys ([#14638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14638))
|
||||
* Add checkpoint info to csv log file when saving images ([#14663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14663))
|
||||
* Make more columns resizable ([#14740](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14740), [#14884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14884))
|
||||
* Add an option to not overlay original image for inpainting for #14727
|
||||
* Add Pad conds v0 option to support same generation with DDIM as before 1.6.0
|
||||
* Add "Interrupting..." placeholder.
|
||||
* Button for refresh extensions list ([#14857](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14857))
|
||||
* Add an option to disable normalization after calculating emphasis. ([#14874](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14874))
|
||||
* When counting tokens, also include enabled styles (can be disabled in settings to revert to previous behavior)
|
||||
* Configuration for the [📂] button for image gallery ([#14947](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14947))
|
||||
* Support inference with LyCORIS BOFT networks ([#14871](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14871), [#14973](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14973))
|
||||
* support resizable columns for touch (tablets) ([#15002](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15002))
|
||||
|
||||
### Extensions and API:
|
||||
* Removed packages from requirements: basicsr, gfpgan, realesrgan; as well as their dependencies: absl-py, addict, beautifulsoup4, future, gdown, grpcio, importlib-metadata, lmdb, lpips, Markdown, platformdirs, PySocks, soupsieve, tb-nightly, tensorboard-data-server, tomli, Werkzeug, yapf, zipp, soupsieve
|
||||
* Enable task ids for API ([#14314](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14314))
|
||||
* add override_settings support for infotext API
|
||||
* rename generation_parameters_copypaste module to infotext_utils
|
||||
* prevent crash due to Script __init__ exception ([#14407](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14407))
|
||||
* Bump numpy to 1.26.2 ([#14471](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14471))
|
||||
* Add utility to inspect a model's dtype/device ([#14478](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14478))
|
||||
* Implement general forward method for all method in built-in lora ext ([#14547](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14547))
|
||||
* Execute model_loaded_callback after moving to target device ([#14563](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14563))
|
||||
* Add self to CFGDenoiserParams ([#14573](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14573))
|
||||
* Allow TLS with API only mode (--nowebui) ([#14593](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14593))
|
||||
* New callback: postprocess_image_after_composite ([#14657](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14657))
|
||||
* modules/api/api.py: add api endpoint to refresh embeddings list ([#14715](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14715))
|
||||
* set_named_arg ([#14773](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14773))
|
||||
* add before_token_counter callback and use it for prompt comments
|
||||
* ResizeHandleRow - allow overridden column scale parameter ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004))
|
||||
|
||||
### Performance
|
||||
* Massive performance improvement for extra networks directories with a huge number of files in them in an attempt to tackle #14507 ([#14528](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14528))
|
||||
* Reduce unnecessary re-indexing extra networks directory ([#14512](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14512))
|
||||
* Avoid unnecessary `isfile`/`exists` calls ([#14527](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14527))
|
||||
|
||||
### Bug Fixes:
|
||||
* fix multiple bugs related to styles multi-file support ([#14203](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14203), [#14276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14276), [#14707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14707))
|
||||
* Lora fixes ([#14300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14300), [#14237](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14237), [#14546](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14546), [#14726](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14726))
|
||||
* Re-add setting lost as part of e294e46 ([#14266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14266))
|
||||
* fix extras caption BLIP ([#14330](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14330))
|
||||
* include infotext into saved init image for img2img ([#14452](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14452))
|
||||
* xyz grid handle axis_type is None ([#14394](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14394))
|
||||
* Update Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed webui.py ([#14354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14354))
|
||||
* fix API thread safe issues of txt2img and img2img ([#14421](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14421))
|
||||
* handle selectable script_index is None ([#14487](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14487))
|
||||
* handle config.json failed to load ([#14525](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14525), [#14767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14767))
|
||||
* paste infotext cast int as float ([#14523](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14523))
|
||||
* Ensure GRADIO_ANALYTICS_ENABLED is set early enough ([#14537](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14537))
|
||||
* Fix logging configuration again ([#14538](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14538))
|
||||
* Handle CondFunc exception when resolving attributes ([#14560](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14560))
|
||||
* Fix extras big batch crashes ([#14699](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14699))
|
||||
* Fix using wrong model caused by alias ([#14655](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14655))
|
||||
* Add # to the invalid_filename_chars list ([#14640](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14640))
|
||||
* Fix extension check for requirements ([#14639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14639))
|
||||
* Fix tab indexes are reset after restart UI ([#14637](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14637))
|
||||
* Fix nested manual cast ([#14689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14689))
|
||||
* Keep postprocessing upscale selected tab after restart ([#14702](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14702))
|
||||
* XYZ grid: filter out blank vals when axis is int or float type (like int axis seed) ([#14754](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14754))
|
||||
* fix CLIP Interrogator topN regex ([#14775](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14775))
|
||||
* Fix dtype error in MHA layer/change dtype checking mechanism for manual cast ([#14791](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14791))
|
||||
* catch load style.csv error ([#14814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14814))
|
||||
* fix error when editing extra networks card
|
||||
* fix extra networks metadata failing to work properly when you create the .json file with metadata for the first time.
|
||||
* util.walk_files extensions case insensitive ([#14879](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14879))
|
||||
* if extensions page not loaded, prevent apply ([#14873](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14873))
|
||||
* call the right function for token counter in img2img
|
||||
* Fix the bugs that search/reload will disappear when using other ExtraNetworks extensions ([#14939](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14939))
|
||||
* Gracefully handle mtime read exception from cache ([#14933](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14933))
|
||||
* Only trigger interrupt on `Esc` when interrupt button visible ([#14932](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14932))
|
||||
* Disable prompt token counters option actually disables token counting rather than just hiding results.
|
||||
* avoid double upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966))
|
||||
* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995))
|
||||
* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006))
|
||||
* Fix resize-handle for mobile ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010), [#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065))
|
||||
|
||||
### Other:
|
||||
* Assign id for "extra_options". Replace numeric field with slider. ([#14270](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14270))
|
||||
* change state dict comparison to ref compare ([#14216](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14216))
|
||||
* Bump torch-rocm to 5.6/5.7 ([#14293](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14293))
|
||||
* Base output path off data path ([#14446](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14446))
|
||||
* reorder training preprocessing modules in extras tab ([#14367](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14367))
|
||||
* Remove `cleanup_models` code ([#14472](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14472))
|
||||
* only rewrite ui-config when there is change ([#14352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14352))
|
||||
* Fix lint issue from 501993eb ([#14495](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14495))
|
||||
* Update README.md ([#14548](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14548))
|
||||
* hires button, fix seeds ()
|
||||
* Logging: set formatter correctly for fallback logger too ([#14618](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14618))
|
||||
* Read generation info from infotexts rather than json for internal needs (save, extract seed from generated pic) ([#14645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14645))
|
||||
* improve get_crop_region ([#14709](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14709))
|
||||
* Bump safetensors' version to 0.4.2 ([#14782](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14782))
|
||||
* add tooltip create_submit_box ([#14803](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14803))
|
||||
* extensions tab table row hover highlight ([#14885](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14885))
|
||||
* Always add timestamp to displayed image ([#14890](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14890))
|
||||
* Added core.filemode=false so doesn't track changes in file permission… ([#14930](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14930))
|
||||
* Normalize command-line argument paths ([#14934](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14934), [#15035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15035))
|
||||
* Use original App Title in progress bar ([#14916](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14916))
|
||||
* register_tmp_file also for mtime ([#15012](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15012))
|
||||
|
||||
## 1.7.0
|
||||
|
||||
### Features:
|
||||
@ -40,7 +171,8 @@
|
||||
* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
|
||||
* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
|
||||
* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
|
||||
* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
|
||||
* allow use of multiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
|
||||
* make extra network card description plaintext by default, with an option (Treat card description as HTML) to re-enable HTML as it was (originally by [#13241](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13241))
|
||||
|
||||
### Extensions and API:
|
||||
* update gradio to 3.41.2
|
||||
@ -176,7 +308,7 @@
|
||||
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
||||
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
||||
* makes all of them work with img2img
|
||||
* makes prompt composition posssible (AND)
|
||||
* makes prompt composition possible (AND)
|
||||
* makes them available for SDXL
|
||||
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
||||
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
||||
@ -352,7 +484,7 @@
|
||||
* user metadata system for custom networks
|
||||
* extended Lora metadata editor: set activation text, default weight, view tags, training info
|
||||
* Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
|
||||
* show github stars for extenstions
|
||||
* show github stars for extensions
|
||||
* img2img batch mode can read extra stuff from png info
|
||||
* img2img batch works with subdirectories
|
||||
* hotkeys to move prompt elements: alt+left/right
|
||||
@ -571,7 +703,7 @@
|
||||
* do not wait for Stable Diffusion model to load at startup
|
||||
* add filename patterns: `[denoising]`
|
||||
* directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for
|
||||
* LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metdata of the file, if present, instead of filename (both can be used to activate LoRA)
|
||||
* LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metadata of the file, if present, instead of filename (both can be used to activate LoRA)
|
||||
* LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
||||
* LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer)
|
||||
* LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss)
|
||||
@ -601,7 +733,7 @@
|
||||
* fix gamepad navigation
|
||||
* make the lightbox fullscreen image function properly
|
||||
* fix squished thumbnails in extras tab
|
||||
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
|
||||
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everything after you refreshed)
|
||||
* fix webui showing the same image if you configure the generation to always save results into same file
|
||||
* fix bug with upscalers not working properly
|
||||
* fix MPS on PyTorch 2.0.1, Intel Macs
|
||||
@ -619,7 +751,7 @@
|
||||
* switch to PyTorch 2.0.0 (except for AMD GPUs)
|
||||
* visual improvements to custom code scripts
|
||||
* add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]`
|
||||
* add support for saving init images in img2img, and record their hashes in infotext for reproducability
|
||||
* add support for saving init images in img2img, and record their hashes in infotext for reproducibility
|
||||
* automatically select current word when adjusting weight with ctrl+up/down
|
||||
* add dropdowns for X/Y/Z plot
|
||||
* add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
|
||||
|
13
README.md
13
README.md
@ -1,5 +1,5 @@
|
||||
# Stable Diffusion web UI
|
||||
A browser interface based on Gradio library for Stable Diffusion.
|
||||
A web interface for Stable Diffusion, implemented using Gradio library.
|
||||
|
||||
![](screenshot.png)
|
||||
|
||||
@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||
|
||||
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||
- ESRGAN - https://github.com/xinntao/ESRGAN
|
||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
|
||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||
- ESRGAN - https://github.com/xinntao/ESRGAN
|
||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||
- MiDaS - https://github.com/isl-org/MiDaS
|
||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||
|
5
_typos.toml
Normal file
5
_typos.toml
Normal file
@ -0,0 +1,5 @@
|
||||
[default.extend-words]
|
||||
# Part of "RGBa" (Pillow's pre-multiplied alpha RGB mode)
|
||||
Ba = "Ba"
|
||||
# HSA is something AMD uses for their GPUs
|
||||
HSA = "HSA"
|
98
configs/sd_xl_inpaint.yaml
Normal file
98
configs/sd_xl_inpaint.yaml
Normal file
@ -0,0 +1,98 @@
|
||||
model:
|
||||
target: sgm.models.diffusion.DiffusionEngine
|
||||
params:
|
||||
scale_factor: 0.13025
|
||||
disable_first_stage_autocast: True
|
||||
|
||||
denoiser_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||
params:
|
||||
num_idx: 1000
|
||||
|
||||
weighting_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||
scaling_config:
|
||||
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||
discretization_config:
|
||||
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||
|
||||
network_config:
|
||||
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
adm_in_channels: 2816
|
||||
num_classes: sequential
|
||||
use_checkpoint: True
|
||||
in_channels: 9
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [4, 2]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [1, 2, 4]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||
context_dim: 2048
|
||||
spatial_transformer_attn_type: softmax-xformers
|
||||
legacy: False
|
||||
|
||||
conditioner_config:
|
||||
target: sgm.modules.GeneralConditioner
|
||||
params:
|
||||
emb_models:
|
||||
# crossattn cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
params:
|
||||
layer: hidden
|
||||
layer_idx: 11
|
||||
# crossattn and vector cond
|
||||
- is_trainable: False
|
||||
input_key: txt
|
||||
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||
params:
|
||||
arch: ViT-bigG-14
|
||||
version: laion2b_s39b_b160k
|
||||
freeze: True
|
||||
layer: penultimate
|
||||
always_return_pooled: True
|
||||
legacy: False
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: original_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: crop_coords_top_left
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
# vector cond
|
||||
- is_trainable: False
|
||||
input_key: target_size_as_tuple
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
params:
|
||||
outdim: 256 # multiplied by two
|
||||
|
||||
first_stage_config:
|
||||
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: vanilla-xformers
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult: [1, 2, 4, 4]
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
@ -301,7 +301,7 @@ class DDPMV1(pl.LightningModule):
|
||||
elif self.parameterization == "x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
||||
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
||||
|
||||
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
||||
|
||||
@ -880,7 +880,7 @@ class LatentDiffusionV1(DDPMV1):
|
||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
# hybrid case, cond is exptected to be a dict
|
||||
# hybrid case, cond is expected to be a dict
|
||||
pass
|
||||
else:
|
||||
if not isinstance(cond, list):
|
||||
@ -916,7 +916,7 @@ class LatentDiffusionV1(DDPMV1):
|
||||
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
||||
|
||||
elif self.cond_stage_key == 'coordinates_bbox':
|
||||
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
||||
assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
|
||||
|
||||
# assuming padding of unfold is always 0 and its dilation is always 1
|
||||
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
||||
@ -926,7 +926,7 @@ class LatentDiffusionV1(DDPMV1):
|
||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||
rescale_latent = 2 ** (num_downs)
|
||||
|
||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
||||
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
||||
|
@ -30,7 +30,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||
secon value is a value for weight.
|
||||
|
||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||
|
||||
examples)
|
||||
factor
|
||||
|
@ -3,6 +3,9 @@ import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
@ -115,6 +118,29 @@ class NetworkModule:
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
@ -137,7 +163,7 @@ class NetworkModule:
|
||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||
if self.bias is not None:
|
||||
updown = updown.reshape(self.bias.shape)
|
||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||
updown = updown.reshape(output_shape)
|
||||
|
||||
if len(output_shape) == 4:
|
||||
@ -155,5 +181,10 @@ class NetworkModule:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
|
@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.weight.shape
|
||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.weight.to(orig_weight.device)
|
||||
if self.ex_bias is not None:
|
||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
|
||||
self.t2 = weights.w.get("hada_t2")
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
|
||||
if self.t1 is not None:
|
||||
output_shape = [w1a.size(1), w1b.size(1)]
|
||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t1 = self.t1.to(orig_weight.device)
|
||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||
output_shape += t1.shape[2:]
|
||||
else:
|
||||
@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
|
||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||
|
||||
if self.t2 is not None:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
else:
|
||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||
|
@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
|
||||
self.on_input = weights.w["on_input"].item()
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w = self.w.to(orig_weight.device)
|
||||
|
||||
output_shape = [w.size(0), orig_weight.size(1)]
|
||||
if self.on_input:
|
||||
|
@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
if self.w1 is not None:
|
||||
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1 = self.w1.to(orig_weight.device)
|
||||
else:
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1a = self.w1a.to(orig_weight.device)
|
||||
w1b = self.w1b.to(orig_weight.device)
|
||||
w1 = w1a @ w1b
|
||||
|
||||
if self.w2 is not None:
|
||||
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2 = self.w2.to(orig_weight.device)
|
||||
elif self.t2 is None:
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
t2 = self.t2.to(orig_weight.device)
|
||||
w2a = self.w2a.to(orig_weight.device)
|
||||
w2b = self.w2b.to(orig_weight.device)
|
||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||
|
||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||
|
@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
|
||||
return module
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
up = self.up_model.weight.to(orig_weight.device)
|
||||
down = self.down_model.weight.to(orig_weight.device)
|
||||
|
||||
output_shape = [up.size(0), down.size(1)]
|
||||
if self.mid_model is not None:
|
||||
# cp-decomposition
|
||||
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
mid = self.mid_model.weight.to(orig_weight.device)
|
||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||
output_shape += mid.shape[2:]
|
||||
else:
|
||||
|
@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
output_shape = self.w_norm.shape
|
||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
updown = self.w_norm.to(orig_weight.device)
|
||||
|
||||
if self.b_norm is not None:
|
||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
ex_bias = self.b_norm.to(orig_weight.device)
|
||||
else:
|
||||
ex_bias = None
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import network
|
||||
from lyco_helpers import factorization
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
@ -22,20 +21,28 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
self.scale = 1.0
|
||||
self.is_R = False
|
||||
self.is_boft = False
|
||||
|
||||
# kohya-ss
|
||||
# kohya-ss/New LyCORIS OFT/BOFT
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
self.is_kohya = True
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||
self.alpha = weights.w.get("alpha", None) # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# LyCORIS
|
||||
# Old LyCORIS OFT
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
self.is_kohya = False
|
||||
self.is_R = True
|
||||
self.oft_blocks = weights.w["oft_diag"]
|
||||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
# LyCORIS BOFT
|
||||
if self.oft_blocks.dim() == 4:
|
||||
self.is_boft = True
|
||||
self.rescale = weights.w.get('rescale', None)
|
||||
if self.rescale is not None:
|
||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
@ -47,27 +54,34 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
elif is_other_linear:
|
||||
self.out_dim = self.sd_module.embed_dim
|
||||
|
||||
if self.is_kohya:
|
||||
self.constraint = self.alpha * self.out_dim
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
else:
|
||||
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||
if self.is_R:
|
||||
self.constraint = None
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
self.block_size = self.dim
|
||||
self.num_blocks = self.out_dim // self.dim
|
||||
elif self.is_boft:
|
||||
self.boft_m = self.oft_blocks.shape[0]
|
||||
self.num_blocks = self.oft_blocks.shape[1]
|
||||
self.block_size = self.oft_blocks.shape[2]
|
||||
self.boft_b = self.block_size
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||
|
||||
if self.is_kohya:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||
if not self.is_R:
|
||||
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
|
||||
if self.constraint != 0:
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||
|
||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
R = oft_blocks.to(orig_weight.device)
|
||||
|
||||
if not self.is_boft:
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
merged_weight = torch.einsum(
|
||||
@ -76,7 +90,29 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
else:
|
||||
# TODO: determine correct value for scale
|
||||
scale = 1.0
|
||||
m = self.boft_m
|
||||
b = self.boft_b
|
||||
r_b = b // 2
|
||||
inp = orig_weight
|
||||
for i in range(m):
|
||||
bi = R[i] # b_num, b_size, b_size
|
||||
if i == 0:
|
||||
# Apply multiplier/scale and rescale into first weight
|
||||
bi = bi * scale + (1 - scale) * eye
|
||||
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
|
||||
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
|
||||
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
|
||||
inp = rearrange(inp, "d b ... -> (d b) ...")
|
||||
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
||||
merged_weight = inp
|
||||
|
||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||
# Rescale mechanism
|
||||
if self.rescale is not None:
|
||||
merged_weight = self.rescale.to(merged_weight) * merged_weight
|
||||
|
||||
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||
output_shape = orig_weight.shape
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -259,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
|
||||
loaded_networks.clear()
|
||||
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||
if any(x is None for x in networks_on_disk):
|
||||
list_available_networks()
|
||||
|
||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||
|
||||
failed_to_load_networks = []
|
||||
|
||||
@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
emb_db.skipped_embeddings[name] = embedding
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
|
||||
sd_hijack.model_hijack.comments.append(lora_not_found_message)
|
||||
if shared.opts.lora_not_found_warning_console:
|
||||
print(f'\n{lora_not_found_message}\n')
|
||||
if shared.opts.lora_not_found_gradio_warning:
|
||||
gr.Warning(lora_not_found_message)
|
||||
|
||||
purge_networks_from_memory()
|
||||
|
||||
@ -349,7 +355,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
"""
|
||||
Applies the currently selected set of networks to the weights of torch layer self.
|
||||
If weights already have this particular set of networks applied, does nothing.
|
||||
If not, restores orginal weights from backup and alters weights according to networks.
|
||||
If not, restores original weights from backup and alters weights according to networks.
|
||||
"""
|
||||
|
||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||
@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
if getattr(self, 'fp16_weight', None) is None:
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
else:
|
||||
weight = self.fp16_weight.clone().to(self.weight.device)
|
||||
bias = getattr(self, 'fp16_bias', None)
|
||||
if bias is not None:
|
||||
bias = bias.clone().to(self.bias.device)
|
||||
updown, ex_bias = module.calc_updown(weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
if len(weight.shape) == 4 and weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
|
||||
def network_forward(module, input, original_forward):
|
||||
def network_forward(org_module, input, original_forward):
|
||||
"""
|
||||
Old way of applying Lora by executing operations during layer's forward.
|
||||
Stacking many loras this way results in big performance degradation.
|
||||
"""
|
||||
|
||||
if len(loaded_networks) == 0:
|
||||
return original_forward(module, input)
|
||||
return original_forward(org_module, input)
|
||||
|
||||
input = devices.cond_cast_unet(input)
|
||||
|
||||
network_restore_weights_from_backup(module)
|
||||
network_reset_cached_weight(module)
|
||||
network_restore_weights_from_backup(org_module)
|
||||
network_reset_cached_weight(org_module)
|
||||
|
||||
y = original_forward(module, input)
|
||||
y = original_forward(org_module, input)
|
||||
|
||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||
network_layer_name = getattr(org_module, 'network_layer_name', None)
|
||||
for lora in loaded_networks:
|
||||
module = lora.modules.get(network_layer_name, None)
|
||||
if module is None:
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
from modules import paths
|
||||
from modules.paths_internal import normalized_filepath
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
||||
parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
||||
parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
||||
parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
||||
|
@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||
}))
|
||||
|
||||
|
||||
|
@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.slider_preferred_weight = None
|
||||
self.edit_notes = None
|
||||
|
||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
|
||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
|
||||
user_metadata = self.get_user_metadata(name)
|
||||
user_metadata["description"] = desc
|
||||
user_metadata["sd version"] = sd_version
|
||||
user_metadata["activation text"] = activation_text
|
||||
user_metadata["preferred weight"] = preferred_weight
|
||||
user_metadata["negative text"] = negative_text
|
||||
user_metadata["notes"] = notes
|
||||
|
||||
self.write_user_metadata(name, user_metadata)
|
||||
@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
||||
user_metadata.get('activation text', ''),
|
||||
float(user_metadata.get('preferred weight', 0.0)),
|
||||
user_metadata.get('negative text', ''),
|
||||
gr.update(visible=True if tags else False),
|
||||
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
||||
]
|
||||
@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
||||
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
||||
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
||||
|
||||
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
|
||||
with gr.Row() as row_random_prompt:
|
||||
with gr.Column(scale=8):
|
||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||
@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.taginfo,
|
||||
self.edit_activation_text,
|
||||
self.slider_preferred_weight,
|
||||
self.edit_negative_text,
|
||||
row_random_prompt,
|
||||
random_prompt,
|
||||
]
|
||||
@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
||||
self.select_sd_version,
|
||||
self.edit_activation_text,
|
||||
self.slider_preferred_weight,
|
||||
self.edit_negative_text,
|
||||
self.edit_notes,
|
||||
]
|
||||
|
||||
|
||||
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
||||
|
@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
|
||||
alias = lora_on_disk.get_alias()
|
||||
|
||||
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
|
||||
if lora_on_disk.hash:
|
||||
search_terms.append(lora_on_disk.hash)
|
||||
item = {
|
||||
"name": name,
|
||||
"filename": lora_on_disk.filename,
|
||||
"shorthash": lora_on_disk.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||
"search_terms": search_terms,
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": lora_on_disk.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||
@ -45,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
if activation_text:
|
||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||
|
||||
negative_prompt = item["user_metadata"].get("negative text")
|
||||
item["negative_prompt"] = quote_js("")
|
||||
if negative_prompt:
|
||||
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||
|
||||
sd_version = item["user_metadata"].get("sd version")
|
||||
if sd_version in network.SdVersion.__members__:
|
||||
item["sd_version"] = sd_version
|
||||
|
@ -1,16 +1,9 @@
|
||||
import sys
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader, script_callbacks, errors
|
||||
from scunet_model_arch import SCUNet
|
||||
|
||||
from modules.modelloader import load_file_from_url
|
||||
from modules.shared import opts
|
||||
from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
|
||||
|
||||
|
||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
scalers.append(scaler_data2)
|
||||
self.scalers = scalers
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def tiled_inference(img, model):
|
||||
# test the image tile by tile
|
||||
h, w = img.shape[2:]
|
||||
tile = opts.SCUNET_tile
|
||||
tile_overlap = opts.SCUNET_tile_overlap
|
||||
if tile == 0:
|
||||
return model(img)
|
||||
|
||||
device = devices.get_device_for('scunet')
|
||||
assert tile % 8 == 0, "tile size should be a multiple of window_size"
|
||||
sf = 1
|
||||
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
|
||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
|
||||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
|
||||
for w_idx in w_idx_list:
|
||||
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
E[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch)
|
||||
W[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch_mask)
|
||||
pbar.update(1)
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
|
||||
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
try:
|
||||
model = self.load_model(selected_file)
|
||||
except Exception as e:
|
||||
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
|
||||
device = devices.get_device_for('scunet')
|
||||
tile = opts.SCUNET_tile
|
||||
h, w = img.height, img.width
|
||||
np_img = np.array(img)
|
||||
np_img = np_img[:, :, ::-1] # RGB to BGR
|
||||
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
||||
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
||||
|
||||
if tile > h or tile > w:
|
||||
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
||||
_img[:, :, :h, :w] = torch_img # pad image
|
||||
torch_img = _img
|
||||
|
||||
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
|
||||
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
||||
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
||||
del torch_img, torch_output
|
||||
img = upscaler_utils.upscale_2(
|
||||
img,
|
||||
model,
|
||||
tile_size=shared.opts.SCUNET_tile,
|
||||
tile_overlap=shared.opts.SCUNET_tile_overlap,
|
||||
scale=1, # ScuNET is a denoising model, not an upscaler
|
||||
desc='ScuNET',
|
||||
)
|
||||
devices.torch_gc()
|
||||
|
||||
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
||||
output = output[:, :, ::-1] # BGR to RGB
|
||||
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
device = devices.get_device_for('scunet')
|
||||
if path.startswith("http"):
|
||||
# TODO: this doesn't use `path` at all?
|
||||
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||
else:
|
||||
filename = path
|
||||
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for _, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
model = model.to(device)
|
||||
|
||||
return model
|
||||
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
from modules import shared
|
||||
|
||||
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||
|
@ -1,268 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
from timm.models.layers import trunc_normal_, DropPath
|
||||
|
||||
|
||||
class WMSA(nn.Module):
|
||||
""" Self-attention module in Swin Transformer
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
||||
super(WMSA, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.head_dim = head_dim
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.n_heads = input_dim // head_dim
|
||||
self.window_size = window_size
|
||||
self.type = type
|
||||
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
||||
|
||||
self.relative_position_params = nn.Parameter(
|
||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
||||
|
||||
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
||||
|
||||
trunc_normal_(self.relative_position_params, std=.02)
|
||||
self.relative_position_params = torch.nn.Parameter(
|
||||
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
||||
2).transpose(
|
||||
0, 1))
|
||||
|
||||
def generate_mask(self, h, w, p, shift):
|
||||
""" generating the mask of SW-MSA
|
||||
Args:
|
||||
shift: shift parameters in CyclicShift.
|
||||
Returns:
|
||||
attn_mask: should be (1 1 w p p),
|
||||
"""
|
||||
# supporting square.
|
||||
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
||||
if self.type == 'W':
|
||||
return attn_mask
|
||||
|
||||
s = p - shift
|
||||
attn_mask[-1, :, :s, :, s:, :] = True
|
||||
attn_mask[-1, :, s:, :, :s, :] = True
|
||||
attn_mask[:, -1, :, :s, :, s:] = True
|
||||
attn_mask[:, -1, :, s:, :, :s] = True
|
||||
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
""" Forward pass of Window Multi-head Self-attention module.
|
||||
Args:
|
||||
x: input tensor with shape of [b h w c];
|
||||
attn_mask: attention mask, fill -inf where the value is True;
|
||||
Returns:
|
||||
output: tensor shape [b h w c]
|
||||
"""
|
||||
if self.type != 'W':
|
||||
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||
|
||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||
h_windows = x.size(1)
|
||||
w_windows = x.size(2)
|
||||
# square validation
|
||||
# assert h_windows == w_windows
|
||||
|
||||
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
||||
qkv = self.embedding_layer(x)
|
||||
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
||||
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
||||
# Adding learnable relative embedding
|
||||
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
||||
# Using Attn Mask to distinguish different subwindows.
|
||||
if self.type != 'W':
|
||||
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
||||
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
||||
|
||||
probs = nn.functional.softmax(sim, dim=-1)
|
||||
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
||||
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
||||
output = self.linear(output)
|
||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||
|
||||
if self.type != 'W':
|
||||
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||
|
||||
return output
|
||||
|
||||
def relative_embedding(self):
|
||||
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
||||
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
||||
# negative is allowed
|
||||
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||
""" SwinTransformer Block
|
||||
"""
|
||||
super(Block, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
assert type in ['W', 'SW']
|
||||
self.type = type
|
||||
if input_resolution <= window_size:
|
||||
self.type = 'W'
|
||||
|
||||
self.ln1 = nn.LayerNorm(input_dim)
|
||||
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.ln2 = nn.LayerNorm(input_dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, 4 * input_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * input_dim, output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.msa(self.ln1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConvTransBlock(nn.Module):
|
||||
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||
""" SwinTransformer and Conv Block
|
||||
"""
|
||||
super(ConvTransBlock, self).__init__()
|
||||
self.conv_dim = conv_dim
|
||||
self.trans_dim = trans_dim
|
||||
self.head_dim = head_dim
|
||||
self.window_size = window_size
|
||||
self.drop_path = drop_path
|
||||
self.type = type
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
assert self.type in ['W', 'SW']
|
||||
if self.input_resolution <= self.window_size:
|
||||
self.type = 'W'
|
||||
|
||||
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
||||
self.type, self.input_resolution)
|
||||
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
||||
conv_x = self.conv_block(conv_x) + conv_x
|
||||
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
||||
trans_x = self.trans_block(trans_x)
|
||||
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
||||
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
||||
x = x + res
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SCUNet(nn.Module):
|
||||
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||
super(SCUNet, self).__init__()
|
||||
if config is None:
|
||||
config = [2, 2, 2, 2, 2, 2, 2]
|
||||
self.config = config
|
||||
self.dim = dim
|
||||
self.head_dim = 32
|
||||
self.window_size = 8
|
||||
|
||||
# drop path rate for each layer
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
||||
|
||||
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
||||
|
||||
begin = 0
|
||||
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution)
|
||||
for i in range(config[0])] + \
|
||||
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[0]
|
||||
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||
for i in range(config[1])] + \
|
||||
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[1]
|
||||
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||
for i in range(config[2])] + \
|
||||
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
||||
|
||||
begin += config[2]
|
||||
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 8)
|
||||
for i in range(config[3])]
|
||||
|
||||
begin += config[3]
|
||||
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||
for i in range(config[4])]
|
||||
|
||||
begin += config[4]
|
||||
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||
for i in range(config[5])]
|
||||
|
||||
begin += config[5]
|
||||
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
||||
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||
'W' if not i % 2 else 'SW', input_resolution)
|
||||
for i in range(config[6])]
|
||||
|
||||
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
||||
|
||||
self.m_head = nn.Sequential(*self.m_head)
|
||||
self.m_down1 = nn.Sequential(*self.m_down1)
|
||||
self.m_down2 = nn.Sequential(*self.m_down2)
|
||||
self.m_down3 = nn.Sequential(*self.m_down3)
|
||||
self.m_body = nn.Sequential(*self.m_body)
|
||||
self.m_up3 = nn.Sequential(*self.m_up3)
|
||||
self.m_up2 = nn.Sequential(*self.m_up2)
|
||||
self.m_up1 = nn.Sequential(*self.m_up1)
|
||||
self.m_tail = nn.Sequential(*self.m_tail)
|
||||
# self.apply(self._init_weights)
|
||||
|
||||
def forward(self, x0):
|
||||
|
||||
h, w = x0.size()[-2:]
|
||||
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
||||
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
||||
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
||||
|
||||
x1 = self.m_head(x0)
|
||||
x2 = self.m_down1(x1)
|
||||
x3 = self.m_down2(x2)
|
||||
x4 = self.m_down3(x3)
|
||||
x = self.m_body(x4)
|
||||
x = self.m_up3(x + x4)
|
||||
x = self.m_up2(x + x3)
|
||||
x = self.m_up1(x + x2)
|
||||
x = self.m_tail(x + x1)
|
||||
|
||||
x = x[..., :h, :w]
|
||||
|
||||
return x
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
@ -1,20 +1,15 @@
|
||||
import logging
|
||||
import sys
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import opts, state
|
||||
from swinir_model_arch import SwinIR
|
||||
from swinir_model_arch_v2 import Swin2SR
|
||||
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||
|
||||
device_swinir = devices.get_device_for('swinir')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
|
||||
scalers.append(model_data)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img, model_file):
|
||||
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
||||
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
||||
current_config = (model_file, opts.SWIN_tile)
|
||||
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
|
||||
current_config = (model_file, shared.opts.SWIN_tile)
|
||||
|
||||
if use_compile and self._cached_model_config == current_config:
|
||||
if self._cached_model_config == current_config:
|
||||
model = self._cached_model
|
||||
else:
|
||||
self._cached_model = None
|
||||
try:
|
||||
model = self.load_model(model_file)
|
||||
except Exception as e:
|
||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
if use_compile:
|
||||
model = torch.compile(model)
|
||||
self._cached_model = model
|
||||
self._cached_model_config = current_config
|
||||
img = upscale(img, model)
|
||||
|
||||
img = upscaler_utils.upscale_2(
|
||||
img,
|
||||
model,
|
||||
tile_size=shared.opts.SWIN_tile,
|
||||
tile_overlap=shared.opts.SWIN_tile_overlap,
|
||||
scale=model.scale,
|
||||
desc="SwinIR",
|
||||
)
|
||||
devices.torch_gc()
|
||||
return img
|
||||
|
||||
@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if filename.endswith(".v2.pth"):
|
||||
model = Swin2SR(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6],
|
||||
embed_dim=180,
|
||||
num_heads=[6, 6, 6, 6, 6, 6],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="1conv",
|
||||
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=self._get_device(),
|
||||
prefer_half=(devices.dtype == torch.float16),
|
||||
expected_architecture="SwinIR",
|
||||
)
|
||||
params = None
|
||||
else:
|
||||
model = SwinIR(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||
embed_dim=240,
|
||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="3conv",
|
||||
)
|
||||
params = "params_ema"
|
||||
if getattr(shared.opts, 'SWIN_torch_compile', False):
|
||||
try:
|
||||
model_descriptor.model.compile()
|
||||
except Exception:
|
||||
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
|
||||
return model_descriptor
|
||||
|
||||
pretrained_model = torch.load(filename)
|
||||
if params is not None:
|
||||
model.load_state_dict(pretrained_model[params], strict=True)
|
||||
else:
|
||||
model.load_state_dict(pretrained_model, strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def upscale(
|
||||
img,
|
||||
model,
|
||||
tile=None,
|
||||
tile_overlap=None,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
):
|
||||
tile = tile or opts.SWIN_tile
|
||||
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
||||
|
||||
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
||||
with torch.no_grad(), devices.autocast():
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
||||
output = output[..., : h_old * scale, : w_old * scale]
|
||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(
|
||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
||||
) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||
return Image.fromarray(output, "RGB")
|
||||
|
||||
|
||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
# test the image tile by tile
|
||||
b, c, h, w = img.size()
|
||||
tile = min(tile, h, w)
|
||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
||||
sf = scale
|
||||
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
||||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
for w_idx in w_idx_list:
|
||||
if state.interrupted or state.skipped:
|
||||
break
|
||||
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
E[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch)
|
||||
W[
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch_mask)
|
||||
pbar.update(1)
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
def _get_device(self):
|
||||
return devices.get_device_for('swinir')
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
@ -185,7 +89,6 @@ def on_ui_settings():
|
||||
|
||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
||||
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
||||
|
||||
|
||||
|
@ -1,867 +0,0 @@
|
||||
# -----------------------------------------------------------------------------------
|
||||
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
||||
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
||||
# -----------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if self.shift_size > 0:
|
||||
attn_mask = self.calculate_mask(self.input_resolution)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def calculate_mask(self, x_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = x_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, x_size):
|
||||
H, W = x_size
|
||||
B, L, C = x.shape
|
||||
# assert L == H * W, "input feature has wrong size"
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
||||
if self.input_resolution == x_size:
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
else:
|
||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
||||
num_heads=num_heads, window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, x_size):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, x_size)
|
||||
else:
|
||||
x = blk(x, x_size)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class RSTB(nn.Module):
|
||||
"""Residual Swin Transformer Block (RSTB).
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
resi_connection: The convolutional block before residual connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
||||
img_size=224, patch_size=4, resi_connection='1conv'):
|
||||
super(RSTB, self).__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
self.residual_group = BasicLayer(dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop, attn_drop=attn_drop,
|
||||
drop_path=drop_path,
|
||||
norm_layer=norm_layer,
|
||||
downsample=downsample,
|
||||
use_checkpoint=use_checkpoint)
|
||||
|
||||
if resi_connection == '1conv':
|
||||
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
||||
elif resi_connection == '3conv':
|
||||
# to save parameters and memory
|
||||
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
def forward(self, x, x_size):
|
||||
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.residual_group.flops()
|
||||
H, W = self.input_resolution
|
||||
flops += H * W * self.dim * self.dim * 9
|
||||
flops += self.patch_embed.flops()
|
||||
flops += self.patch_unembed.flops()
|
||||
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.img_size
|
||||
if self.norm is not None:
|
||||
flops += H * W * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class PatchUnEmbed(nn.Module):
|
||||
r""" Image to Patch Unembedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x, x_size):
|
||||
B, HW, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
return flops
|
||||
|
||||
|
||||
class Upsample(nn.Sequential):
|
||||
"""Upsample module.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat):
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0: # scale = 2^n
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(2))
|
||||
elif scale == 3:
|
||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(3))
|
||||
else:
|
||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||
super(Upsample, self).__init__(*m)
|
||||
|
||||
|
||||
class UpsampleOneStep(nn.Sequential):
|
||||
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
||||
Used in lightweight SR to save parameters.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
||||
self.num_feat = num_feat
|
||||
self.input_resolution = input_resolution
|
||||
m = []
|
||||
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(scale))
|
||||
super(UpsampleOneStep, self).__init__(*m)
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.num_feat * 3 * 9
|
||||
return flops
|
||||
|
||||
|
||||
class SwinIR(nn.Module):
|
||||
r""" SwinIR
|
||||
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple(int)): Input image size. Default 64
|
||||
patch_size (int | tuple(int)): Patch size. Default: 1
|
||||
in_chans (int): Number of input image channels. Default: 3
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
||||
window_size (int): Window size. Default: 7
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
||||
drop_rate (float): Dropout rate. Default: 0
|
||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
||||
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
||||
img_range: Image range. 1. or 255.
|
||||
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
||||
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(SwinIR, self).__init__()
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
self.img_range = img_range
|
||||
if in_chans == 3:
|
||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
||||
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
||||
else:
|
||||
self.mean = torch.zeros(1, 1, 1, 1)
|
||||
self.upscale = upscale
|
||||
self.upsampler = upsampler
|
||||
self.window_size = window_size
|
||||
|
||||
#####################################################################################################
|
||||
################################### 1, shallow feature extraction ###################################
|
||||
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
||||
|
||||
#####################################################################################################
|
||||
################################### 2, deep feature extraction ######################################
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = embed_dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# merge non-overlapping patches into image
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# build Residual Swin Transformer blocks (RSTB)
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = RSTB(dim=embed_dim,
|
||||
input_resolution=(patches_resolution[0],
|
||||
patches_resolution[1]),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||
norm_layer=norm_layer,
|
||||
downsample=None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
resi_connection=resi_connection
|
||||
|
||||
)
|
||||
self.layers.append(layer)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# build the last conv layer in deep feature extraction
|
||||
if resi_connection == '1conv':
|
||||
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
||||
elif resi_connection == '3conv':
|
||||
# to save parameters and memory
|
||||
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
||||
|
||||
#####################################################################################################
|
||||
################################ 3, high quality image reconstruction ################################
|
||||
if self.upsampler == 'pixelshuffle':
|
||||
# for classical SR
|
||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
self.upsample = Upsample(upscale, num_feat)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
elif self.upsampler == 'pixelshuffledirect':
|
||||
# for lightweight SR (to save parameters)
|
||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||
(patches_resolution[0], patches_resolution[1]))
|
||||
elif self.upsampler == 'nearest+conv':
|
||||
# for real-world SR (less artifacts)
|
||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
if self.upscale == 4:
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
else:
|
||||
# for image denoising and JPEG compression artifact reduction
|
||||
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'absolute_pos_embed'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
||||
return x
|
||||
|
||||
def forward_features(self, x):
|
||||
x_size = (x.shape[2], x.shape[3])
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, x_size)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.patch_unembed(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.shape[2:]
|
||||
x = self.check_image_size(x)
|
||||
|
||||
self.mean = self.mean.type_as(x)
|
||||
x = (x - self.mean) * self.img_range
|
||||
|
||||
if self.upsampler == 'pixelshuffle':
|
||||
# for classical SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.conv_before_upsample(x)
|
||||
x = self.conv_last(self.upsample(x))
|
||||
elif self.upsampler == 'pixelshuffledirect':
|
||||
# for lightweight SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.upsample(x)
|
||||
elif self.upsampler == 'nearest+conv':
|
||||
# for real-world SR
|
||||
x = self.conv_first(x)
|
||||
x = self.conv_after_body(self.forward_features(x)) + x
|
||||
x = self.conv_before_upsample(x)
|
||||
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
||||
if self.upscale == 4:
|
||||
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
||||
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
||||
else:
|
||||
# for image denoising and JPEG compression artifact reduction
|
||||
x_first = self.conv_first(x)
|
||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||
x = x + self.conv_last(res)
|
||||
|
||||
x = x / self.img_range + self.mean
|
||||
|
||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.patches_resolution
|
||||
flops += H * W * 3 * self.embed_dim * 9
|
||||
flops += self.patch_embed.flops()
|
||||
for layer in self.layers:
|
||||
flops += layer.flops()
|
||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||
flops += self.upsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
upscale = 4
|
||||
window_size = 8
|
||||
height = (1024 // upscale // window_size + 1) * window_size
|
||||
width = (720 // upscale // window_size + 1) * window_size
|
||||
model = SwinIR(upscale=2, img_size=(height, width),
|
||||
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
||||
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
||||
print(model)
|
||||
print(height, width, model.flops() / 1e9)
|
||||
|
||||
x = torch.randn((1, 3, height, width))
|
||||
x = model(x)
|
||||
print(x.shape)
|
File diff suppressed because it is too large
Load Diff
@ -218,6 +218,8 @@ onUiLoaded(async() => {
|
||||
canvas_hotkey_fullscreen: "KeyS",
|
||||
canvas_hotkey_move: "KeyF",
|
||||
canvas_hotkey_overlap: "KeyO",
|
||||
canvas_hotkey_shrink_brush: "KeyQ",
|
||||
canvas_hotkey_grow_brush: "KeyW",
|
||||
canvas_disabled_functions: [],
|
||||
canvas_show_tooltip: true,
|
||||
canvas_auto_expand: true,
|
||||
@ -227,6 +229,8 @@ onUiLoaded(async() => {
|
||||
const functionMap = {
|
||||
"Zoom": "canvas_hotkey_zoom",
|
||||
"Adjust brush size": "canvas_hotkey_adjust",
|
||||
"Hotkey shrink brush": "canvas_hotkey_shrink_brush",
|
||||
"Hotkey enlarge brush": "canvas_hotkey_grow_brush",
|
||||
"Moving canvas": "canvas_hotkey_move",
|
||||
"Fullscreen": "canvas_hotkey_fullscreen",
|
||||
"Reset Zoom": "canvas_hotkey_reset",
|
||||
@ -288,7 +292,7 @@ onUiLoaded(async() => {
|
||||
|
||||
// Create tooltip
|
||||
function createTooltip() {
|
||||
const toolTipElemnt =
|
||||
const toolTipElement =
|
||||
targetElement.querySelector(".image-container");
|
||||
const tooltip = document.createElement("div");
|
||||
tooltip.className = "canvas-tooltip";
|
||||
@ -351,7 +355,7 @@ onUiLoaded(async() => {
|
||||
tooltip.appendChild(tooltipContent);
|
||||
|
||||
// Add a hint element to the target element
|
||||
toolTipElemnt.appendChild(tooltip);
|
||||
toolTipElement.appendChild(tooltip);
|
||||
}
|
||||
|
||||
//Show tool tip if setting enable
|
||||
@ -686,7 +690,9 @@ onUiLoaded(async() => {
|
||||
const hotkeyActions = {
|
||||
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
||||
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
||||
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
|
||||
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
|
||||
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
|
||||
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
|
||||
};
|
||||
|
||||
const action = hotkeyActions[event.code];
|
||||
|
@ -4,12 +4,14 @@ from modules import shared
|
||||
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
||||
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
|
||||
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
||||
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
||||
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas position"),
|
||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, needed for testing"),
|
||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||
}))
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
|
||||
import gradio as gr
|
||||
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
||||
from modules import scripts, shared, ui_components, ui_settings, infotext_utils, errors
|
||||
from modules.ui_components import FormColumn
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script):
|
||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||
|
||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||
mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
|
||||
|
||||
with gr.Blocks() as interface:
|
||||
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||
@ -42,7 +42,11 @@ class ExtraOptionsSection(scripts.Script):
|
||||
setting_name = extra_options[index]
|
||||
|
||||
with FormColumn():
|
||||
try:
|
||||
comp = ui_settings.create_setting_component(setting_name)
|
||||
except KeyError:
|
||||
errors.report(f"Can't add extra options for {setting_name} in ui")
|
||||
continue
|
||||
|
||||
self.comps.append(comp)
|
||||
self.setting_names.append(setting_name)
|
||||
|
761
extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
Normal file
761
extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
Normal file
@ -0,0 +1,761 @@
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
import math
|
||||
from modules.ui_components import InputAccordion
|
||||
import modules.scripts as scripts
|
||||
|
||||
|
||||
class SoftInpaintingSettings:
|
||||
def __init__(self,
|
||||
mask_blend_power,
|
||||
mask_blend_scale,
|
||||
inpaint_detail_preservation,
|
||||
composite_mask_influence,
|
||||
composite_difference_threshold,
|
||||
composite_difference_contrast):
|
||||
self.mask_blend_power = mask_blend_power
|
||||
self.mask_blend_scale = mask_blend_scale
|
||||
self.inpaint_detail_preservation = inpaint_detail_preservation
|
||||
self.composite_mask_influence = composite_mask_influence
|
||||
self.composite_difference_threshold = composite_difference_threshold
|
||||
self.composite_difference_contrast = composite_difference_contrast
|
||||
|
||||
def add_generation_params(self, dest):
|
||||
dest[enabled_gen_param_label] = True
|
||||
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
|
||||
dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
|
||||
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
|
||||
dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
|
||||
dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
|
||||
dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
|
||||
|
||||
|
||||
# ------------------- Methods -------------------
|
||||
|
||||
def processing_uses_inpainting(p):
|
||||
# TODO: Figure out a better way to determine if inpainting is being used by p
|
||||
if getattr(p, "image_mask", None) is not None:
|
||||
return True
|
||||
|
||||
if getattr(p, "mask", None) is not None:
|
||||
return True
|
||||
|
||||
if getattr(p, "nmask", None) is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def latent_blend(settings, a, b, t):
|
||||
"""
|
||||
Interpolates two latent image representations according to the parameter t,
|
||||
where the interpolated vectors' magnitudes are also interpolated separately.
|
||||
The "detail_preservation" factor biases the magnitude interpolation towards
|
||||
the larger of the two magnitudes.
|
||||
"""
|
||||
import torch
|
||||
|
||||
# NOTE: We use inplace operations wherever possible.
|
||||
|
||||
if len(t.shape) == 3:
|
||||
# [4][w][h] to [1][4][w][h]
|
||||
t2 = t.unsqueeze(0)
|
||||
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
||||
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
||||
else:
|
||||
t2 = t
|
||||
t3 = t[:, 0][:, None]
|
||||
|
||||
one_minus_t2 = 1 - t2
|
||||
one_minus_t3 = 1 - t3
|
||||
|
||||
# Linearly interpolate the image vectors.
|
||||
a_scaled = a * one_minus_t2
|
||||
b_scaled = b * t2
|
||||
image_interp = a_scaled
|
||||
image_interp.add_(b_scaled)
|
||||
result_type = image_interp.dtype
|
||||
del a_scaled, b_scaled, t2, one_minus_t2
|
||||
|
||||
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
||||
# 64-bit operations are used here to allow large exponents.
|
||||
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
|
||||
|
||||
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||
settings.inpaint_detail_preservation) * one_minus_t3
|
||||
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||
settings.inpaint_detail_preservation) * t3
|
||||
desired_magnitude = a_magnitude
|
||||
desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
|
||||
del a_magnitude, b_magnitude, t3, one_minus_t3
|
||||
|
||||
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
||||
# This is the last 64-bit operation.
|
||||
image_interp_scaling_factor = desired_magnitude
|
||||
image_interp_scaling_factor.div_(current_magnitude)
|
||||
image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
|
||||
image_interp_scaled = image_interp
|
||||
image_interp_scaled.mul_(image_interp_scaling_factor)
|
||||
del current_magnitude
|
||||
del desired_magnitude
|
||||
del image_interp
|
||||
del image_interp_scaling_factor
|
||||
del result_type
|
||||
|
||||
return image_interp_scaled
|
||||
|
||||
|
||||
def get_modified_nmask(settings, nmask, sigma):
|
||||
"""
|
||||
Converts a negative mask representing the transparency of the original latent vectors being overlaid
|
||||
to a mask that is scaled according to the denoising strength for this step.
|
||||
|
||||
Where:
|
||||
0 = fully opaque, infinite density, fully masked
|
||||
1 = fully transparent, zero density, fully unmasked
|
||||
|
||||
We bring this transparency to a power, as this allows one to simulate N number of blending operations
|
||||
where N can be any positive real value. Using this one can control the balance of influence between
|
||||
the denoiser and the original latents according to the sigma value.
|
||||
|
||||
NOTE: "mask" is not used
|
||||
"""
|
||||
import torch
|
||||
return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
|
||||
|
||||
|
||||
def apply_adaptive_masks(
|
||||
settings: SoftInpaintingSettings,
|
||||
nmask,
|
||||
latent_orig,
|
||||
latent_processed,
|
||||
overlay_images,
|
||||
width, height,
|
||||
paste_to):
|
||||
import torch
|
||||
import modules.processing as proc
|
||||
import modules.images as images
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
|
||||
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
|
||||
if len(nmask.shape) == 3:
|
||||
latent_mask = nmask[0].float()
|
||||
else:
|
||||
latent_mask = nmask[:, 0].float()
|
||||
# convert the original mask into a form we use to scale distances for thresholding
|
||||
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
|
||||
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
|
||||
+ mask_scalar * settings.composite_mask_influence)
|
||||
mask_scalar = mask_scalar / (1.00001 - mask_scalar)
|
||||
mask_scalar = mask_scalar.cpu().numpy()
|
||||
|
||||
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
|
||||
|
||||
kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
|
||||
|
||||
masks_for_overlay = []
|
||||
|
||||
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
|
||||
converted_mask = distance_map.float().cpu().numpy()
|
||||
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||
percentile_min=0.9, percentile_max=1, min_width=1)
|
||||
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||
percentile_min=0.25, percentile_max=0.75, min_width=1)
|
||||
|
||||
# The distance at which opacity of original decreases to 50%
|
||||
if len(mask_scalar.shape) == 3:
|
||||
if mask_scalar.shape[0] > i:
|
||||
half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i]
|
||||
else:
|
||||
half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0]
|
||||
else:
|
||||
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
|
||||
|
||||
converted_mask = converted_mask / half_weighted_distance
|
||||
|
||||
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
|
||||
converted_mask = smootherstep(converted_mask)
|
||||
converted_mask = 1 - converted_mask
|
||||
converted_mask = 255. * converted_mask
|
||||
converted_mask = converted_mask.astype(np.uint8)
|
||||
converted_mask = Image.fromarray(converted_mask)
|
||||
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||
|
||||
# Remove aliasing artifacts using a gaussian blur.
|
||||
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||
|
||||
# Expand the mask to fit the whole image if needed.
|
||||
if paste_to is not None:
|
||||
converted_mask = proc.uncrop(converted_mask,
|
||||
(overlay_image.width, overlay_image.height),
|
||||
paste_to)
|
||||
|
||||
masks_for_overlay.append(converted_mask)
|
||||
|
||||
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||
|
||||
overlay_images[i] = image_masked.convert('RGBA')
|
||||
|
||||
return masks_for_overlay
|
||||
|
||||
|
||||
def apply_masks(
|
||||
settings,
|
||||
nmask,
|
||||
overlay_images,
|
||||
width, height,
|
||||
paste_to):
|
||||
import torch
|
||||
import modules.processing as proc
|
||||
import modules.images as images
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
|
||||
converted_mask = nmask[0].float()
|
||||
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
|
||||
converted_mask = 255. * converted_mask
|
||||
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
|
||||
converted_mask = Image.fromarray(converted_mask)
|
||||
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||
|
||||
# Remove aliasing artifacts using a gaussian blur.
|
||||
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||
|
||||
# Expand the mask to fit the whole image if needed.
|
||||
if paste_to is not None:
|
||||
converted_mask = proc.uncrop(converted_mask,
|
||||
(width, height),
|
||||
paste_to)
|
||||
|
||||
masks_for_overlay = []
|
||||
|
||||
for i, overlay_image in enumerate(overlay_images):
|
||||
masks_for_overlay[i] = converted_mask
|
||||
|
||||
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||
|
||||
overlay_images[i] = image_masked.convert('RGBA')
|
||||
|
||||
return masks_for_overlay
|
||||
|
||||
|
||||
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||
"""
|
||||
Generalization convolution filter capable of applying
|
||||
weighted mean, median, maximum, and minimum filters
|
||||
parametrically using an arbitrary kernel.
|
||||
|
||||
Args:
|
||||
img (nparray):
|
||||
The image, a 2-D array of floats, to which the filter is being applied.
|
||||
kernel (nparray):
|
||||
The kernel, a 2-D array of floats.
|
||||
kernel_center (nparray):
|
||||
The kernel center coordinate, a 1-D array with two elements.
|
||||
percentile_min (float):
|
||||
The lower bound of the histogram window used by the filter,
|
||||
from 0 to 1.
|
||||
percentile_max (float):
|
||||
The upper bound of the histogram window used by the filter,
|
||||
from 0 to 1.
|
||||
min_width (float):
|
||||
The minimum size of the histogram window bounds, in weight units.
|
||||
Must be greater than 0.
|
||||
|
||||
Returns:
|
||||
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||
"""
|
||||
|
||||
# Converts an index tuple into a vector.
|
||||
def vec(x):
|
||||
return np.array(x)
|
||||
|
||||
kernel_min = -kernel_center
|
||||
kernel_max = vec(kernel.shape) - kernel_center
|
||||
|
||||
def weighted_histogram_filter_single(idx):
|
||||
idx = vec(idx)
|
||||
min_index = np.maximum(0, idx + kernel_min)
|
||||
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||
window_shape = max_index - min_index
|
||||
|
||||
class WeightedElement:
|
||||
"""
|
||||
An element of the histogram, its weight
|
||||
and bounds.
|
||||
"""
|
||||
|
||||
def __init__(self, value, weight):
|
||||
self.value: float = value
|
||||
self.weight: float = weight
|
||||
self.window_min: float = 0.0
|
||||
self.window_max: float = 1.0
|
||||
|
||||
# Collect the values in the image as WeightedElements,
|
||||
# weighted by their corresponding kernel values.
|
||||
values = []
|
||||
for window_tup in np.ndindex(tuple(window_shape)):
|
||||
window_index = vec(window_tup)
|
||||
image_index = window_index + min_index
|
||||
centered_kernel_index = image_index - idx
|
||||
kernel_index = centered_kernel_index + kernel_center
|
||||
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||
values.append(element)
|
||||
|
||||
def sort_key(x: WeightedElement):
|
||||
return x.value
|
||||
|
||||
values.sort(key=sort_key)
|
||||
|
||||
# Calculate the height of the stack (sum)
|
||||
# and each sample's range they occupy in the stack
|
||||
sum = 0
|
||||
for i in range(len(values)):
|
||||
values[i].window_min = sum
|
||||
sum += values[i].weight
|
||||
values[i].window_max = sum
|
||||
|
||||
# Calculate what range of this stack ("window")
|
||||
# we want to get the weighted average across.
|
||||
window_min = sum * percentile_min
|
||||
window_max = sum * percentile_max
|
||||
window_width = window_max - window_min
|
||||
|
||||
# Ensure the window is within the stack and at least a certain size.
|
||||
if window_width < min_width:
|
||||
window_center = (window_min + window_max) / 2
|
||||
window_min = window_center - min_width / 2
|
||||
window_max = window_center + min_width / 2
|
||||
|
||||
if window_max > sum:
|
||||
window_max = sum
|
||||
window_min = sum - min_width
|
||||
|
||||
if window_min < 0:
|
||||
window_min = 0
|
||||
window_max = min_width
|
||||
|
||||
value = 0
|
||||
value_weight = 0
|
||||
|
||||
# Get the weighted average of all the samples
|
||||
# that overlap with the window, weighted
|
||||
# by the size of their overlap.
|
||||
for i in range(len(values)):
|
||||
if window_min >= values[i].window_max:
|
||||
continue
|
||||
if window_max <= values[i].window_min:
|
||||
break
|
||||
|
||||
s = max(window_min, values[i].window_min)
|
||||
e = min(window_max, values[i].window_max)
|
||||
w = e - s
|
||||
|
||||
value += values[i].value * w
|
||||
value_weight += w
|
||||
|
||||
return value / value_weight if value_weight != 0 else 0
|
||||
|
||||
img_out = img.copy()
|
||||
|
||||
# Apply the kernel operation over each pixel.
|
||||
for index in np.ndindex(img.shape):
|
||||
img_out[index] = weighted_histogram_filter_single(index)
|
||||
|
||||
return img_out
|
||||
|
||||
|
||||
def smoothstep(x):
|
||||
"""
|
||||
The smoothstep function, input should be clamped to 0-1 range.
|
||||
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||
"""
|
||||
return x * x * (3 - 2 * x)
|
||||
|
||||
|
||||
def smootherstep(x):
|
||||
"""
|
||||
The smootherstep function, input should be clamped to 0-1 range.
|
||||
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||
"""
|
||||
return x * x * x * (x * (6 * x - 15) + 10)
|
||||
|
||||
|
||||
def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
|
||||
"""
|
||||
Creates a Gaussian kernel with thresholded edges.
|
||||
|
||||
Args:
|
||||
stddev_radius (float):
|
||||
Standard deviation of the gaussian kernel, in pixels.
|
||||
max_radius (int):
|
||||
The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
|
||||
The kernel is thresholded so that any values one pixel beyond this radius
|
||||
is weighted at 0.
|
||||
|
||||
Returns:
|
||||
(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
|
||||
"""
|
||||
|
||||
# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
|
||||
def gaussian(sqr_mag):
|
||||
return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
|
||||
|
||||
# Helper function for converting a tuple to an array.
|
||||
def vec(x):
|
||||
return np.array(x)
|
||||
|
||||
"""
|
||||
Since a gaussian is unbounded, we need to limit ourselves
|
||||
to a finite range.
|
||||
We taper the ends off at the end of that range so they equal zero
|
||||
while preserving the maximum value of 1 at the mean.
|
||||
"""
|
||||
zero_radius = max_radius + 1.0
|
||||
gauss_zero = gaussian(zero_radius * zero_radius)
|
||||
gauss_kernel_scale = 1 / (1 - gauss_zero)
|
||||
|
||||
def gaussian_kernel_func(coordinate):
|
||||
x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
|
||||
x = gaussian(x)
|
||||
x -= gauss_zero
|
||||
x *= gauss_kernel_scale
|
||||
x = max(0.0, x)
|
||||
return x
|
||||
|
||||
size = max_radius * 2 + 1
|
||||
kernel_center = max_radius
|
||||
kernel = np.zeros((size, size))
|
||||
|
||||
for index in np.ndindex(kernel.shape):
|
||||
kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
|
||||
|
||||
return kernel, kernel_center
|
||||
|
||||
|
||||
# ------------------- Constants -------------------
|
||||
|
||||
|
||||
default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
|
||||
|
||||
enabled_ui_label = "Soft inpainting"
|
||||
enabled_gen_param_label = "Soft inpainting enabled"
|
||||
enabled_el_id = "soft_inpainting_enabled"
|
||||
|
||||
ui_labels = SoftInpaintingSettings(
|
||||
"Schedule bias",
|
||||
"Preservation strength",
|
||||
"Transition contrast boost",
|
||||
"Mask influence",
|
||||
"Difference threshold",
|
||||
"Difference contrast")
|
||||
|
||||
ui_info = SoftInpaintingSettings(
|
||||
"Shifts when preservation of original content occurs during denoising.",
|
||||
"How strongly partially masked content should be preserved.",
|
||||
"Amplifies the contrast that may be lost in partially masked regions.",
|
||||
"How strongly the original mask should bias the difference threshold.",
|
||||
"How much an image region can change before the original pixels are not blended in anymore.",
|
||||
"How sharp the transition should be between blended and not blended.")
|
||||
|
||||
gen_param_labels = SoftInpaintingSettings(
|
||||
"Soft inpainting schedule bias",
|
||||
"Soft inpainting preservation strength",
|
||||
"Soft inpainting transition contrast boost",
|
||||
"Soft inpainting mask influence",
|
||||
"Soft inpainting difference threshold",
|
||||
"Soft inpainting difference contrast")
|
||||
|
||||
el_ids = SoftInpaintingSettings(
|
||||
"mask_blend_power",
|
||||
"mask_blend_scale",
|
||||
"inpaint_detail_preservation",
|
||||
"composite_mask_influence",
|
||||
"composite_difference_threshold",
|
||||
"composite_difference_contrast")
|
||||
|
||||
|
||||
# ------------------- Script -------------------
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
self.section = "inpaint"
|
||||
self.masks_for_overlay = None
|
||||
self.overlay_images = None
|
||||
|
||||
def title(self):
|
||||
return "Soft Inpainting"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible if is_img2img else False
|
||||
|
||||
def ui(self, is_img2img):
|
||||
if not is_img2img:
|
||||
return
|
||||
|
||||
with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
"""
|
||||
Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
|
||||
**High _Mask blur_** values are recommended!
|
||||
""")
|
||||
|
||||
power = \
|
||||
gr.Slider(label=ui_labels.mask_blend_power,
|
||||
info=ui_info.mask_blend_power,
|
||||
minimum=0,
|
||||
maximum=8,
|
||||
step=0.1,
|
||||
value=default.mask_blend_power,
|
||||
elem_id=el_ids.mask_blend_power)
|
||||
scale = \
|
||||
gr.Slider(label=ui_labels.mask_blend_scale,
|
||||
info=ui_info.mask_blend_scale,
|
||||
minimum=0,
|
||||
maximum=8,
|
||||
step=0.05,
|
||||
value=default.mask_blend_scale,
|
||||
elem_id=el_ids.mask_blend_scale)
|
||||
detail = \
|
||||
gr.Slider(label=ui_labels.inpaint_detail_preservation,
|
||||
info=ui_info.inpaint_detail_preservation,
|
||||
minimum=1,
|
||||
maximum=32,
|
||||
step=0.5,
|
||||
value=default.inpaint_detail_preservation,
|
||||
elem_id=el_ids.inpaint_detail_preservation)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
### Pixel Composite Settings
|
||||
""")
|
||||
|
||||
mask_inf = \
|
||||
gr.Slider(label=ui_labels.composite_mask_influence,
|
||||
info=ui_info.composite_mask_influence,
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.05,
|
||||
value=default.composite_mask_influence,
|
||||
elem_id=el_ids.composite_mask_influence)
|
||||
|
||||
dif_thresh = \
|
||||
gr.Slider(label=ui_labels.composite_difference_threshold,
|
||||
info=ui_info.composite_difference_threshold,
|
||||
minimum=0,
|
||||
maximum=8,
|
||||
step=0.25,
|
||||
value=default.composite_difference_threshold,
|
||||
elem_id=el_ids.composite_difference_threshold)
|
||||
|
||||
dif_contr = \
|
||||
gr.Slider(label=ui_labels.composite_difference_contrast,
|
||||
info=ui_info.composite_difference_contrast,
|
||||
minimum=0,
|
||||
maximum=8,
|
||||
step=0.25,
|
||||
value=default.composite_difference_contrast,
|
||||
elem_id=el_ids.composite_difference_contrast)
|
||||
|
||||
with gr.Accordion("Help", open=False):
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.mask_blend_power}
|
||||
|
||||
The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
|
||||
This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
|
||||
This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
|
||||
|
||||
- **Below 1**: Stronger preservation near the end (with low sigma)
|
||||
- **1**: Balanced (proportional to sigma)
|
||||
- **Above 1**: Stronger preservation in the beginning (with high sigma)
|
||||
""")
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.mask_blend_scale}
|
||||
|
||||
Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
|
||||
This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
|
||||
|
||||
- **Low values**: Favors generated content.
|
||||
- **High values**: Favors original content.
|
||||
""")
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.inpaint_detail_preservation}
|
||||
|
||||
This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
|
||||
With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
|
||||
This can prevent the loss of contrast that occurs with linear interpolation.
|
||||
|
||||
- **Low values**: Softer blending, details may fade.
|
||||
- **High values**: Stronger contrast, may over-saturate colors.
|
||||
""")
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
## Pixel Composite Settings
|
||||
|
||||
Masks are generated based on how much a part of the image changed after denoising.
|
||||
These masks are used to blend the original and final images together.
|
||||
If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
|
||||
""")
|
||||
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.composite_mask_influence}
|
||||
|
||||
This parameter controls how much the mask should bias this sensitivity to difference.
|
||||
|
||||
- **0**: Ignore the mask, only consider differences in image content.
|
||||
- **1**: Follow the mask closely despite image content changes.
|
||||
""")
|
||||
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.composite_difference_threshold}
|
||||
|
||||
This value represents the difference at which the original pixels will have less than 50% opacity.
|
||||
|
||||
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
|
||||
- **High values**: Two images patches can be very different and still retain original pixels.
|
||||
""")
|
||||
|
||||
gr.Markdown(
|
||||
f"""
|
||||
### {ui_labels.composite_difference_contrast}
|
||||
|
||||
This value represents the contrast between the opacity of the original and inpainted content.
|
||||
|
||||
- **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
|
||||
- **High values**: Ghosting will be less common, but transitions may be very sudden.
|
||||
""")
|
||||
|
||||
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
|
||||
(power, gen_param_labels.mask_blend_power),
|
||||
(scale, gen_param_labels.mask_blend_scale),
|
||||
(detail, gen_param_labels.inpaint_detail_preservation),
|
||||
(mask_inf, gen_param_labels.composite_mask_influence),
|
||||
(dif_thresh, gen_param_labels.composite_difference_threshold),
|
||||
(dif_contr, gen_param_labels.composite_difference_contrast)]
|
||||
|
||||
self.paste_field_names = []
|
||||
for _, field_name in self.infotext_fields:
|
||||
self.paste_field_names.append(field_name)
|
||||
|
||||
return [soft_inpainting_enabled,
|
||||
power,
|
||||
scale,
|
||||
detail,
|
||||
mask_inf,
|
||||
dif_thresh,
|
||||
dif_contr]
|
||||
|
||||
def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
if not processing_uses_inpainting(p):
|
||||
return
|
||||
|
||||
# Shut off the rounding it normally does.
|
||||
p.mask_round = False
|
||||
|
||||
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||
|
||||
# p.extra_generation_params["Mask rounding"] = False
|
||||
settings.add_generation_params(p.extra_generation_params)
|
||||
|
||||
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||
dif_thresh, dif_contr):
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
if not processing_uses_inpainting(p):
|
||||
return
|
||||
|
||||
if mba.is_final_blend:
|
||||
mba.blended_latent = mba.current_latent
|
||||
return
|
||||
|
||||
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||
|
||||
# todo: Why is sigma 2D? Both values are the same.
|
||||
mba.blended_latent = latent_blend(settings,
|
||||
mba.init_latent,
|
||||
mba.current_latent,
|
||||
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
|
||||
|
||||
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||
dif_thresh, dif_contr):
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
if not processing_uses_inpainting(p):
|
||||
return
|
||||
|
||||
nmask = getattr(p, "nmask", None)
|
||||
if nmask is None:
|
||||
return
|
||||
|
||||
from modules import images
|
||||
from modules.shared import opts
|
||||
|
||||
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||
|
||||
# since the original code puts holes in the existing overlay images,
|
||||
# we have to rebuild them.
|
||||
self.overlay_images = []
|
||||
for img in p.init_images:
|
||||
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
if p.paste_to is None and p.resize_mode != 3:
|
||||
image = images.resize_image(p.resize_mode, image, p.width, p.height)
|
||||
|
||||
self.overlay_images.append(image.convert('RGBA'))
|
||||
|
||||
if len(p.init_images) == 1:
|
||||
self.overlay_images = self.overlay_images * p.batch_size
|
||||
|
||||
if getattr(ps.samples, 'already_decoded', False):
|
||||
self.masks_for_overlay = apply_masks(settings=settings,
|
||||
nmask=nmask,
|
||||
overlay_images=self.overlay_images,
|
||||
width=p.width,
|
||||
height=p.height,
|
||||
paste_to=p.paste_to)
|
||||
else:
|
||||
self.masks_for_overlay = apply_adaptive_masks(settings=settings,
|
||||
nmask=nmask,
|
||||
latent_orig=p.init_latent,
|
||||
latent_processed=ps.samples,
|
||||
overlay_images=self.overlay_images,
|
||||
width=p.width,
|
||||
height=p.height,
|
||||
paste_to=p.paste_to)
|
||||
|
||||
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
|
||||
detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||
if not enabled:
|
||||
return
|
||||
|
||||
if not processing_uses_inpainting(p):
|
||||
return
|
||||
|
||||
if self.masks_for_overlay is None:
|
||||
return
|
||||
|
||||
if self.overlay_images is None:
|
||||
return
|
||||
|
||||
ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
|
||||
ppmo.overlay_image = self.overlay_images[ppmo.index]
|
@ -1,14 +1,9 @@
|
||||
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
|
||||
<div class="card" style="{style}" onclick="{card_clicked}" data-name="{name}" {sort_keys}>
|
||||
{background_image}
|
||||
<div class="button-row">
|
||||
{metadata_button}
|
||||
{edit_button}
|
||||
</div>
|
||||
<div class='actions'>
|
||||
<div class='additional'>
|
||||
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
|
||||
</div>
|
||||
<span class='name'>{name}</span>
|
||||
<span class='description'>{description}</span>
|
||||
<div class="button-row">{copy_path_button}{metadata_button}{edit_button}</div>
|
||||
<div class="actions">
|
||||
<div class="additional">{search_terms}</div>
|
||||
<span class="name">{name}</span>
|
||||
<span class="description">{description}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
5
html/extra-networks-copy-path-button.html
Normal file
5
html/extra-networks-copy-path-button.html
Normal file
@ -0,0 +1,5 @@
|
||||
<div class="copy-path-button card-button"
|
||||
title="Copy path to clipboard"
|
||||
onclick="extraNetworksCopyCardPath(event, '{filename}')"
|
||||
data-clipboard-text="{filename}">
|
||||
</div>
|
4
html/extra-networks-edit-item-button.html
Normal file
4
html/extra-networks-edit-item-button.html
Normal file
@ -0,0 +1,4 @@
|
||||
<div class="edit-button card-button"
|
||||
title="Edit metadata"
|
||||
onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
|
||||
</div>
|
4
html/extra-networks-metadata-button.html
Normal file
4
html/extra-networks-metadata-button.html
Normal file
@ -0,0 +1,4 @@
|
||||
<div class="metadata-button card-button"
|
||||
title="Show internal metadata"
|
||||
onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}', '{name}')">
|
||||
</div>
|
55
html/extra-networks-pane.html
Normal file
55
html/extra-networks-pane.html
Normal file
@ -0,0 +1,55 @@
|
||||
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane'>
|
||||
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
|
||||
<div class="extra-network-control--search">
|
||||
<input
|
||||
id="{tabname}_{extra_networks_tabname}_extra_search"
|
||||
class="extra-network-control--search-text"
|
||||
type="search"
|
||||
placeholder="Filter files"
|
||||
>
|
||||
</div>
|
||||
<div
|
||||
id="{tabname}_{extra_networks_tabname}_extra_sort"
|
||||
class="extra-network-control--sort"
|
||||
data-sortmode="{data_sortmode}"
|
||||
data-sortkey="{data_sortkey}"
|
||||
title="Sort by path"
|
||||
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||
>
|
||||
<i class="extra-network-control--sort-icon"></i>
|
||||
</div>
|
||||
<div
|
||||
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
|
||||
class="extra-network-control--sort-dir"
|
||||
data-sortdir="{data_sortdir}"
|
||||
title="Sort ascending"
|
||||
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||
>
|
||||
<i class="extra-network-control--sort-dir-icon"></i>
|
||||
</div>
|
||||
<div
|
||||
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
|
||||
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
|
||||
title="Enable Tree View"
|
||||
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||
>
|
||||
<i class="extra-network-control--tree-view-icon"></i>
|
||||
</div>
|
||||
<div
|
||||
id="{tabname}_{extra_networks_tabname}_extra_refresh"
|
||||
class="extra-network-control--refresh"
|
||||
title="Refresh page"
|
||||
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||
>
|
||||
<i class="extra-network-control--refresh-icon"></i>
|
||||
</div>
|
||||
</div>
|
||||
<div class="extra-network-pane-content resize-handle-row" style="display: {extra_network_pane_content_default_display};">
|
||||
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree {tree_view_div_extra_class}' style='flex-basis: {extra_networks_tree_view_default_width}px; display: {tree_view_div_default_display};'>
|
||||
{tree_html}
|
||||
</div>
|
||||
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>
|
||||
{items_html}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
23
html/extra-networks-tree-button.html
Normal file
23
html/extra-networks-tree-button.html
Normal file
@ -0,0 +1,23 @@
|
||||
<span data-filterable-item-text hidden>{search_terms}</span>
|
||||
<div class="tree-list-content {subclass}"
|
||||
type="button"
|
||||
onclick="extraNetworksTreeOnClick(event, '{tabname}', '{extra_networks_tabname}');{onclick_extra}"
|
||||
data-path="{data_path}"
|
||||
data-hash="{data_hash}"
|
||||
>
|
||||
<span class='tree-list-item-action tree-list-item-action--leading'>
|
||||
{action_list_item_action_leading}
|
||||
</span>
|
||||
<span class="tree-list-item-visual tree-list-item-visual--leading">
|
||||
{action_list_item_visual_leading}
|
||||
</span>
|
||||
<span class="tree-list-item-label tree-list-item-label--truncate">
|
||||
{action_list_item_label}
|
||||
</span>
|
||||
<span class="tree-list-item-visual tree-list-item-visual--trailing">
|
||||
{action_list_item_visual_trailing}
|
||||
</span>
|
||||
<span class="tree-list-item-action tree-list-item-action--trailing">
|
||||
{action_list_item_action_trailing}
|
||||
</span>
|
||||
</div>
|
@ -4,107 +4,6 @@
|
||||
#licenses pre { margin: 1em 0 2em 0;}
|
||||
</style>
|
||||
|
||||
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
||||
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
||||
<pre>
|
||||
S-Lab License 1.0
|
||||
|
||||
Copyright 2022 S-Lab
|
||||
|
||||
Redistribution and use for non-commercial purpose in source and
|
||||
binary forms, with or without modification, are permitted provided
|
||||
that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in
|
||||
the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
In the event that redistribution and/or use for commercial purpose in
|
||||
source or binary forms, with or without modification is required,
|
||||
please contact the contributor(s) of the work.
|
||||
</pre>
|
||||
|
||||
|
||||
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
||||
<small>Code for architecture and reading models copied.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021 victorca25
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
||||
<small>Some code is copied to support ESRGAN models.</small>
|
||||
<pre>
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||
<pre>
|
||||
@ -183,213 +82,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||
|
||||
<pre>
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [2021] [SwinIR Authors]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||
<pre>
|
||||
|
@ -50,17 +50,17 @@ function dimensionChange(e, is_width, is_height) {
|
||||
var scaledx = targetElement.naturalWidth * viewportscale;
|
||||
var scaledy = targetElement.naturalHeight * viewportscale;
|
||||
|
||||
var cleintRectTop = (viewportOffset.top + window.scrollY);
|
||||
var cleintRectLeft = (viewportOffset.left + window.scrollX);
|
||||
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
|
||||
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
|
||||
var clientRectTop = (viewportOffset.top + window.scrollY);
|
||||
var clientRectLeft = (viewportOffset.left + window.scrollX);
|
||||
var clientRectCentreY = clientRectTop + (targetElement.clientHeight / 2);
|
||||
var clientRectCentreX = clientRectLeft + (targetElement.clientWidth / 2);
|
||||
|
||||
var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
|
||||
var arscaledx = currentWidth * arscale;
|
||||
var arscaledy = currentHeight * arscale;
|
||||
|
||||
var arRectTop = cleintRectCentreY - (arscaledy / 2);
|
||||
var arRectLeft = cleintRectCentreX - (arscaledx / 2);
|
||||
var arRectTop = clientRectCentreY - (arscaledy / 2);
|
||||
var arRectLeft = clientRectCentreX - (arscaledx / 2);
|
||||
var arRectWidth = arscaledx;
|
||||
var arRectHeight = arscaledy;
|
||||
|
||||
|
@ -2,8 +2,11 @@
|
||||
function extensions_apply(_disabled_list, _update_list, disable_all) {
|
||||
var disable = [];
|
||||
var update = [];
|
||||
|
||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
|
||||
const extensions_input = gradioApp().querySelectorAll('#extensions input[type="checkbox"]');
|
||||
if (extensions_input.length == 0) {
|
||||
throw Error("Extensions page not yet loaded.");
|
||||
}
|
||||
extensions_input.forEach(function(x) {
|
||||
if (x.name.startsWith("enable_") && !x.checked) {
|
||||
disable.push(x.name.substring(7));
|
||||
}
|
||||
|
@ -16,56 +16,71 @@ function toggleCss(key, css, enable) {
|
||||
}
|
||||
|
||||
function setupExtraNetworksForTab(tabname) {
|
||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||
function registerPrompt(tabname, id) {
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
||||
var search = searchDiv.querySelector('textarea');
|
||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
|
||||
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
|
||||
if (!activePromptTextarea[tabname]) {
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
}
|
||||
|
||||
tabs.appendChild(searchDiv);
|
||||
tabs.appendChild(sort);
|
||||
tabs.appendChild(sortOrder);
|
||||
tabs.appendChild(refresh);
|
||||
tabs.appendChild(showDirsDiv);
|
||||
textarea.addEventListener("focus", function() {
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
});
|
||||
}
|
||||
|
||||
var applyFilter = function() {
|
||||
var tabnav = gradioApp().querySelector('#' + tabname + '_extra_tabs > div.tab-nav');
|
||||
var controlsDiv = document.createElement('DIV');
|
||||
controlsDiv.classList.add('extra-networks-controls-div');
|
||||
tabnav.appendChild(controlsDiv);
|
||||
tabnav.insertBefore(controlsDiv, null);
|
||||
|
||||
var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs');
|
||||
this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) {
|
||||
// tabname_full = {tabname}_{extra_networks_tabname}
|
||||
var tabname_full = elem.id;
|
||||
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
|
||||
var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort");
|
||||
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
|
||||
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
|
||||
|
||||
// If any of the buttons above don't exist, we want to skip this iteration of the loop.
|
||||
if (!search || !sort_mode || !sort_dir || !refresh) {
|
||||
return; // `return` is equivalent of `continue` but for forEach loops.
|
||||
}
|
||||
|
||||
var applyFilter = function(force) {
|
||||
var searchTerm = search.value.toLowerCase();
|
||||
|
||||
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
||||
var searchOnly = elem.querySelector('.search_only');
|
||||
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase();
|
||||
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) {
|
||||
return t.textContent.toLowerCase();
|
||||
}).join(" ");
|
||||
|
||||
var visible = text.indexOf(searchTerm) != -1;
|
||||
|
||||
if (searchOnly && searchTerm.length < 4) {
|
||||
visible = false;
|
||||
}
|
||||
|
||||
elem.style.display = visible ? "" : "none";
|
||||
if (visible) {
|
||||
elem.classList.remove("hidden");
|
||||
} else {
|
||||
elem.classList.add("hidden");
|
||||
}
|
||||
});
|
||||
|
||||
applySort();
|
||||
applySort(force);
|
||||
};
|
||||
|
||||
var applySort = function() {
|
||||
var applySort = function(force) {
|
||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||
|
||||
var reverse = sortOrder.classList.contains("sortReverse");
|
||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||
var reverse = sort_dir.dataset.sortdir == "Descending";
|
||||
var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
||||
|
||||
if (sortKeyStore == sort.dataset.sortkey) {
|
||||
if (sortKeyStore == sort_mode.dataset.sortkey && !force) {
|
||||
return;
|
||||
}
|
||||
sort.dataset.sortkey = sortKeyStore;
|
||||
sort_mode.dataset.sortkey = sortKeyStore;
|
||||
|
||||
cards.forEach(function(card) {
|
||||
card.originalParentElement = card.parentElement;
|
||||
@ -92,23 +107,21 @@ function setupExtraNetworksForTab(tabname) {
|
||||
};
|
||||
|
||||
search.addEventListener("input", applyFilter);
|
||||
sortOrder.addEventListener("click", function() {
|
||||
sortOrder.classList.toggle("sortReverse");
|
||||
applySort();
|
||||
});
|
||||
applyFilter();
|
||||
extraNetworksApplySort[tabname_full] = applySort;
|
||||
extraNetworksApplyFilter[tabname_full] = applyFilter;
|
||||
|
||||
extraNetworksApplySort[tabname] = applySort;
|
||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||
var controls = gradioApp().querySelector("#" + tabname_full + "_controls");
|
||||
controlsDiv.insertBefore(controls, null);
|
||||
|
||||
var showDirsUpdate = function() {
|
||||
var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
|
||||
toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
|
||||
localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
|
||||
};
|
||||
showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
|
||||
showDirs.addEventListener("change", showDirsUpdate);
|
||||
showDirsUpdate();
|
||||
if (elem.style.display != "none") {
|
||||
extraNetworksShowControlsForPage(tabname, tabname_full);
|
||||
}
|
||||
});
|
||||
|
||||
registerPrompt(tabname, tabname + "_prompt");
|
||||
registerPrompt(tabname, tabname + "_neg_prompt");
|
||||
}
|
||||
|
||||
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
||||
@ -137,21 +150,42 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp
|
||||
}
|
||||
|
||||
|
||||
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||
function extraNetworksShowControlsForPage(tabname, tabname_full) {
|
||||
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs .extra-networks-controls-div > div').forEach(function(elem) {
|
||||
var targetId = tabname_full + "_controls";
|
||||
elem.style.display = elem.id == targetId ? "" : "none";
|
||||
});
|
||||
}
|
||||
|
||||
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
|
||||
|
||||
function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||
|
||||
extraNetworksShowControlsForPage(tabname, null);
|
||||
}
|
||||
|
||||
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt, tabname_full) { // called from python when user selects an extra networks tab
|
||||
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
||||
|
||||
extraNetworksShowControlsForPage(tabname, tabname_full);
|
||||
}
|
||||
|
||||
function applyExtraNetworkFilter(tabname) {
|
||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||
function applyExtraNetworkFilter(tabname_full) {
|
||||
var doFilter = function() {
|
||||
var applyFunction = extraNetworksApplyFilter[tabname_full];
|
||||
|
||||
if (applyFunction) {
|
||||
applyFunction(true);
|
||||
}
|
||||
};
|
||||
setTimeout(doFilter, 1);
|
||||
}
|
||||
|
||||
function applyExtraNetworkSort(tabname) {
|
||||
setTimeout(extraNetworksApplySort[tabname], 1);
|
||||
function applyExtraNetworkSort(tabname_full) {
|
||||
var doSort = function() {
|
||||
extraNetworksApplySort[tabname_full](true);
|
||||
};
|
||||
setTimeout(doSort, 1);
|
||||
}
|
||||
|
||||
var extraNetworksApplyFilter = {};
|
||||
@ -161,41 +195,24 @@ var activePromptTextarea = {};
|
||||
function setupExtraNetworks() {
|
||||
setupExtraNetworksForTab('txt2img');
|
||||
setupExtraNetworksForTab('img2img');
|
||||
|
||||
function registerPrompt(tabname, id) {
|
||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||
|
||||
if (!activePromptTextarea[tabname]) {
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
}
|
||||
|
||||
textarea.addEventListener("focus", function() {
|
||||
activePromptTextarea[tabname] = textarea;
|
||||
});
|
||||
}
|
||||
|
||||
registerPrompt('txt2img', 'txt2img_prompt');
|
||||
registerPrompt('txt2img', 'txt2img_neg_prompt');
|
||||
registerPrompt('img2img', 'img2img_prompt');
|
||||
registerPrompt('img2img', 'img2img_neg_prompt');
|
||||
}
|
||||
|
||||
onUiLoaded(setupExtraNetworks);
|
||||
|
||||
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
|
||||
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
|
||||
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
var m = text.match(re_extranet);
|
||||
var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
|
||||
var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
|
||||
var m = text.match(isNeg ? re_extranet_neg : re_extranet);
|
||||
var replaced = false;
|
||||
var newTextareaText;
|
||||
if (m) {
|
||||
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
|
||||
if (m) {
|
||||
var extraTextAfterNet = m[2];
|
||||
var partToSearch = m[1];
|
||||
var foundAtPosition = -1;
|
||||
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
|
||||
m = found.match(re_extranet);
|
||||
newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
|
||||
m = found.match(isNeg ? re_extranet_neg : re_extranet);
|
||||
if (m[1] == partToSearch) {
|
||||
replaced = true;
|
||||
foundAtPosition = pos;
|
||||
@ -203,9 +220,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
}
|
||||
return found;
|
||||
});
|
||||
|
||||
if (foundAtPosition >= 0) {
|
||||
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||
if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
||||
}
|
||||
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
|
||||
@ -213,13 +229,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
||||
if (found == text) {
|
||||
replaced = true;
|
||||
return "";
|
||||
}
|
||||
return found;
|
||||
});
|
||||
newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, "g"), "");
|
||||
replaced = (newTextareaText != textarea.value);
|
||||
}
|
||||
|
||||
if (replaced) {
|
||||
@ -230,14 +241,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
return false;
|
||||
}
|
||||
|
||||
function cardClicked(tabname, textToAdd, allowNegativePrompt) {
|
||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
|
||||
|
||||
if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
|
||||
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
|
||||
function updatePromptArea(text, textArea, isNeg) {
|
||||
if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
|
||||
textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
|
||||
}
|
||||
|
||||
updateInput(textarea);
|
||||
updateInput(textArea);
|
||||
}
|
||||
|
||||
function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
|
||||
if (textToAddNegative.length > 0) {
|
||||
updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
|
||||
updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
|
||||
} else {
|
||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
|
||||
updatePromptArea(textToAdd, textarea);
|
||||
}
|
||||
}
|
||||
|
||||
function saveCardPreview(event, tabname, filename) {
|
||||
@ -253,13 +272,219 @@ function saveCardPreview(event, tabname, filename) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
function extraNetworksSearchButton(tabs_id, event) {
|
||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
||||
var button = event.target;
|
||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||
function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Processes `onclick` events when user clicks on files in tree.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param btn The clicked `tree-list-item` button.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
// NOTE: Currently unused.
|
||||
return;
|
||||
}
|
||||
|
||||
searchTextarea.value = text;
|
||||
updateInput(searchTextarea);
|
||||
function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Processes `onclick` events when user clicks on directories in tree.
|
||||
*
|
||||
* Here is how the tree reacts to clicks for various states:
|
||||
* unselected unopened directory: Directory is selected and expanded.
|
||||
* unselected opened directory: Directory is selected.
|
||||
* selected opened directory: Directory is collapsed and deselected.
|
||||
* chevron is clicked: Directory is expanded or collapsed. Selected state unchanged.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param btn The clicked `tree-list-item` button.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
var ul = btn.nextElementSibling;
|
||||
// This is the actual target that the user clicked on within the target button.
|
||||
// We use this to detect if the chevron was clicked.
|
||||
var true_targ = event.target;
|
||||
|
||||
function _expand_or_collapse(_ul, _btn) {
|
||||
// Expands <ul> if it is collapsed, collapses otherwise. Updates button attributes.
|
||||
if (_ul.hasAttribute("hidden")) {
|
||||
_ul.removeAttribute("hidden");
|
||||
_btn.dataset.expanded = "";
|
||||
} else {
|
||||
_ul.setAttribute("hidden", "");
|
||||
delete _btn.dataset.expanded;
|
||||
}
|
||||
}
|
||||
|
||||
function _remove_selected_from_all() {
|
||||
// Removes the `selected` attribute from all buttons.
|
||||
var sels = document.querySelectorAll("div.tree-list-content");
|
||||
[...sels].forEach(el => {
|
||||
delete el.dataset.selected;
|
||||
});
|
||||
}
|
||||
|
||||
function _select_button(_btn) {
|
||||
// Removes `data-selected` attribute from all buttons then adds to passed button.
|
||||
_remove_selected_from_all();
|
||||
_btn.dataset.selected = "";
|
||||
}
|
||||
|
||||
function _update_search(_tabname, _extra_networks_tabname, _search_text) {
|
||||
// Update search input with select button's path.
|
||||
var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search");
|
||||
search_input_elem.value = _search_text;
|
||||
updateInput(search_input_elem);
|
||||
}
|
||||
|
||||
|
||||
// If user clicks on the chevron, then we do not select the folder.
|
||||
if (true_targ.matches(".tree-list-item-action--leading, .tree-list-item-action-chevron")) {
|
||||
_expand_or_collapse(ul, btn);
|
||||
} else {
|
||||
// User clicked anywhere else on the button.
|
||||
if ("selected" in btn.dataset && !(ul.hasAttribute("hidden"))) {
|
||||
// If folder is select and open, collapse and deselect button.
|
||||
_expand_or_collapse(ul, btn);
|
||||
delete btn.dataset.selected;
|
||||
_update_search(tabname, extra_networks_tabname, "");
|
||||
} else if (!(!("selected" in btn.dataset) && !(ul.hasAttribute("hidden")))) {
|
||||
// If folder is open and not selected, then we don't collapse; just select.
|
||||
// NOTE: Double inversion sucks but it is the clearest way to show the branching here.
|
||||
_expand_or_collapse(ul, btn);
|
||||
_select_button(btn, tabname, extra_networks_tabname);
|
||||
_update_search(tabname, extra_networks_tabname, btn.dataset.path);
|
||||
} else {
|
||||
// All other cases, just select the button.
|
||||
_select_button(btn, tabname, extra_networks_tabname);
|
||||
_update_search(tabname, extra_networks_tabname, btn.dataset.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`.
|
||||
*
|
||||
* Determines whether the clicked button in the tree is for a file entry or a directory
|
||||
* then calls the appropriate function.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
var btn = event.currentTarget;
|
||||
var par = btn.parentElement;
|
||||
if (par.dataset.treeEntryType === "file") {
|
||||
extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname);
|
||||
} else {
|
||||
extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname);
|
||||
}
|
||||
}
|
||||
|
||||
function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for the Sort Mode button.
|
||||
*
|
||||
* Modifies the data attributes of the Sort Mode button to cycle between
|
||||
* various sorting modes.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
var curr_mode = event.currentTarget.dataset.sortmode;
|
||||
var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir");
|
||||
var sort_dir = el_sort_dir.dataset.sortdir;
|
||||
if (curr_mode == "path") {
|
||||
event.currentTarget.dataset.sortmode = "name";
|
||||
event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640";
|
||||
event.currentTarget.setAttribute("title", "Sort by filename");
|
||||
} else if (curr_mode == "name") {
|
||||
event.currentTarget.dataset.sortmode = "date_created";
|
||||
event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640";
|
||||
event.currentTarget.setAttribute("title", "Sort by date created");
|
||||
} else if (curr_mode == "date_created") {
|
||||
event.currentTarget.dataset.sortmode = "date_modified";
|
||||
event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640";
|
||||
event.currentTarget.setAttribute("title", "Sort by date modified");
|
||||
} else {
|
||||
event.currentTarget.dataset.sortmode = "path";
|
||||
event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640";
|
||||
event.currentTarget.setAttribute("title", "Sort by path");
|
||||
}
|
||||
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
||||
}
|
||||
|
||||
function extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for the Sort Direction button.
|
||||
*
|
||||
* Modifies the data attributes of the Sort Direction button to cycle between
|
||||
* ascending and descending sort directions.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
if (event.currentTarget.dataset.sortdir == "Ascending") {
|
||||
event.currentTarget.dataset.sortdir = "Descending";
|
||||
event.currentTarget.setAttribute("title", "Sort descending");
|
||||
} else {
|
||||
event.currentTarget.dataset.sortdir = "Ascending";
|
||||
event.currentTarget.setAttribute("title", "Sort ascending");
|
||||
}
|
||||
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
||||
}
|
||||
|
||||
function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for the Tree View button.
|
||||
*
|
||||
* Toggles the tree view in the extra networks pane.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
const tree = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree");
|
||||
const parent = tree.parentElement;
|
||||
let resizeHandle = parent.querySelector('.resize-handle');
|
||||
tree.classList.toggle("hidden");
|
||||
|
||||
if (tree.classList.contains("hidden")) {
|
||||
tree.style.display = 'none';
|
||||
parent.style.display = 'flex';
|
||||
if (resizeHandle) {
|
||||
resizeHandle.style.display = 'none';
|
||||
}
|
||||
} else {
|
||||
tree.style.display = 'block';
|
||||
parent.style.display = 'grid';
|
||||
if (!resizeHandle) {
|
||||
setupResizeHandle(parent);
|
||||
resizeHandle = parent.querySelector('.resize-handle');
|
||||
}
|
||||
resizeHandle.style.display = 'block';
|
||||
}
|
||||
event.currentTarget.classList.toggle("extra-network-control--enabled");
|
||||
}
|
||||
|
||||
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
|
||||
/**
|
||||
* Handles `onclick` events for the Refresh Page button.
|
||||
*
|
||||
* In order to actually call the python functions in `ui_extra_networks.py`
|
||||
* to refresh the page, we created an empty gradio button in that file with an
|
||||
* event handler that refreshes the page. So what this function here does
|
||||
* is it manually raises a `click` event on that button.
|
||||
*
|
||||
* @param event The generated event.
|
||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||
*/
|
||||
var btn_refresh_internal = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_extra_refresh_internal");
|
||||
btn_refresh_internal.dispatchEvent(new Event("click"));
|
||||
}
|
||||
|
||||
var globalPopup = null;
|
||||
@ -303,12 +528,76 @@ function popupId(id) {
|
||||
popup(storedPopupIds[id]);
|
||||
}
|
||||
|
||||
function extraNetworksFlattenMetadata(obj) {
|
||||
const result = {};
|
||||
|
||||
// Convert any stringified JSON objects to actual objects
|
||||
for (const key of Object.keys(obj)) {
|
||||
if (typeof obj[key] === 'string') {
|
||||
try {
|
||||
const parsed = JSON.parse(obj[key]);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
obj[key] = parsed;
|
||||
}
|
||||
} catch (error) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten the object
|
||||
for (const key of Object.keys(obj)) {
|
||||
if (typeof obj[key] === 'object' && obj[key] !== null) {
|
||||
const nested = extraNetworksFlattenMetadata(obj[key]);
|
||||
for (const nestedKey of Object.keys(nested)) {
|
||||
result[`${key}/${nestedKey}`] = nested[nestedKey];
|
||||
}
|
||||
} else {
|
||||
result[key] = obj[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for handling modelspec keys
|
||||
for (const key of Object.keys(result)) {
|
||||
if (key.startsWith("modelspec.")) {
|
||||
result[key.replaceAll(".", "/")] = result[key];
|
||||
delete result[key];
|
||||
}
|
||||
}
|
||||
|
||||
// Add empty keys to designate hierarchy
|
||||
for (const key of Object.keys(result)) {
|
||||
const parts = key.split("/");
|
||||
for (let i = 1; i < parts.length; i++) {
|
||||
const parent = parts.slice(0, i).join("/");
|
||||
if (!result[parent]) {
|
||||
result[parent] = "";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function extraNetworksShowMetadata(text) {
|
||||
try {
|
||||
let parsed = JSON.parse(text);
|
||||
if (parsed && typeof parsed === 'object') {
|
||||
parsed = extraNetworksFlattenMetadata(parsed);
|
||||
const table = createVisualizationTable(parsed, 0);
|
||||
popup(table);
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.eror(error);
|
||||
}
|
||||
|
||||
var elem = document.createElement('pre');
|
||||
elem.classList.add('popup-metadata');
|
||||
elem.textContent = text;
|
||||
|
||||
popup(elem);
|
||||
return;
|
||||
}
|
||||
|
||||
function requestGet(url, data, handler, errorHandler) {
|
||||
@ -337,6 +626,11 @@ function requestGet(url, data, handler, errorHandler) {
|
||||
xhr.send(js);
|
||||
}
|
||||
|
||||
function extraNetworksCopyCardPath(event, path) {
|
||||
navigator.clipboard.writeText(path);
|
||||
event.stopPropagation();
|
||||
}
|
||||
|
||||
function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
||||
var showError = function() {
|
||||
extraNetworksShowMetadata("there was an error getting metadata");
|
||||
@ -398,3 +692,39 @@ window.addEventListener("keydown", function(event) {
|
||||
closePopup();
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Setup custom loading for this script.
|
||||
* We need to wait for all of our HTML to be generated in the extra networks tabs
|
||||
* before we can actually run the `setupExtraNetworks` function.
|
||||
* The `onUiLoaded` function actually runs before all of our extra network tabs are
|
||||
* finished generating. Thus we needed this new method.
|
||||
*
|
||||
*/
|
||||
|
||||
var uiAfterScriptsCallbacks = [];
|
||||
var uiAfterScriptsTimeout = null;
|
||||
var executedAfterScripts = false;
|
||||
|
||||
function scheduleAfterScriptsCallbacks() {
|
||||
clearTimeout(uiAfterScriptsTimeout);
|
||||
uiAfterScriptsTimeout = setTimeout(function() {
|
||||
executeCallbacks(uiAfterScriptsCallbacks);
|
||||
}, 200);
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
var mutationObserver = new MutationObserver(function(m) {
|
||||
let existingSearchfields = gradioApp().querySelectorAll("[id$='_extra_search']").length;
|
||||
let neededSearchfields = gradioApp().querySelectorAll("[id$='_extra_tabs'] > .tab-nav > button").length - 2;
|
||||
|
||||
if (!executedAfterScripts && existingSearchfields >= neededSearchfields) {
|
||||
mutationObserver.disconnect();
|
||||
executedAfterScripts = true;
|
||||
scheduleAfterScriptsCallbacks();
|
||||
}
|
||||
});
|
||||
mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
|
||||
});
|
||||
|
||||
uiAfterScriptsCallbacks.push(setupExtraNetworks);
|
||||
|
@ -33,23 +33,33 @@ function createRow(table, cellName, items) {
|
||||
return res;
|
||||
}
|
||||
|
||||
function showProfile(path, cutoff = 0.05) {
|
||||
requestGet(path, {}, function(data) {
|
||||
function createVisualizationTable(data, cutoff = 0, sort = "") {
|
||||
var table = document.createElement('table');
|
||||
table.className = 'popup-table';
|
||||
|
||||
data.records['total'] = data.total;
|
||||
var keys = Object.keys(data.records).sort(function(a, b) {
|
||||
return data.records[b] - data.records[a];
|
||||
var keys = Object.keys(data);
|
||||
if (sort === "number") {
|
||||
keys = keys.sort(function(a, b) {
|
||||
return data[b] - data[a];
|
||||
});
|
||||
} else {
|
||||
keys = keys.sort();
|
||||
}
|
||||
var items = keys.map(function(x) {
|
||||
return {key: x, parts: x.split('/'), time: data.records[x]};
|
||||
return {key: x, parts: x.split('/'), value: data[x]};
|
||||
});
|
||||
var maxLength = items.reduce(function(a, b) {
|
||||
return Math.max(a, b.parts.length);
|
||||
}, 0);
|
||||
|
||||
var cols = createRow(table, 'th', ['record', 'seconds']);
|
||||
var cols = createRow(
|
||||
table,
|
||||
'th',
|
||||
[
|
||||
cutoff === 0 ? 'key' : 'record',
|
||||
cutoff === 0 ? 'value' : 'seconds'
|
||||
]
|
||||
);
|
||||
cols[0].colSpan = maxLength;
|
||||
|
||||
function arraysEqual(a, b) {
|
||||
@ -60,21 +70,25 @@ function showProfile(path, cutoff = 0.05) {
|
||||
var matching = items.filter(function(x) {
|
||||
return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
|
||||
});
|
||||
var sorted = matching.sort(function(a, b) {
|
||||
return b.time - a.time;
|
||||
if (sort === "number") {
|
||||
matching = matching.sort(function(a, b) {
|
||||
return b.value - a.value;
|
||||
});
|
||||
} else {
|
||||
matching = matching.sort();
|
||||
}
|
||||
var othersTime = 0;
|
||||
var othersList = [];
|
||||
var othersRows = [];
|
||||
var childrenRows = [];
|
||||
sorted.forEach(function(x) {
|
||||
var visible = x.time >= cutoff && !hide;
|
||||
matching.forEach(function(x) {
|
||||
var visible = (cutoff === 0 && !hide) || (x.value >= cutoff && !hide);
|
||||
|
||||
var cells = [];
|
||||
for (var i = 0; i < maxLength; i++) {
|
||||
cells.push(x.parts[i]);
|
||||
}
|
||||
cells.push(x.time.toFixed(3));
|
||||
cells.push(cutoff === 0 ? x.value : x.value.toFixed(3));
|
||||
var cols = createRow(table, 'td', cells);
|
||||
for (i = 0; i < level; i++) {
|
||||
cols[i].className = 'muted';
|
||||
@ -85,10 +99,10 @@ function showProfile(path, cutoff = 0.05) {
|
||||
tr.classList.add("hidden");
|
||||
}
|
||||
|
||||
if (x.time >= cutoff) {
|
||||
if (cutoff === 0 || x.value >= cutoff) {
|
||||
childrenRows.push(tr);
|
||||
} else {
|
||||
othersTime += x.time;
|
||||
othersTime += x.value;
|
||||
othersList.push(x.parts[level]);
|
||||
othersRows.push(tr);
|
||||
}
|
||||
@ -147,6 +161,13 @@ function showProfile(path, cutoff = 0.05) {
|
||||
|
||||
addLevel(0, []);
|
||||
|
||||
return table;
|
||||
}
|
||||
|
||||
function showProfile(path, cutoff = 0.05) {
|
||||
requestGet(path, {}, function(data) {
|
||||
data.records['total'] = data.total;
|
||||
const table = createVisualizationTable(data.records, cutoff, "number");
|
||||
popup(table);
|
||||
});
|
||||
}
|
||||
|
@ -45,8 +45,15 @@ function formatTime(secs) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
var originalAppTitle = undefined;
|
||||
|
||||
onUiLoaded(function() {
|
||||
originalAppTitle = document.title;
|
||||
});
|
||||
|
||||
function setTitle(progress) {
|
||||
var title = 'Stable Diffusion';
|
||||
var title = originalAppTitle;
|
||||
|
||||
if (opts.show_progress_in_title && progress) {
|
||||
title = '[' + progress.trim() + '] ' + title;
|
||||
|
@ -1,8 +1,8 @@
|
||||
(function() {
|
||||
const GRADIO_MIN_WIDTH = 320;
|
||||
const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr';
|
||||
const PAD = 16;
|
||||
const DEBOUNCE_TIME = 100;
|
||||
const DOUBLE_TAP_DELAY = 200; //ms
|
||||
|
||||
const R = {
|
||||
tracking: false,
|
||||
@ -11,6 +11,7 @@
|
||||
leftCol: null,
|
||||
leftColStartWidth: null,
|
||||
screenX: null,
|
||||
lastTapTime: null,
|
||||
};
|
||||
|
||||
let resizeTimer;
|
||||
@ -21,30 +22,29 @@
|
||||
}
|
||||
|
||||
function displayResizeHandle(parent) {
|
||||
if (!parent.needHideOnMoblie) {
|
||||
return true;
|
||||
}
|
||||
if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {
|
||||
parent.style.display = 'flex';
|
||||
if (R.handle != null) {
|
||||
R.handle.style.opacity = '0';
|
||||
}
|
||||
parent.resizeHandle.style.display = "none";
|
||||
return false;
|
||||
} else {
|
||||
parent.style.display = 'grid';
|
||||
if (R.handle != null) {
|
||||
R.handle.style.opacity = '100';
|
||||
}
|
||||
parent.resizeHandle.style.display = "block";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
function afterResize(parent) {
|
||||
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) {
|
||||
if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != parent.style.originalGridTemplateColumns) {
|
||||
const oldParentWidth = R.parentWidth;
|
||||
const newParentWidth = parent.offsetWidth;
|
||||
const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);
|
||||
|
||||
const ratio = newParentWidth / oldParentWidth;
|
||||
|
||||
const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH);
|
||||
const newWidthL = Math.max(Math.floor(ratio * widthL), parent.minLeftColWidth);
|
||||
setLeftColGridTemplate(parent, newWidthL);
|
||||
|
||||
R.parentWidth = newParentWidth;
|
||||
@ -52,6 +52,14 @@
|
||||
}
|
||||
|
||||
function setup(parent) {
|
||||
|
||||
function onDoubleClick(evt) {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
parent.style.gridTemplateColumns = parent.style.originalGridTemplateColumns;
|
||||
}
|
||||
|
||||
const leftCol = parent.firstElementChild;
|
||||
const rightCol = parent.lastElementChild;
|
||||
|
||||
@ -59,14 +67,42 @@
|
||||
|
||||
parent.style.display = 'grid';
|
||||
parent.style.gap = '0';
|
||||
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||
let leftColTemplate = "";
|
||||
if (parent.children[0].style.flexGrow) {
|
||||
leftColTemplate = `${parent.children[0].style.flexGrow}fr`;
|
||||
parent.minLeftColWidth = GRADIO_MIN_WIDTH;
|
||||
parent.minRightColWidth = GRADIO_MIN_WIDTH;
|
||||
parent.needHideOnMoblie = true;
|
||||
} else {
|
||||
leftColTemplate = parent.children[0].style.flexBasis;
|
||||
parent.minLeftColWidth = parent.children[0].style.flexBasis.slice(0, -2) / 2;
|
||||
parent.minRightColWidth = 0;
|
||||
parent.needHideOnMoblie = false;
|
||||
}
|
||||
const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`;
|
||||
parent.style.gridTemplateColumns = gridTemplateColumns;
|
||||
parent.style.originalGridTemplateColumns = gridTemplateColumns;
|
||||
|
||||
const resizeHandle = document.createElement('div');
|
||||
resizeHandle.classList.add('resize-handle');
|
||||
parent.insertBefore(resizeHandle, rightCol);
|
||||
parent.resizeHandle = resizeHandle;
|
||||
|
||||
resizeHandle.addEventListener('mousedown', (evt) => {
|
||||
['mousedown', 'touchstart'].forEach((eventType) => {
|
||||
resizeHandle.addEventListener(eventType, (evt) => {
|
||||
if (eventType.startsWith('mouse')) {
|
||||
if (evt.button !== 0) return;
|
||||
} else {
|
||||
if (evt.changedTouches.length !== 1) return;
|
||||
|
||||
const currentTime = new Date().getTime();
|
||||
if (R.lastTapTime && currentTime - R.lastTapTime <= DOUBLE_TAP_DELAY) {
|
||||
onDoubleClick(evt);
|
||||
return;
|
||||
}
|
||||
|
||||
R.lastTapTime = currentTime;
|
||||
}
|
||||
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
@ -76,37 +112,54 @@
|
||||
R.tracking = true;
|
||||
R.parent = parent;
|
||||
R.parentWidth = parent.offsetWidth;
|
||||
R.handle = resizeHandle;
|
||||
R.leftCol = leftCol;
|
||||
R.leftColStartWidth = leftCol.offsetWidth;
|
||||
if (eventType.startsWith('mouse')) {
|
||||
R.screenX = evt.screenX;
|
||||
} else {
|
||||
R.screenX = evt.changedTouches[0].screenX;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
resizeHandle.addEventListener('dblclick', (evt) => {
|
||||
evt.preventDefault();
|
||||
evt.stopPropagation();
|
||||
|
||||
parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS;
|
||||
});
|
||||
resizeHandle.addEventListener('dblclick', onDoubleClick);
|
||||
|
||||
afterResize(parent);
|
||||
}
|
||||
|
||||
window.addEventListener('mousemove', (evt) => {
|
||||
['mousemove', 'touchmove'].forEach((eventType) => {
|
||||
window.addEventListener(eventType, (evt) => {
|
||||
if (eventType.startsWith('mouse')) {
|
||||
if (evt.button !== 0) return;
|
||||
} else {
|
||||
if (evt.changedTouches.length !== 1) return;
|
||||
}
|
||||
|
||||
if (R.tracking) {
|
||||
if (eventType.startsWith('mouse')) {
|
||||
evt.preventDefault();
|
||||
}
|
||||
evt.stopPropagation();
|
||||
|
||||
const delta = R.screenX - evt.screenX;
|
||||
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH);
|
||||
let delta = 0;
|
||||
if (eventType.startsWith('mouse')) {
|
||||
delta = R.screenX - evt.screenX;
|
||||
} else {
|
||||
delta = R.screenX - evt.changedTouches[0].screenX;
|
||||
}
|
||||
const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - R.parent.minRightColWidth - PAD), R.parent.minLeftColWidth);
|
||||
setLeftColGridTemplate(R.parent, leftColWidth);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
window.addEventListener('mouseup', (evt) => {
|
||||
['mouseup', 'touchend'].forEach((eventType) => {
|
||||
window.addEventListener(eventType, (evt) => {
|
||||
if (eventType.startsWith('mouse')) {
|
||||
if (evt.button !== 0) return;
|
||||
} else {
|
||||
if (evt.changedTouches.length !== 1) return;
|
||||
}
|
||||
|
||||
if (R.tracking) {
|
||||
evt.preventDefault();
|
||||
@ -117,6 +170,7 @@
|
||||
document.body.classList.remove('resizing');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
@ -132,10 +186,15 @@
|
||||
setupResizeHandle = setup;
|
||||
})();
|
||||
|
||||
onUiLoaded(function() {
|
||||
|
||||
function setupAllResizeHandles() {
|
||||
for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {
|
||||
if (!elem.querySelector('.resize-handle')) {
|
||||
if (!elem.querySelector('.resize-handle') && !elem.children[0].classList.contains("hidden")) {
|
||||
setupResizeHandle(elem);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
onUiLoaded(setupAllResizeHandles);
|
||||
|
||||
|
@ -55,8 +55,8 @@ onOptionsChanged(function() {
|
||||
});
|
||||
|
||||
opts._categories.forEach(function(x) {
|
||||
var section = x[0];
|
||||
var category = x[1];
|
||||
var section = localization[x[0]] ?? x[0];
|
||||
var category = localization[x[1]] ?? x[1];
|
||||
|
||||
var span = document.createElement('SPAN');
|
||||
span.textContent = category;
|
||||
|
@ -48,11 +48,6 @@ function setupTokenCounting(id, id_counter, id_button) {
|
||||
var counter = gradioApp().getElementById(id_counter);
|
||||
var textarea = gradioApp().querySelector(`#${id} > label > textarea`);
|
||||
|
||||
if (opts.disable_token_counters) {
|
||||
counter.style.display = "none";
|
||||
return;
|
||||
}
|
||||
|
||||
if (counter.parentElement == prompt.parentElement) {
|
||||
return;
|
||||
}
|
||||
@ -61,15 +56,32 @@ function setupTokenCounting(id, id_counter, id_button) {
|
||||
prompt.parentElement.style.position = "relative";
|
||||
|
||||
var func = onEdit(id, textarea, 800, function() {
|
||||
if (counter.classList.contains("token-counter-visible")) {
|
||||
gradioApp().getElementById(id_button)?.click();
|
||||
}
|
||||
});
|
||||
promptTokenCountUpdateFunctions[id] = func;
|
||||
promptTokenCountUpdateFunctions[id_button] = func;
|
||||
}
|
||||
|
||||
function setupTokenCounters() {
|
||||
setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
|
||||
setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
|
||||
setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
|
||||
setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
|
||||
function toggleTokenCountingVisibility(id, id_counter, id_button) {
|
||||
var counter = gradioApp().getElementById(id_counter);
|
||||
|
||||
counter.style.display = opts.disable_token_counters ? "none" : "block";
|
||||
counter.classList.toggle("token-counter-visible", !opts.disable_token_counters);
|
||||
}
|
||||
|
||||
function runCodeForTokenCounters(fun) {
|
||||
fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
|
||||
fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
|
||||
fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
|
||||
fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
runCodeForTokenCounters(setupTokenCounting);
|
||||
});
|
||||
|
||||
onOptionsChanged(function() {
|
||||
runCodeForTokenCounters(toggleTokenCountingVisibility);
|
||||
});
|
||||
|
@ -119,9 +119,18 @@ function create_submit_args(args) {
|
||||
return res;
|
||||
}
|
||||
|
||||
function setSubmitButtonsVisibility(tabname, showInterrupt, showSkip, showInterrupting) {
|
||||
gradioApp().getElementById(tabname + '_interrupt').style.display = showInterrupt ? "block" : "none";
|
||||
gradioApp().getElementById(tabname + '_skip').style.display = showSkip ? "block" : "none";
|
||||
gradioApp().getElementById(tabname + '_interrupting').style.display = showInterrupting ? "block" : "none";
|
||||
}
|
||||
|
||||
function showSubmitButtons(tabname, show) {
|
||||
gradioApp().getElementById(tabname + '_interrupt').style.display = show ? "none" : "block";
|
||||
gradioApp().getElementById(tabname + '_skip').style.display = show ? "none" : "block";
|
||||
setSubmitButtonsVisibility(tabname, !show, !show, false);
|
||||
}
|
||||
|
||||
function showSubmitInterruptingPlaceholder(tabname) {
|
||||
setSubmitButtonsVisibility(tabname, false, true, true);
|
||||
}
|
||||
|
||||
function showRestoreProgressButton(tabname, show) {
|
||||
@ -150,6 +159,14 @@ function submit() {
|
||||
return res;
|
||||
}
|
||||
|
||||
function submit_txt2img_upscale() {
|
||||
var res = submit(...arguments);
|
||||
|
||||
res[2] = selected_gallery_index();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
function submit_img2img() {
|
||||
showSubmitButtons('img2img', false);
|
||||
|
||||
@ -302,8 +319,6 @@ onAfterUiUpdate(function() {
|
||||
});
|
||||
|
||||
json_elem.parentElement.style.display = "none";
|
||||
|
||||
setupTokenCounters();
|
||||
});
|
||||
|
||||
onOptionsChanged(function() {
|
||||
@ -396,7 +411,7 @@ function switchWidthHeight(tabname) {
|
||||
|
||||
var onEditTimers = {};
|
||||
|
||||
// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
|
||||
// calls func after afterMs milliseconds has passed since the input elem has been edited by user
|
||||
function onEdit(editId, elem, afterMs, func) {
|
||||
var edited = function() {
|
||||
var existingTimer = onEditTimers[editId];
|
||||
|
@ -17,13 +17,13 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin, Image
|
||||
from PIL import PngImagePlugin
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
@ -31,7 +31,7 @@ from typing import Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from contextlib import closing
|
||||
|
||||
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
@ -85,7 +85,7 @@ def decode_base64_to_image(encoding):
|
||||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
image = Image.open(BytesIO(response.content))
|
||||
image = images.read(BytesIO(response.content))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||
@ -93,7 +93,7 @@ def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("data:image/"):
|
||||
encoding = encoding.split(";")[1].split(",")[1]
|
||||
try:
|
||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||
image = images.read(BytesIO(base64.b64decode(encoding)))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||
@ -230,6 +230,7 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||
@ -251,6 +252,24 @@ class Api:
|
||||
self.default_script_arg_txt2img = []
|
||||
self.default_script_arg_img2img = []
|
||||
|
||||
txt2img_script_runner = scripts.scripts_txt2img
|
||||
img2img_script_runner = scripts.scripts_img2img
|
||||
|
||||
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
|
||||
ui.create_ui()
|
||||
|
||||
if not txt2img_script_runner.scripts:
|
||||
txt2img_script_runner.initialize_scripts(False)
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
|
||||
|
||||
if not img2img_script_runner.scripts:
|
||||
img2img_script_runner.initialize_scripts(True)
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||
|
||||
|
||||
|
||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||
if shared.cmd_opts.api_auth:
|
||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||
@ -312,8 +331,13 @@ class Api:
|
||||
script_args[script.args_from:script.args_to] = ui_default_values
|
||||
return script_args
|
||||
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
|
||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
|
||||
script_args = default_script_args.copy()
|
||||
|
||||
if input_script_args is not None:
|
||||
for index, value in input_script_args.items():
|
||||
script_args[index] = value
|
||||
|
||||
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||
if selectable_scripts:
|
||||
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||
@ -335,13 +359,83 @@ class Api:
|
||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||
return script_args
|
||||
|
||||
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
||||
"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
|
||||
|
||||
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
||||
|
||||
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
|
||||
"""
|
||||
|
||||
if not request.infotext:
|
||||
return {}
|
||||
|
||||
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
|
||||
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
|
||||
params = infotext_utils.parse_generation_parameters(request.infotext)
|
||||
|
||||
def get_field_value(field, params):
|
||||
value = field.function(params) if field.function else params.get(field.label)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if field.api in request.__fields__:
|
||||
target_type = request.__fields__[field.api].type_
|
||||
else:
|
||||
target_type = type(field.component.value)
|
||||
|
||||
if target_type == type(None):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
|
||||
value = value.get('value')
|
||||
|
||||
if value is not None and not isinstance(value, target_type):
|
||||
value = target_type(value)
|
||||
|
||||
return value
|
||||
|
||||
for field in possible_fields:
|
||||
if not field.api:
|
||||
continue
|
||||
|
||||
if field.api in set_fields:
|
||||
continue
|
||||
|
||||
value = get_field_value(field, params)
|
||||
if value is not None:
|
||||
setattr(request, field.api, value)
|
||||
|
||||
if request.override_settings is None:
|
||||
request.override_settings = {}
|
||||
|
||||
overridden_settings = infotext_utils.get_override_settings(params)
|
||||
for _, setting_name, value in overridden_settings:
|
||||
if setting_name not in request.override_settings:
|
||||
request.override_settings[setting_name] = value
|
||||
|
||||
if script_runner is not None and mentioned_script_args is not None:
|
||||
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
|
||||
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
|
||||
|
||||
for field, index in script_fields:
|
||||
value = get_field_value(field, params)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
mentioned_script_args[index] = value
|
||||
|
||||
return params
|
||||
|
||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
|
||||
|
||||
script_runner = scripts.scripts_txt2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_txt2img:
|
||||
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
@ -356,12 +450,15 @@ class Api:
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.is_api = True
|
||||
@ -371,12 +468,14 @@ class Api:
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_txt2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
@ -386,6 +485,8 @@ class Api:
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||
task_id = img2imgreq.force_task_id or create_task_id("img2img")
|
||||
|
||||
init_images = img2imgreq.init_images
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
@ -395,11 +496,10 @@ class Api:
|
||||
mask = decode_base64_to_image(mask)
|
||||
|
||||
script_runner = scripts.scripts_img2img
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(True)
|
||||
ui.create_ui()
|
||||
if not self.default_script_arg_img2img:
|
||||
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
|
||||
|
||||
infotext_script_args = {}
|
||||
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||
|
||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||
|
||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||
@ -416,12 +516,15 @@ class Api:
|
||||
args.pop('script_name', None)
|
||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||
args.pop('alwayson_scripts', None)
|
||||
args.pop('infotext', None)
|
||||
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
|
||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||
|
||||
send_images = args.pop('send_images', True)
|
||||
args.pop('save_images', None)
|
||||
|
||||
add_task_to_queue(task_id)
|
||||
|
||||
with self.queue_lock:
|
||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
@ -432,12 +535,14 @@ class Api:
|
||||
|
||||
try:
|
||||
shared.state.begin(job="scripts_img2img")
|
||||
start_task(task_id)
|
||||
if selectable_scripts is not None:
|
||||
p.script_args = script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||
else:
|
||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||
processed = process_images(p)
|
||||
finish_task(task_id)
|
||||
finally:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
@ -480,7 +585,7 @@ class Api:
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||
params = infotext_utils.parse_generation_parameters(geninfo)
|
||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||
@ -511,7 +616,7 @@ class Api:
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
|
||||
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
@ -643,6 +748,10 @@ class Api:
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
|
||||
def refresh_embeddings(self):
|
||||
with self.queue_lock:
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||
|
||||
def refresh_checkpoints(self):
|
||||
with self.queue_lock:
|
||||
shared.refresh_checkpoints()
|
||||
@ -775,7 +884,15 @@ class Api:
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||
uvicorn.run(
|
||||
self.app,
|
||||
host=server_name,
|
||||
port=port,
|
||||
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
|
||||
root_path=root_path,
|
||||
ssl_keyfile=shared.cmd_opts.tls_keyfile,
|
||||
ssl_certfile=shared.cmd_opts.tls_certfile
|
||||
)
|
||||
|
||||
def kill_webui(self):
|
||||
restart.stop_program()
|
||||
|
@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str, "default": None},
|
||||
{"key": "infotext", "type": str, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
{"key": "send_images", "type": bool, "default": True},
|
||||
{"key": "save_images", "type": bool, "default": False},
|
||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||
{"key": "force_task_id", "type": str, "default": None},
|
||||
{"key": "infotext", "type": str, "default": None},
|
||||
]
|
||||
).generate_model()
|
||||
|
||||
|
@ -62,12 +62,11 @@ def cache(subsection):
|
||||
if cache_data is None:
|
||||
with cache_lock:
|
||||
if cache_data is None:
|
||||
if not os.path.isfile(cache_filename):
|
||||
cache_data = {}
|
||||
else:
|
||||
try:
|
||||
with open(cache_filename, "r", encoding="utf8") as file:
|
||||
cache_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
cache_data = {}
|
||||
except Exception:
|
||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
||||
|
@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
|
||||
shared.state.skipped = False
|
||||
shared.state.interrupted = False
|
||||
shared.state.stopping_generation = False
|
||||
shared.state.job_count = 0
|
||||
|
||||
if not add_stats:
|
||||
@ -99,8 +100,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
||||
sys_pct = sys_peak/max(sys_total, 1) * 100
|
||||
|
||||
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
||||
toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
|
||||
toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
|
||||
toltip_r = "Reserved: total amount of video memory allocated by the Torch library "
|
||||
toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity"
|
||||
|
||||
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
||||
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
from modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@ -19,21 +19,21 @@ parser.add_argument("--skip-install", action='store_true', help="launch.py argum
|
||||
parser.add_argument("--dump-sysinfo", action='store_true', help="launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit")
|
||||
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
|
||||
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--data-dir", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||
parser.add_argument("--config", type=normalized_filepath, default=sd_default_config, help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=normalized_filepath, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=normalized_filepath, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--vae-dir", type=normalized_filepath, default=None, help="Path to directory with VAE files")
|
||||
parser.add_argument("--gfpgan-dir", type=normalized_filepath, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=normalized_filepath, help="GFPGAN model file name", default=None)
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
|
||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
|
||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--embeddings-dir", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
|
||||
parser.add_argument("--textual-inversion-templates-dir", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
|
||||
parser.add_argument("--hypernetwork-dir", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
|
||||
parser.add_argument("--localizations-dir", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help="localizations directory")
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
|
||||
@ -48,12 +48,13 @@ parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to g
|
||||
parser.add_argument("--ngrok-region", type=str, help="does not do anything.", default="")
|
||||
parser.add_argument("--ngrok-options", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \'{"authtoken_from_env":true, "basic_auth":"user:password", "oauth_provider":"google", "oauth_allow_emails":"user@asdf.com"}\'', default=dict())
|
||||
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--codeformer-models-path", type=normalized_filepath, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=normalized_filepath, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=normalized_filepath, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=normalized_filepath, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=normalized_filepath, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--dat-models-path", type=normalized_filepath, help="Path to directory with DAT model file(s).", default=os.path.join(models_path, 'DAT'))
|
||||
parser.add_argument("--clip-models-path", type=normalized_filepath, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
|
||||
@ -77,22 +78,24 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
|
||||
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
|
||||
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
|
||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-auth-path", type=normalized_filepath, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
|
||||
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
|
||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
||||
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
@ -118,4 +121,6 @@ parser.add_argument('--api-server-stop', action='store_true', help='enable serve
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
||||
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
||||
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
||||
|
@ -1,276 +0,0 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
@ -1,435 +0,0 @@
|
||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions or [16]
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError('Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
@ -1,132 +1,64 @@
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader, errors
|
||||
from modules.paths import models_path
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||
# I am making a choice to include some files from codeformer to work around this issue.
|
||||
model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_download_name = 'codeformer-v0.1.0.pth'
|
||||
|
||||
codeformer = None
|
||||
# used by e.g. postprocessing_codeformer.py
|
||||
codeformer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
path = modules.paths.paths.get("CodeFormer", None)
|
||||
if path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from torchvision.transforms.functional import normalize
|
||||
from modules.codeformer.codeformer_arch import CodeFormer
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from facelib.detection.retinaface import retinaface
|
||||
|
||||
net_class = CodeFormer
|
||||
|
||||
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
||||
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def __init__(self, dirname):
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
self.cmd_dir = dirname
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
expected_architecture='CodeFormer',
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
def create_models(self):
|
||||
def get_device(self):
|
||||
return devices.device_codeformer
|
||||
|
||||
if self.net is not None and self.face_helper is not None:
|
||||
self.net.to(devices.device_codeformer)
|
||||
return self.net, self.face_helper
|
||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
||||
if len(model_paths) != 0:
|
||||
ckpt_path = model_paths[0]
|
||||
else:
|
||||
print("Unable to load codeformer model.")
|
||||
return None, None
|
||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
def restore(self, np_image, w: float | None = None):
|
||||
if w is None:
|
||||
w = getattr(shared.opts, "code_former_weight", 0.5)
|
||||
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = devices.device_codeformer
|
||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, w=w, adain=True)[0]
|
||||
|
||||
self.net = net
|
||||
self.face_helper = face_helper
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
return net, face_helper
|
||||
|
||||
def send_model_to(self, device):
|
||||
self.net.to(device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, w=None):
|
||||
np_image = np_image[:, :, ::-1]
|
||||
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
self.create_models()
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.send_model_to(devices.device_codeformer)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = self.face_helper.paste_faces_to_input_image()
|
||||
restored_img = restored_img[:, :, ::-1]
|
||||
|
||||
if original_resolution != restored_img.shape[0:2]:
|
||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
return restored_img
|
||||
|
||||
def setup_model(dirname: str) -> None:
|
||||
global codeformer
|
||||
try:
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
|
||||
except Exception:
|
||||
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||
|
||||
# sys.path = stored_sys_path
|
||||
|
79
modules/dat_model.py
Normal file
79
modules/dat_model.py
Normal file
@ -0,0 +1,79 @@
|
||||
import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerDAT(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "DAT"
|
||||
self.user_path = user_path
|
||||
self.scalers = []
|
||||
super().__init__()
|
||||
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
for model in get_dat_models(self):
|
||||
if model.name in opts.dat_enabled_models:
|
||||
self.scalers.append(model)
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load DAT model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="DAT",
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.DAT_tile,
|
||||
tile_overlap=opts.DAT_tile_overlap,
|
||||
)
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
|
||||
def get_dat_models(scaler):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="DAT x2",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x3",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
||||
scale=3,
|
||||
upscaler=scaler,
|
||||
),
|
||||
UpscalerData(
|
||||
name="DAT x4",
|
||||
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
),
|
||||
]
|
@ -3,7 +3,7 @@ import contextlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors, shared
|
||||
from modules import errors, shared, npu_specific
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
@ -23,6 +23,23 @@ def has_mps() -> bool:
|
||||
return mac_specific.has_mps
|
||||
|
||||
|
||||
def cuda_no_autocast(device_id=None) -> bool:
|
||||
if device_id is None:
|
||||
device_id = get_cuda_device_id()
|
||||
return (
|
||||
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_device_id():
|
||||
return (
|
||||
int(shared.cmd_opts.device_id)
|
||||
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||
else 0
|
||||
) or torch.cuda.current_device()
|
||||
|
||||
|
||||
def get_cuda_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"cuda:{shared.cmd_opts.device_id}"
|
||||
@ -40,6 +57,9 @@ def get_optimal_device_name():
|
||||
if has_xpu():
|
||||
return xpu_specific.get_xpu_device_string()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return npu_specific.get_npu_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
@ -67,14 +87,23 @@ def torch_gc():
|
||||
if has_xpu():
|
||||
xpu_specific.torch_xpu_gc()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
torch_npu_set_device()
|
||||
npu_specific.torch_npu_gc()
|
||||
|
||||
|
||||
def torch_npu_set_device():
|
||||
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||
if npu_specific.has_npu:
|
||||
torch.npu.set_device(0)
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||
if cuda_no_autocast():
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@ -84,6 +113,7 @@ def enable_tf32():
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = None
|
||||
device_interrogate: torch.device = None
|
||||
device_gfpgan: torch.device = None
|
||||
@ -92,6 +122,7 @@ device_codeformer: torch.device = None
|
||||
dtype: torch.dtype = torch.float16
|
||||
dtype_vae: torch.dtype = torch.float16
|
||||
dtype_unet: torch.dtype = torch.float16
|
||||
dtype_inference: torch.dtype = torch.float16
|
||||
unet_needs_upcast = False
|
||||
|
||||
|
||||
@ -104,15 +135,89 @@ def cond_cast_float(input):
|
||||
|
||||
|
||||
nv_rng = None
|
||||
patch_module_list = [
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.MultiheadAttention,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
]
|
||||
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = target_dtype
|
||||
for param in self.parameters():
|
||||
if param.dtype != target_dtype:
|
||||
org_dtype = param.dtype
|
||||
break
|
||||
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
if isinstance(result, tuple):
|
||||
result = tuple(
|
||||
i.to(dtype_inference)
|
||||
if isinstance(i, torch.Tensor)
|
||||
else i
|
||||
for i in result
|
||||
)
|
||||
elif isinstance(result, torch.Tensor):
|
||||
result = result.to(dtype_inference)
|
||||
return result
|
||||
return forward_wrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_cast(target_dtype):
|
||||
applied = False
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
continue
|
||||
applied = True
|
||||
org_forward = module_type.forward
|
||||
if module_type == torch.nn.MultiheadAttention:
|
||||
module_type.forward = manual_cast_forward(torch.float32)
|
||||
else:
|
||||
module_type.forward = manual_cast_forward(target_dtype)
|
||||
module_type.org_forward = org_forward
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if applied:
|
||||
for module_type in patch_module_list:
|
||||
if hasattr(module_type, "org_forward"):
|
||||
module_type.forward = module_type.org_forward
|
||||
delattr(module_type, "org_forward")
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
|
||||
|
||||
@ -154,7 +259,7 @@ def test_for_nans(x, where):
|
||||
def first_time_calculation():
|
||||
"""
|
||||
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
||||
spends about 2.7 seconds doing that, at least wih NVidia.
|
||||
spends about 2.7 seconds doing that, at least with NVidia.
|
||||
"""
|
||||
|
||||
x = torch.zeros((1, 1)).to(device, dtype)
|
||||
@ -164,4 +269,3 @@ def first_time_calculation():
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
|
@ -107,8 +107,8 @@ def check_versions():
|
||||
import torch
|
||||
import gradio
|
||||
|
||||
expected_torch_version = "2.0.0"
|
||||
expected_xformers_version = "0.0.20"
|
||||
expected_torch_version = "2.1.2"
|
||||
expected_xformers_version = "0.0.23.post1"
|
||||
expected_gradio_version = "3.41.2"
|
||||
|
||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||
|
@ -1,121 +1,7 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import modelloader, images, devices
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if 'conv_first.weight' in state_dict:
|
||||
crt_net = {}
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
for k in items.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
||||
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
||||
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
||||
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
||||
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
||||
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
||||
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def resrgan2normal(state_dict, nb=23):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||
re8x = 0
|
||||
crt_net = {}
|
||||
items = list(state_dict)
|
||||
|
||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||
|
||||
for k in items.copy():
|
||||
if "rdb" in k:
|
||||
ori_k = k.replace('body.', 'model.1.sub.')
|
||||
ori_k = ori_k.replace('.rdb', '.RDB')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[ori_k] = state_dict[k]
|
||||
items.remove(k)
|
||||
|
||||
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
||||
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
||||
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
||||
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
||||
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
||||
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
||||
|
||||
if 'conv_up3.weight' in state_dict:
|
||||
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
||||
re8x = 3
|
||||
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
||||
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
||||
|
||||
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
||||
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
||||
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
||||
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
||||
|
||||
state_dict = crt_net
|
||||
return state_dict
|
||||
|
||||
|
||||
def infer_params(state_dict):
|
||||
# this code is copied from https://github.com/victorca25/iNNfer
|
||||
scale2x = 0
|
||||
scalemin = 6
|
||||
n_uplayer = 0
|
||||
plus = False
|
||||
|
||||
for block in list(state_dict):
|
||||
parts = block.split(".")
|
||||
n_parts = len(parts)
|
||||
if n_parts == 5 and parts[2] == "sub":
|
||||
nb = int(parts[3])
|
||||
elif n_parts == 3:
|
||||
part_num = int(parts[1])
|
||||
if (part_num > scalemin
|
||||
and parts[0] == "model"
|
||||
and parts[2] == "weight"):
|
||||
scale2x += 1
|
||||
if part_num > n_uplayer:
|
||||
n_uplayer = part_num
|
||||
out_nc = state_dict[block].shape[0]
|
||||
if not plus and "conv1x1" in block:
|
||||
plus = True
|
||||
|
||||
nf = state_dict["model.0.weight"].shape[0]
|
||||
in_nc = state_dict["model.0.weight"].shape[1]
|
||||
out_nc = out_nc
|
||||
scale = 2 ** scale2x
|
||||
|
||||
return in_nc, out_nc, nf, nb, plus, scale
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
|
||||
def do_upscale(self, img, selected_model):
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||
except Exception:
|
||||
errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
return esrgan_upscale(model, img)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if path.startswith("http"):
|
||||
@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
|
||||
else:
|
||||
filename = path
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
if "params_ema" in state_dict:
|
||||
state_dict = state_dict["params_ema"]
|
||||
elif "params" in state_dict:
|
||||
state_dict = state_dict["params"]
|
||||
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
||||
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
||||
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
||||
state_dict = resrgan2normal(state_dict, nb)
|
||||
elif "conv_first.weight" in state_dict:
|
||||
state_dict = mod2normal(state_dict)
|
||||
elif "model.0.weight" not in state_dict:
|
||||
raise Exception("The file is not a recognized ESRGAN model.")
|
||||
|
||||
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
||||
|
||||
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
return modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||
expected_architecture='ESRGAN',
|
||||
)
|
||||
|
||||
|
||||
def esrgan_upscale(model, img):
|
||||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for tiledata in row:
|
||||
x, w, tile = tiledata
|
||||
|
||||
output = upscale_without_tiling(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
@ -1,465 +0,0 @@
|
||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
####################
|
||||
# RRDBNet Generator
|
||||
####################
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
||||
finalact=None, gaussian_noise=False, plus=False):
|
||||
super(RRDBNet, self).__init__()
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
self.resrgan_scale = 0
|
||||
if in_nc % 16 == 0:
|
||||
self.resrgan_scale = 1
|
||||
elif in_nc != 4 and in_nc % 4 == 0:
|
||||
self.resrgan_scale = 2
|
||||
|
||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
||||
|
||||
if upsample_mode == 'upconv':
|
||||
upsample_block = upconv_block
|
||||
elif upsample_mode == 'pixelshuffle':
|
||||
upsample_block = pixelshuffle_block
|
||||
else:
|
||||
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||
if upscale == 3:
|
||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||
else:
|
||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
||||
|
||||
outact = act(finalact) if finalact else None
|
||||
|
||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
||||
*upsampler, HR_conv0, HR_conv1, outact)
|
||||
|
||||
def forward(self, x, outm=None):
|
||||
if self.resrgan_scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
elif self.resrgan_scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
else:
|
||||
feat = x
|
||||
|
||||
return self.model(feat)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(RRDB, self).__init__()
|
||||
# This is for backwards compatibility with existing models
|
||||
if nr == 3:
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus)
|
||||
else:
|
||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
||||
self.RDBs = nn.Sequential(*RDB_list)
|
||||
|
||||
def forward(self, x):
|
||||
if hasattr(self, 'RDB1'):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
else:
|
||||
out = self.RDBs(x)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
"""
|
||||
|
||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
self.noise = GaussianNoise() if gaussian_noise else None
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
|
||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
if mode == 'CNA':
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
||||
spectral_norm=spectral_norm)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
x2 = x2 + self.conv1x1(x)
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
if self.noise:
|
||||
return self.noise(x5.mul(0.2) + x)
|
||||
else:
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
####################
|
||||
# ESRGANplus
|
||||
####################
|
||||
|
||||
class GaussianNoise(nn.Module):
|
||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.is_relative_detach = is_relative_detach
|
||||
self.noise = torch.tensor(0, dtype=torch.float)
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.sigma != 0:
|
||||
self.noise = self.noise.to(x.device)
|
||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||
x = x + sampled_noise
|
||||
return x
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# SRVGGNetCompact
|
||||
####################
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
"""A compact VGG-style network structure for super-resolution.
|
||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
||||
super(SRVGGNetCompact, self).__init__()
|
||||
self.num_in_ch = num_in_ch
|
||||
self.num_out_ch = num_out_ch
|
||||
self.num_feat = num_feat
|
||||
self.num_conv = num_conv
|
||||
self.upscale = upscale
|
||||
self.act_type = act_type
|
||||
|
||||
self.body = nn.ModuleList()
|
||||
# the first conv
|
||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
||||
# the first activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the body structure
|
||||
for _ in range(num_conv):
|
||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
||||
# activation
|
||||
if act_type == 'relu':
|
||||
activation = nn.ReLU(inplace=True)
|
||||
elif act_type == 'prelu':
|
||||
activation = nn.PReLU(num_parameters=num_feat)
|
||||
elif act_type == 'leakyrelu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.body.append(activation)
|
||||
|
||||
# the last conv
|
||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
||||
# upsample
|
||||
self.upsampler = nn.PixelShuffle(upscale)
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
for i in range(0, len(self.body)):
|
||||
out = self.body[i](out)
|
||||
|
||||
out = self.upsampler(out)
|
||||
# add the nearest upsampled image, so that the network learns the residual
|
||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
class Upsample(nn.Module):
|
||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
||||
The input data is assumed to be of the form
|
||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
||||
"""
|
||||
|
||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.size = size
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def extra_repr(self):
|
||||
if self.scale_factor is not None:
|
||||
info = f'scale_factor={self.scale_factor}'
|
||||
else:
|
||||
info = f'size={self.size}'
|
||||
info += f', mode={self.mode}'
|
||||
return info
|
||||
|
||||
|
||||
def pixel_unshuffle(x, scale):
|
||||
""" Pixel unshuffle.
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
|
||||
|
||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
||||
""" Upconv layer """
|
||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
||||
"""Make layers by stacking the same blocks.
|
||||
Args:
|
||||
basic_block (nn.module): nn.module class for basic block. (block)
|
||||
num_basic_block (int): number of blocks. (n_layers)
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
||||
""" activation helper """
|
||||
act_type = act_type.lower()
|
||||
if act_type == 'relu':
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type in ('leakyrelu', 'lrelu'):
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == 'prelu':
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
elif act_type == 'tanh': # [-1, 1] range output
|
||||
layer = nn.Tanh()
|
||||
elif act_type == 'sigmoid': # [0, 1] range output
|
||||
layer = nn.Sigmoid()
|
||||
else:
|
||||
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self, *kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x, *kwargs):
|
||||
return x
|
||||
|
||||
|
||||
def norm(norm_type, nc):
|
||||
""" Return a normalization layer """
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == 'batch':
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
elif norm_type == 'none':
|
||||
def norm_layer(x): return Identity()
|
||||
else:
|
||||
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type, padding):
|
||||
""" padding layer helper """
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == 'reflect':
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
layer = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
""" Elementwise sum the output of a submodule to its input """
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||
spectral_norm=False):
|
||||
""" Conv layer with padding, normalization, activation """
|
||||
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||
padding = padding if pad_type == 'zero' else 0
|
||||
|
||||
if convtype=='PartialConv2D':
|
||||
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='DeformConv2D':
|
||||
from torchvision.ops import DeformConv2d # not tested
|
||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
elif convtype=='Conv3D':
|
||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
else:
|
||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
dilation=dilation, bias=bias, groups=groups)
|
||||
|
||||
if spectral_norm:
|
||||
c = nn.utils.spectral_norm(c)
|
||||
|
||||
a = act(act_type) if act_type else None
|
||||
if 'CNA' in mode:
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == 'NAC':
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
@ -32,7 +32,8 @@ class ExtensionMetadata:
|
||||
self.config = configparser.ConfigParser()
|
||||
|
||||
filepath = os.path.join(path, self.filename)
|
||||
if os.path.isfile(filepath):
|
||||
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
|
||||
# so no need to check whether the file exists beforehand.
|
||||
try:
|
||||
self.config.read(filepath)
|
||||
except Exception:
|
||||
@ -223,13 +224,16 @@ def list_extensions():
|
||||
|
||||
# check for requirements
|
||||
for extension in extensions:
|
||||
if not extension.enabled:
|
||||
continue
|
||||
|
||||
for req in extension.metadata.requires:
|
||||
required_extension = loaded_extensions.get(req)
|
||||
if required_extension is None:
|
||||
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||
continue
|
||||
|
||||
if not extension.enabled:
|
||||
if not required_extension.enabled:
|
||||
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||
continue
|
||||
|
||||
|
@ -60,7 +60,7 @@ class ExtraNetwork:
|
||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||
separated by colon.
|
||||
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
|
||||
in this case, all effects of this extra networks should be disabled.
|
||||
|
||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||
@ -206,7 +206,7 @@ def parse_prompts(prompts):
|
||||
return res, extra_data
|
||||
|
||||
|
||||
def get_user_metadata(filename):
|
||||
def get_user_metadata(filename, lister=None):
|
||||
if filename is None:
|
||||
return {}
|
||||
|
||||
@ -215,7 +215,8 @@ def get_user_metadata(filename):
|
||||
|
||||
metadata = {}
|
||||
try:
|
||||
if os.path.isfile(metadata_filename):
|
||||
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
|
||||
if exists:
|
||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||
metadata = json.load(file)
|
||||
except Exception as e:
|
||||
|
180
modules/face_restoration_utils.py
Normal file
180
modules/face_restoration_utils.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
||||
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
||||
assert img.shape[2] == 3, "image must be RGB"
|
||||
if img.dtype == "float64":
|
||||
img = img.astype("float32")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||
|
||||
|
||||
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
||||
"""
|
||||
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
||||
"""
|
||||
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
assert tensor.dim() == 3, "tensor must be RGB"
|
||||
img_np = tensor.numpy().transpose(1, 2, 0)
|
||||
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
||||
return np.squeeze(img_np, axis=2)
|
||||
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
def create_face_helper(device) -> FaceRestoreHelper:
|
||||
from facexlib.detection import retinaface
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = device
|
||||
return FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def restore_with_face_helper(
|
||||
np_image: np.ndarray,
|
||||
face_helper: FaceRestoreHelper,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||
|
||||
`restore_face` should take a cropped face image and return a restored face image.
|
||||
"""
|
||||
from torchvision.transforms.functional import normalize
|
||||
np_image = np_image[:, :, ::-1]
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
try:
|
||||
logger.debug("Detecting faces...")
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np_image)
|
||||
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||
for cropped_face in face_helper.cropped_faces:
|
||||
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
cropped_face_t = restore_face(cropped_face_t)
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed face-restoration inference', exc_info=True)
|
||||
|
||||
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
||||
restored_face = (restored_face * 255.0).astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
logger.debug("Merging restored faces into image")
|
||||
face_helper.get_inverse_affine(None)
|
||||
img = face_helper.paste_faces_to_input_image()
|
||||
img = img[:, :, ::-1]
|
||||
if original_resolution != img.shape[0:2]:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(0, 0),
|
||||
fx=original_resolution[1] / img.shape[1],
|
||||
fy=original_resolution[0] / img.shape[0],
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
logger.debug("Face restoration complete")
|
||||
finally:
|
||||
face_helper.clean_all()
|
||||
return img
|
||||
|
||||
|
||||
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
net: torch.Module | None
|
||||
model_url: str
|
||||
model_download_name: str
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__()
|
||||
self.net = None
|
||||
self.model_path = model_path
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
@cached_property
|
||||
def face_helper(self) -> FaceRestoreHelper:
|
||||
return create_face_helper(self.get_device())
|
||||
|
||||
def send_model_to(self, device):
|
||||
if self.net:
|
||||
logger.debug("Sending %s to %s", self.net, device)
|
||||
self.net.to(device)
|
||||
if self.face_helper:
|
||||
logger.debug("Sending face helper to %s", device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def get_device(self):
|
||||
raise NotImplementedError("get_device must be implemented by subclasses")
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
raise NotImplementedError("load_net must be implemented by subclasses")
|
||||
|
||||
def restore_with_helper(
|
||||
self,
|
||||
np_image: np.ndarray,
|
||||
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
if self.net is None:
|
||||
self.net = self.load_net()
|
||||
except Exception:
|
||||
logger.warning("Unable to load face-restoration model", exc_info=True)
|
||||
return np_image
|
||||
|
||||
try:
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
|
||||
def patch_facexlib(dirname: str) -> None:
|
||||
import facexlib.detection
|
||||
import facexlib.parsing
|
||||
|
||||
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
||||
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
||||
|
||||
def update_kwargs(kwargs):
|
||||
return dict(kwargs, save_dir=dirname, model_dir=None)
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
@ -1,125 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import facexlib
|
||||
import gfpgan
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import paths, shared, devices, modelloader, errors
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_file_path = None
|
||||
logger = logging.getLogger(__name__)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
model_download_name = "GFPGANv1.4.pth"
|
||||
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
global model_file_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||
return loaded_gfpgan_model
|
||||
|
||||
if gfpgan_constructor is None:
|
||||
return None
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
gfp_models = []
|
||||
for item in models:
|
||||
if 'GFPGAN' in os.path.basename(item):
|
||||
gfp_models.append(item)
|
||||
latest_file = max(gfp_models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model_file_path = model_file
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def send_model_to(model, device):
|
||||
model.gfpgan.to(device)
|
||||
model.face_helper.face_det.to(device)
|
||||
model.face_helper.face_parse.to(device)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
|
||||
send_model_to(model, devices.device_gfpgan)
|
||||
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
model.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
send_model_to(model, devices.cpu)
|
||||
|
||||
return np_image
|
||||
|
||||
|
||||
gfpgan_constructor = None
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
try:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
from gfpgan import GFPGANer
|
||||
from facexlib import detection, parsing # noqa: F401
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
global model_file_path
|
||||
|
||||
facexlib_path = model_path
|
||||
|
||||
if dirname is not None:
|
||||
facexlib_path = dirname
|
||||
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
def restore(self, np_image):
|
||||
return gfpgan_fix_faces(np_image)
|
||||
def get_device(self):
|
||||
return devices.device_gfpgan
|
||||
|
||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
if 'GFPGAN' in os.path.basename(model_path):
|
||||
model = modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||
return model
|
||||
raise ValueError("No GFPGAN model found")
|
||||
|
||||
def restore(self, np_image):
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, return_rgb=False)[0]
|
||||
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
if gfpgan_face_restorer:
|
||||
return gfpgan_face_restorer.restore(np_image)
|
||||
logger.warning("GFPGAN face restorer not set up")
|
||||
return np_image
|
||||
|
||||
|
||||
def setup_model(dirname: str) -> None:
|
||||
global gfpgan_face_restorer
|
||||
|
||||
try:
|
||||
face_restoration_utils.patch_facexlib(dirname)
|
||||
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
||||
shared.face_restorers.append(gfpgan_face_restorer)
|
||||
except Exception:
|
||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||
|
@ -21,7 +21,10 @@ def calculate_sha256(filename):
|
||||
|
||||
def sha256_from_cache(filename, title, use_addnet_hash=False):
|
||||
hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
|
||||
try:
|
||||
ondisk_mtime = os.path.getmtime(filename)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
if title not in hashes:
|
||||
return None
|
||||
|
43
modules/hat_model.py
Normal file
43
modules/hat_model.py
Normal file
@ -0,0 +1,43 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerHAT(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "HAT"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||
name = modelloader.friendly_name(file)
|
||||
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
|
||||
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
except Exception as e:
|
||||
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan) # TODO: should probably be device_hat
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
|
||||
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
|
||||
)
|
||||
|
||||
def load_model(self, path: str):
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Model file {path} not found")
|
||||
return modelloader.load_spandrel_model(
|
||||
path,
|
||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||
expected_architecture='HAT',
|
||||
)
|
@ -95,6 +95,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
devices.torch_npu_set_device()
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
|
@ -12,7 +12,7 @@ import re
|
||||
import numpy as np
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
|
||||
import string
|
||||
import json
|
||||
import hashlib
|
||||
@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||
return grid
|
||||
|
||||
|
||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
||||
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
|
||||
@property
|
||||
def tile_count(self) -> int:
|
||||
"""
|
||||
The total number of tiles in the grid.
|
||||
"""
|
||||
return sum(len(row[2]) for row in self.tiles)
|
||||
|
||||
|
||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
w = image.width
|
||||
h = image.height
|
||||
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
||||
w, h = image.size
|
||||
|
||||
non_overlap_width = tile_w - overlap
|
||||
non_overlap_height = tile_h - overlap
|
||||
@ -316,13 +321,16 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
||||
return res
|
||||
|
||||
|
||||
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||
if not shared.cmd_opts.unix_filenames_sanitization:
|
||||
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||
else:
|
||||
invalid_filename_chars = '/'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||
max_filename_part_length = 128
|
||||
max_filename_part_length = shared.cmd_opts.filenames_max_length
|
||||
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||
|
||||
|
||||
@ -765,7 +773,7 @@ def image_data(data):
|
||||
import gradio as gr
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
image = read(io.BytesIO(data))
|
||||
textinfo, _ = read_info_from_image(image)
|
||||
return textinfo, None
|
||||
except Exception:
|
||||
@ -791,3 +799,31 @@ def flatten(img, bgcolor):
|
||||
img = background
|
||||
|
||||
return img.convert('RGB')
|
||||
|
||||
|
||||
def read(fp, **kwargs):
|
||||
image = Image.open(fp, **kwargs)
|
||||
image = fix_image(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def fix_image(image: Image.Image):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
image = fix_png_transparency(image)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def fix_png_transparency(image: Image.Image):
|
||||
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
|
||||
return image
|
||||
|
||||
image = image.convert("RGBA")
|
||||
return image
|
||||
|
@ -6,8 +6,8 @@ import numpy as np
|
||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||
import gradio as gr
|
||||
|
||||
from modules import images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules import images
|
||||
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.sd_models import get_closet_checkpoint_match
|
||||
@ -21,7 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
output_dir = output_dir.strip()
|
||||
processing.fix_seed(p)
|
||||
|
||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||
batch_images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||
|
||||
is_inpaint_batch = False
|
||||
if inpaint_mask_dir:
|
||||
@ -31,9 +31,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
if is_inpaint_batch:
|
||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||
|
||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||
|
||||
state.job_count = len(images) * p.n_iter
|
||||
state.job_count = len(batch_images) * p.n_iter
|
||||
|
||||
# extract "default" params to use in case getting png info fails
|
||||
prompt = p.prompt
|
||||
@ -46,16 +46,16 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||
batch_results = None
|
||||
discard_further_results = False
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
for i, image in enumerate(batch_images):
|
||||
state.job = f"{i+1} out of {len(batch_images)}"
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
if state.interrupted or state.stopping_generation:
|
||||
break
|
||||
|
||||
try:
|
||||
img = Image.open(image)
|
||||
img = images.read(image)
|
||||
except UnidentifiedImageError as e:
|
||||
print(e)
|
||||
continue
|
||||
@ -86,7 +86,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
# otherwise user has many masks with the same name but different extensions
|
||||
mask_image_path = masks_found[0]
|
||||
|
||||
mask_image = Image.open(mask_image_path)
|
||||
mask_image = images.read(mask_image_path)
|
||||
p.image_mask = mask_image
|
||||
|
||||
if use_png_info:
|
||||
@ -94,8 +94,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
info_img = img
|
||||
if png_info_dir:
|
||||
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||
info_img = Image.open(info_img_path)
|
||||
geninfo, _ = imgutil.read_info_from_image(info_img)
|
||||
info_img = images.read(info_img_path)
|
||||
geninfo, _ = images.read_info_from_image(info_img)
|
||||
parsed_parameters = parse_generation_parameters(geninfo)
|
||||
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||
except Exception:
|
||||
@ -175,9 +175,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
image = None
|
||||
mask = None
|
||||
|
||||
# Use the EXIF orientation of photos taken by smartphones.
|
||||
if image is not None:
|
||||
image = ImageOps.exif_transpose(image)
|
||||
image = images.fix_image(image)
|
||||
mask = images.fix_image(mask)
|
||||
|
||||
if selected_scale_tab == 1 and not is_batch:
|
||||
assert image, "Can't scale by because no image is selected"
|
||||
@ -222,9 +221,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
p.extra_generation_params["Mask blur"] = mask_blur
|
||||
|
||||
with closing(p):
|
||||
if is_batch:
|
||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||
|
@ -4,12 +4,15 @@ import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser
|
||||
from PIL import Image
|
||||
|
||||
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
||||
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
@ -28,6 +31,19 @@ class ParamBinding:
|
||||
self.paste_field_names = paste_field_names or []
|
||||
|
||||
|
||||
class PasteField(tuple):
|
||||
def __new__(cls, component, target, *, api=None):
|
||||
return super().__new__(cls, (component, target))
|
||||
|
||||
def __init__(self, component, target, *, api=None):
|
||||
super().__init__()
|
||||
|
||||
self.api = api
|
||||
self.component = component
|
||||
self.label = target if isinstance(target, str) else None
|
||||
self.function = target if callable(target) else None
|
||||
|
||||
|
||||
paste_fields: dict[str, dict] = {}
|
||||
registered_param_bindings: list[ParamBinding] = []
|
||||
|
||||
@ -67,7 +83,7 @@ def image_from_url_text(filedata):
|
||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||
|
||||
filename = filename.rsplit('?', 1)[0]
|
||||
return Image.open(filename)
|
||||
return images.read(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
@ -79,11 +95,17 @@ def image_from_url_text(filedata):
|
||||
filedata = filedata[len("data:image/png;base64,"):]
|
||||
|
||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||
image = Image.open(io.BytesIO(filedata))
|
||||
image = images.read(io.BytesIO(filedata))
|
||||
return image
|
||||
|
||||
|
||||
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
||||
|
||||
if fields:
|
||||
for i in range(len(fields)):
|
||||
if not isinstance(fields[i], PasteField):
|
||||
fields[i] = PasteField(*fields[i])
|
||||
|
||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
||||
|
||||
# backwards compatibility for existing extensions
|
||||
@ -208,7 +230,7 @@ def restore_old_hires_fix_params(res):
|
||||
res['Hires resize-2'] = height
|
||||
|
||||
|
||||
def parse_generation_parameters(x: str):
|
||||
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
|
||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||
```
|
||||
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
||||
@ -218,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
|
||||
returns a dict with field values
|
||||
"""
|
||||
if skip_fields is None:
|
||||
skip_fields = shared.opts.infotext_skip_pasting
|
||||
|
||||
res = {}
|
||||
|
||||
@ -290,6 +314,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "Hires negative prompt" not in res:
|
||||
res["Hires negative prompt"] = ""
|
||||
|
||||
if "Mask mode" not in res:
|
||||
res["Mask mode"] = "Inpaint masked"
|
||||
|
||||
if "Masked content" not in res:
|
||||
res["Masked content"] = 'original'
|
||||
|
||||
if "Inpaint area" not in res:
|
||||
res["Inpaint area"] = "Whole picture"
|
||||
|
||||
if "Masked area padding" not in res:
|
||||
res["Masked area padding"] = 32
|
||||
|
||||
restore_old_hires_fix_params(res)
|
||||
|
||||
# Missing RNG means the default was set, which is GPU RNG
|
||||
@ -314,8 +350,25 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
||||
if "VAE Decoder" not in res:
|
||||
res["VAE Decoder"] = "Full"
|
||||
|
||||
skip = set(shared.opts.infotext_skip_pasting)
|
||||
res = {k: v for k, v in res.items() if k not in skip}
|
||||
if "FP8 weight" not in res:
|
||||
res["FP8 weight"] = "Disable"
|
||||
|
||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||
res["Cache FP16 weight for LoRA"] = False
|
||||
|
||||
prompt_attention = prompt_parser.parse_prompt_attention(prompt)
|
||||
prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
|
||||
prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
|
||||
if "Emphasis" not in res and prompt_uses_emphasis:
|
||||
res["Emphasis"] = "Original"
|
||||
|
||||
if "Refiner switch by sampling steps" not in res:
|
||||
res["Refiner switch by sampling steps"] = False
|
||||
|
||||
infotext_versions.backcompat(res)
|
||||
|
||||
for key in skip_fields:
|
||||
res.pop(key, None)
|
||||
|
||||
return res
|
||||
|
||||
@ -365,13 +418,57 @@ def create_override_settings_dict(text_pairs):
|
||||
return res
|
||||
|
||||
|
||||
def get_override_settings(params, *, skip_fields=None):
|
||||
"""Returns a list of settings overrides from the infotext parameters dictionary.
|
||||
|
||||
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
|
||||
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
|
||||
|
||||
It checks for conditions before adding an override:
|
||||
- ignores settings that match the current value
|
||||
- ignores parameter keys present in skip_fields argument.
|
||||
|
||||
Example input:
|
||||
{"Clip skip": "2"}
|
||||
|
||||
Example output:
|
||||
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
|
||||
"""
|
||||
|
||||
res = []
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in (skip_fields or {}):
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
res.append((param_name, setting_name, v))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||
def paste_func(prompt):
|
||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||
filename = os.path.join(data_path, "params.txt")
|
||||
if os.path.exists(filename):
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
prompt = file.read()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
params = parse_generation_parameters(prompt)
|
||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||
@ -393,6 +490,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
|
||||
if valtype == bool and v == "False":
|
||||
val = False
|
||||
elif valtype == int:
|
||||
val = float(v)
|
||||
else:
|
||||
val = valtype(v)
|
||||
|
||||
@ -406,29 +505,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||
|
||||
def paste_settings(params):
|
||||
vals = {}
|
||||
vals = get_override_settings(params, skip_fields=already_handled_fields)
|
||||
|
||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||
if param_name in already_handled_fields:
|
||||
continue
|
||||
|
||||
v = params.get(param_name, None)
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||
continue
|
||||
|
||||
v = shared.opts.cast_value(setting_name, v)
|
||||
current_value = getattr(shared.opts, setting_name, None)
|
||||
|
||||
if v == current_value:
|
||||
continue
|
||||
|
||||
vals[param_name] = v
|
||||
|
||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
||||
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
|
||||
|
||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
||||
|
45
modules/infotext_versions.py
Normal file
45
modules/infotext_versions.py
Normal file
@ -0,0 +1,45 @@
|
||||
from modules import shared
|
||||
from packaging import version
|
||||
import re
|
||||
|
||||
|
||||
v160 = version.parse("1.6.0")
|
||||
v170_tsnr = version.parse("v1.7.0-225")
|
||||
v180 = version.parse("1.8.0")
|
||||
|
||||
|
||||
def parse_version(text):
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
m = re.match(r'([^-]+-[^-]+)-.*', text)
|
||||
if m:
|
||||
text = m.group(1)
|
||||
|
||||
try:
|
||||
return version.parse(text)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def backcompat(d):
|
||||
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
|
||||
|
||||
if not shared.opts.auto_backcompat:
|
||||
return
|
||||
|
||||
ver = parse_version(d.get("Version"))
|
||||
if ver is None:
|
||||
return
|
||||
|
||||
if ver < v160 and '[' in d.get('Prompt', ''):
|
||||
d["Old prompt editing timelines"] = True
|
||||
|
||||
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
||||
d["Pad conds v0"] = True
|
||||
|
||||
if ver < v170_tsnr:
|
||||
d["Downcast alphas_cumprod"] = True
|
||||
|
||||
if ver < v180 and d.get('Refiner'):
|
||||
d["Refiner switch by sampling steps"] = True
|
@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from threading import Thread
|
||||
@ -18,6 +19,7 @@ def imports():
|
||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
import gradio # noqa: F401
|
||||
startup_timer.record("import gradio")
|
||||
|
||||
@ -54,9 +56,6 @@ def initialize():
|
||||
initialize_util.configure_sigint_handler()
|
||||
initialize_util.configure_opts_onchange()
|
||||
|
||||
from modules import modelloader
|
||||
modelloader.cleanup_models()
|
||||
|
||||
from modules import sd_models
|
||||
sd_models.setup_model()
|
||||
startup_timer.record("setup SD model")
|
||||
@ -140,16 +139,17 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
"""
|
||||
Accesses shared.sd_model property to load model.
|
||||
After it's available, if it has been loaded before this access by some extension,
|
||||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
its optimization may be None because the list of optimizers has not been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
from modules import devices
|
||||
devices.torch_npu_set_device()
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
from modules import devices
|
||||
devices.first_time_calculation()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
|
@ -177,6 +177,8 @@ def configure_opts_onchange():
|
||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||
startup_timer.record("opts onchange")
|
||||
|
||||
|
||||
|
@ -10,14 +10,14 @@ import torch.hub
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
||||
|
||||
blip_image_eval_size = 384
|
||||
clip_model_name = 'ViT-L/14'
|
||||
|
||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||
|
||||
re_topn = re.compile(r"\.top(\d+)\.")
|
||||
re_topn = re.compile(r"\.top(\d+)$")
|
||||
|
||||
def category_types():
|
||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||
@ -131,7 +131,7 @@ class InterrogateModels:
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = next(self.clip_model.parameters()).dtype
|
||||
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
||||
|
||||
def send_clip_to_ram(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
|
@ -27,8 +27,7 @@ dir_repos = "repositories"
|
||||
# Whether to default to printing command output
|
||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||
|
||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||
|
||||
|
||||
def check_python_version():
|
||||
@ -56,7 +55,7 @@ and delete current Python and "venv" folder in WebUI's directory.
|
||||
|
||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre" if is_windows else ""}
|
||||
|
||||
Use --skip-python-version-check to suppress this warning.
|
||||
""")
|
||||
@ -189,7 +188,7 @@ def git_clone(url, dir, name, commithash=None):
|
||||
return
|
||||
|
||||
try:
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
run(f'"{git}" clone --config core.filemode=false "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
|
||||
except RuntimeError:
|
||||
shutil.rmtree(dir, ignore_errors=True)
|
||||
raise
|
||||
@ -245,11 +244,13 @@ def list_extensions(settings_file):
|
||||
settings = {}
|
||||
|
||||
try:
|
||||
if os.path.isfile(settings_file):
|
||||
with open(settings_file, "r", encoding="utf8") as file:
|
||||
settings = json.load(file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
errors.report("Could not load settings", exc_info=True)
|
||||
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||
|
||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||
@ -314,8 +315,8 @@ def requirements_met(requirements_file):
|
||||
|
||||
|
||||
def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||
if args.use_ipex:
|
||||
if platform.system() == "Windows":
|
||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||
@ -337,21 +338,22 @@ def prepare_environment():
|
||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||
requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', "requirements_npu.txt")
|
||||
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
@ -405,18 +407,14 @@ def prepare_environment():
|
||||
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
|
||||
startup_timer.record("clone repositores")
|
||||
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||
startup_timer.record("install CodeFormer requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file):
|
||||
requirements_file = os.path.join(script_path, requirements_file)
|
||||
|
||||
@ -424,6 +422,13 @@ def prepare_environment():
|
||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||
startup_timer.record("install requirements")
|
||||
|
||||
if not os.path.isfile(requirements_file_for_npu):
|
||||
requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)
|
||||
|
||||
if "torch_npu" in torch_command and not requirements_met(requirements_file_for_npu):
|
||||
run_pip(f"install -r \"{requirements_file_for_npu}\"", "requirements_for_npu")
|
||||
startup_timer.record("install requirements_for_npu")
|
||||
|
||||
if not args.skip_install:
|
||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||
|
||||
|
@ -1,41 +1,58 @@
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TqdmLoggingHandler(logging.Handler):
|
||||
def __init__(self, level=logging.INFO):
|
||||
super().__init__(level)
|
||||
def __init__(self, fallback_handler: logging.Handler):
|
||||
super().__init__()
|
||||
self.fallback_handler = fallback_handler
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
msg = self.format(record)
|
||||
tqdm.write(msg)
|
||||
self.flush()
|
||||
# If there are active tqdm progress bars,
|
||||
# attempt to not interfere with them.
|
||||
if tqdm._instances:
|
||||
tqdm.write(self.format(record))
|
||||
else:
|
||||
self.fallback_handler.emit(record)
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
self.fallback_handler.emit(record)
|
||||
|
||||
TQDM_IMPORTED = True
|
||||
except ImportError:
|
||||
# tqdm does not exist before first launch
|
||||
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||
TQDM_IMPORTED = False
|
||||
TqdmLoggingHandler = None
|
||||
|
||||
|
||||
def setup_logging(loglevel):
|
||||
if loglevel is None:
|
||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||
|
||||
loghandlers = []
|
||||
if not loglevel:
|
||||
return
|
||||
|
||||
if TQDM_IMPORTED:
|
||||
loghandlers.append(TqdmLoggingHandler())
|
||||
if logging.root.handlers:
|
||||
# Already configured, do not interfere
|
||||
return
|
||||
|
||||
if loglevel:
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
handlers=loghandlers
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
)
|
||||
|
||||
if os.environ.get("SD_WEBUI_RICH_LOG"):
|
||||
from rich.logging import RichHandler
|
||||
handler = RichHandler()
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
if TqdmLoggingHandler:
|
||||
handler = TqdmLoggingHandler(handler)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||
logging.root.setLevel(log_level)
|
||||
logging.root.addHandler(handler)
|
||||
|
@ -12,7 +12,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||
# use check `getattr` and try it for compatibility.
|
||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||
def check_for_mps() -> bool:
|
||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||
|
@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
|
||||
|
||||
def get_crop_region(mask, pad=0):
|
||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||
|
||||
h, w = mask.shape
|
||||
|
||||
crop_left = 0
|
||||
for i in range(w):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_left += 1
|
||||
|
||||
crop_right = 0
|
||||
for i in reversed(range(w)):
|
||||
if not (mask[:, i] == 0).all():
|
||||
break
|
||||
crop_right += 1
|
||||
|
||||
crop_top = 0
|
||||
for i in range(h):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_top += 1
|
||||
|
||||
crop_bottom = 0
|
||||
for i in reversed(range(h)):
|
||||
if not (mask[i] == 0).all():
|
||||
break
|
||||
crop_bottom += 1
|
||||
|
||||
return (
|
||||
int(max(crop_left-pad, 0)),
|
||||
int(max(crop_top-pad, 0)),
|
||||
int(min(w - crop_right + pad, w)),
|
||||
int(min(h - crop_bottom + pad, h))
|
||||
)
|
||||
For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
|
||||
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
||||
box = mask_img.getbbox()
|
||||
if box:
|
||||
x1, y1, x2, y2 = box
|
||||
else: # when no box is found
|
||||
x1, y1 = mask_img.size
|
||||
x2 = y2 = 0
|
||||
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
|
||||
|
||||
|
||||
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
||||
|
@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import spandrel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_file_from_url(
|
||||
@ -90,54 +97,6 @@ def friendly_name(file: str):
|
||||
return model_name
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||
# somehow auto-register and just do these things...
|
||||
root_path = script_path
|
||||
src_path = models_path
|
||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||
move_files(src_path, dest_path, ".ckpt")
|
||||
move_files(src_path, dest_path, ".safetensors")
|
||||
src_path = os.path.join(root_path, "ESRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(models_path, "BSRGAN")
|
||||
dest_path = os.path.join(models_path, "ESRGAN")
|
||||
move_files(src_path, dest_path, ".pth")
|
||||
src_path = os.path.join(root_path, "gfpgan")
|
||||
dest_path = os.path.join(models_path, "GFPGAN")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "SwinIR")
|
||||
dest_path = os.path.join(models_path, "SwinIR")
|
||||
move_files(src_path, dest_path)
|
||||
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||
dest_path = os.path.join(models_path, "LDSR")
|
||||
move_files(src_path, dest_path)
|
||||
|
||||
|
||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
try:
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
if os.path.exists(src_path):
|
||||
for file in os.listdir(src_path):
|
||||
fullpath = os.path.join(src_path, file)
|
||||
if os.path.isfile(fullpath):
|
||||
if ext_filter is not None:
|
||||
if ext_filter not in file:
|
||||
continue
|
||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||
try:
|
||||
shutil.move(fullpath, dest_path)
|
||||
except Exception:
|
||||
pass
|
||||
if len(os.listdir(src_path)) == 0:
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
@ -151,7 +110,7 @@ def load_upscalers():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
datas = []
|
||||
data = []
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
|
||||
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||
@ -170,10 +129,41 @@ def load_upscalers():
|
||||
scaler = cls(commandline_model_path)
|
||||
scaler.user_path = commandline_model_path
|
||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||
datas += scaler.scalers
|
||||
data += scaler.scalers
|
||||
|
||||
shared.sd_upscalers = sorted(
|
||||
datas,
|
||||
data,
|
||||
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||
)
|
||||
|
||||
|
||||
def load_spandrel_model(
|
||||
path: str | os.PathLike,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
prefer_half: bool = False,
|
||||
dtype: str | torch.dtype | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
) -> spandrel.ModelDescriptor:
|
||||
import spandrel
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||
logger.warning(
|
||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||
)
|
||||
half = False
|
||||
if prefer_half:
|
||||
if model_descriptor.supports_half:
|
||||
model_descriptor.model.half()
|
||||
half = True
|
||||
else:
|
||||
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||
if dtype:
|
||||
model_descriptor.model.to(dtype=dtype)
|
||||
model_descriptor.model.eval()
|
||||
logger.debug(
|
||||
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||
model_descriptor, path, device, half, dtype,
|
||||
)
|
||||
return model_descriptor
|
||||
|
@ -341,7 +341,7 @@ class DDPM(pl.LightningModule):
|
||||
elif self.parameterization == "x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
||||
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
||||
|
||||
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
||||
|
||||
@ -901,7 +901,7 @@ class LatentDiffusion(DDPM):
|
||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
# hybrid case, cond is exptected to be a dict
|
||||
# hybrid case, cond is expected to be a dict
|
||||
pass
|
||||
else:
|
||||
if not isinstance(cond, list):
|
||||
@ -937,7 +937,7 @@ class LatentDiffusion(DDPM):
|
||||
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
||||
|
||||
elif self.cond_stage_key == 'coordinates_bbox':
|
||||
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
||||
assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
|
||||
|
||||
# assuming padding of unfold is always 0 and its dilation is always 1
|
||||
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
||||
@ -947,7 +947,7 @@ class LatentDiffusion(DDPM):
|
||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||
rescale_latent = 2 ** (num_downs)
|
||||
|
||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
||||
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
||||
|
31
modules/npu_specific.py
Normal file
31
modules/npu_specific.py
Normal file
@ -0,0 +1,31 @@
|
||||
import importlib
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def check_for_npu():
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
import torch_npu
|
||||
|
||||
try:
|
||||
# Will raise a RuntimeError if no NPU is found
|
||||
_ = torch_npu.npu.device_count()
|
||||
return torch.npu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_npu_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"npu:{shared.cmd_opts.device_id}"
|
||||
return "npu:0"
|
||||
|
||||
|
||||
def torch_npu_gc():
|
||||
with torch.npu.device(get_npu_device_string()):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
has_npu = check_for_npu()
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
@ -6,6 +7,7 @@ import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.shared_cmd_options import cmd_opts
|
||||
from modules.paths_internal import script_path
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
@ -91,18 +93,35 @@ class Options:
|
||||
|
||||
if self.data is not None:
|
||||
if key in self.data or key in self.data_labels:
|
||||
|
||||
# Check that settings aren't globally frozen
|
||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||
|
||||
# Get the info related to the setting being changed
|
||||
info = self.data_labels.get(key, None)
|
||||
if info.do_not_save:
|
||||
return
|
||||
|
||||
# Restrict component arguments
|
||||
comp_args = info.component_args if info else None
|
||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted")
|
||||
|
||||
# Check that this section isn't frozen
|
||||
if cmd_opts.freeze_settings_in_sections is not None:
|
||||
frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
|
||||
section_key = info.section[0]
|
||||
section_name = info.section[1]
|
||||
assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
|
||||
|
||||
# Check that this section of the settings isn't frozen
|
||||
if cmd_opts.freeze_specific_settings is not None:
|
||||
frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
|
||||
assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
|
||||
|
||||
# Check shorthand option which disables editing options in "saving-paths"
|
||||
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
||||
raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
|
||||
|
||||
self.data[key] = value
|
||||
return
|
||||
@ -176,9 +195,15 @@ class Options:
|
||||
return type_x == type_y
|
||||
|
||||
def load(self, filename):
|
||||
try:
|
||||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
except Exception:
|
||||
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
|
||||
self.data = {}
|
||||
# 1.6.0 VAE defaults
|
||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||
|
@ -38,7 +38,6 @@ mute_sdxl_imports()
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
]
|
||||
|
@ -4,6 +4,10 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
normalized_filepath = lambda filepath: str(Path(filepath).absolute())
|
||||
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
@ -28,5 +32,6 @@ models_path = os.path.join(data_path, "models")
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||
config_states_dir = os.path.join(script_path, "config_states")
|
||||
default_output_dir = os.path.join(data_path, "output")
|
||||
|
||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
@ -17,10 +17,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
if extras_mode == 1:
|
||||
for img in image_folder:
|
||||
if isinstance(img, Image.Image):
|
||||
image = img
|
||||
image = images.fix_image(img)
|
||||
fn = ''
|
||||
else:
|
||||
image = Image.open(os.path.abspath(img.name))
|
||||
image = images.read(os.path.abspath(img.name))
|
||||
fn = os.path.splitext(img.orig_name)[0]
|
||||
yield image, fn
|
||||
elif extras_mode == 2:
|
||||
@ -56,14 +56,12 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
|
||||
if isinstance(image_placeholder, str):
|
||||
try:
|
||||
image_data = Image.open(image_placeholder)
|
||||
image_data = images.read(image_placeholder)
|
||||
except Exception:
|
||||
continue
|
||||
else:
|
||||
image_data = image_placeholder
|
||||
|
||||
shared.state.assign_current_image(image_data)
|
||||
|
||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||
if parameters:
|
||||
existing_pnginfo["parameters"] = parameters
|
||||
@ -86,22 +84,25 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
||||
basename = ''
|
||||
forced_filename = None
|
||||
|
||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||
|
||||
if opts.enable_pnginfo:
|
||||
pp.image.info = existing_pnginfo
|
||||
pp.image.info["postprocessing"] = infotext
|
||||
|
||||
shared.state.assign_current_image(pp.image)
|
||||
|
||||
if save_output:
|
||||
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
||||
|
||||
if pp.caption:
|
||||
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
||||
if os.path.isfile(caption_filename):
|
||||
existing_caption = ""
|
||||
try:
|
||||
with open(caption_filename, encoding="utf8") as file:
|
||||
existing_caption = file.read().strip()
|
||||
else:
|
||||
existing_caption = ""
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
action = shared.opts.postprocessing_existing_caption_action
|
||||
if action == 'Prepend' and existing_caption:
|
||||
|
@ -16,7 +16,7 @@ from skimage import exposure
|
||||
from typing import Any
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules.rng import slerp # noqa: F401
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||
@ -62,28 +62,37 @@ def apply_color_correction(correction, original_image):
|
||||
return image.convert('RGB')
|
||||
|
||||
|
||||
def apply_overlay(image, paste_loc, index, overlays):
|
||||
if overlays is None or index >= len(overlays):
|
||||
return image
|
||||
|
||||
overlay = overlays[index]
|
||||
|
||||
if paste_loc is not None:
|
||||
def uncrop(image, dest_size, paste_loc):
|
||||
x, y, w, h = paste_loc
|
||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
||||
base_image = Image.new('RGBA', dest_size)
|
||||
image = images.resize_image(1, image, w, h)
|
||||
base_image.paste(image, (x, y))
|
||||
image = base_image
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def apply_overlay(image, paste_loc, overlay):
|
||||
if overlay is None:
|
||||
return image, image.copy()
|
||||
|
||||
if paste_loc is not None:
|
||||
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
|
||||
|
||||
original_denoised_image = image.copy()
|
||||
|
||||
image = image.convert('RGBA')
|
||||
image.alpha_composite(overlay)
|
||||
image = image.convert('RGB')
|
||||
|
||||
return image
|
||||
return image, original_denoised_image
|
||||
|
||||
def create_binary_mask(image):
|
||||
def create_binary_mask(image, round=True):
|
||||
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||
if round:
|
||||
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||
else:
|
||||
image = image.split()[-1].convert("L")
|
||||
else:
|
||||
image = image.convert('L')
|
||||
return image
|
||||
@ -106,6 +115,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||
|
||||
else:
|
||||
sd = sd_model.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||
image_conditioning = images_tensor_to_samples(image_conditioning,
|
||||
approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
|
||||
# Add the fake full 1s mask to the first dimension.
|
||||
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||
image_conditioning = image_conditioning.to(x.dtype)
|
||||
|
||||
return image_conditioning
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||
# Still takes up a bit of memory, but no encoder call.
|
||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||
@ -157,6 +181,7 @@ class StableDiffusionProcessing:
|
||||
token_merging_ratio = 0
|
||||
token_merging_ratio_hr = 0
|
||||
disable_extra_networks: bool = False
|
||||
firstpass_image: Image = None
|
||||
|
||||
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||
script_args_value: list = field(default=None, init=False)
|
||||
@ -308,7 +333,7 @@ class StableDiffusionProcessing:
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
return c_adm
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
@ -320,8 +345,10 @@ class StableDiffusionProcessing:
|
||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||
|
||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
if round_image_mask:
|
||||
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
|
||||
conditioning_mask = torch.round(conditioning_mask)
|
||||
|
||||
else:
|
||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||
|
||||
@ -345,7 +372,7 @@ class StableDiffusionProcessing:
|
||||
|
||||
return image_conditioning
|
||||
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||
source_image = devices.cond_cast_float(source_image)
|
||||
|
||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||
@ -357,11 +384,17 @@ class StableDiffusionProcessing:
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
|
||||
|
||||
if self.sampler.conditioning_key == "crossattn-adm":
|
||||
return self.unclip_image_conditioning(source_image)
|
||||
|
||||
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||
|
||||
@ -422,6 +455,9 @@ class StableDiffusionProcessing:
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
opts.fp8_storage,
|
||||
opts.cache_fp16_weight,
|
||||
opts.emphasis,
|
||||
)
|
||||
|
||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||
@ -596,20 +632,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if check_for_nans:
|
||||
|
||||
try:
|
||||
devices.test_for_nans(sample, "vae")
|
||||
except devices.NansException as e:
|
||||
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
||||
if shared.opts.auto_vae_precision_bfloat16:
|
||||
autofix_dtype = torch.bfloat16
|
||||
autofix_dtype_text = "bfloat16"
|
||||
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
|
||||
autofix_dtype_comment = ""
|
||||
elif shared.opts.auto_vae_precision:
|
||||
autofix_dtype = torch.float32
|
||||
autofix_dtype_text = "32-bit float"
|
||||
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
|
||||
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
else:
|
||||
raise e
|
||||
|
||||
if devices.dtype_vae == autofix_dtype:
|
||||
raise e
|
||||
|
||||
errors.print_error_explanation(
|
||||
"A tensor with all NaNs was produced in VAE.\n"
|
||||
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
||||
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
|
||||
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
|
||||
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
|
||||
)
|
||||
|
||||
devices.dtype_vae = torch.float32
|
||||
devices.dtype_vae = autofix_dtype
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
batch = batch.to(devices.dtype_vae)
|
||||
|
||||
@ -679,12 +728,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
||||
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Denoising strength": p.extra_generation_params.get("Denoising strength"),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||
@ -699,7 +750,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"User": p.user if opts.add_user_name_to_info else None,
|
||||
}
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
||||
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
||||
@ -818,7 +869,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if state.skipped:
|
||||
state.skipped = False
|
||||
|
||||
if state.interrupted:
|
||||
if state.interrupted or state.stopping_generation:
|
||||
break
|
||||
|
||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||
@ -845,6 +896,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
p.setup_conds()
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
# params.txt should be saved after scripts.process_batch, since the
|
||||
# infotext could be modified by that callback
|
||||
# Example: a wildcard processed by process_batch sets an extra model
|
||||
@ -854,19 +909,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
processed = Processed(p, [])
|
||||
file.write(processed.infotext(p, 0))
|
||||
|
||||
p.setup_conds()
|
||||
|
||||
for comment in model_hijack.comments:
|
||||
p.comment(comment)
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
ps = scripts.PostSampleArgs(samples_ddim)
|
||||
p.scripts.post_sample(p, ps)
|
||||
samples_ddim = ps.samples
|
||||
|
||||
if getattr(samples_ddim, 'already_decoded', False):
|
||||
x_samples_ddim = samples_ddim
|
||||
else:
|
||||
@ -922,13 +980,37 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
mask_for_overlay = getattr(p, "mask_for_overlay", None)
|
||||
|
||||
if not shared.opts.overlay_inpaint:
|
||||
overlay_image = None
|
||||
elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
|
||||
overlay_image = p.overlay_images[i]
|
||||
else:
|
||||
overlay_image = None
|
||||
|
||||
if p.scripts is not None:
|
||||
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
|
||||
p.scripts.postprocess_maskoverlay(p, ppmo)
|
||||
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)
|
||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||
image = apply_color_correction(p.color_corrections[i], image)
|
||||
|
||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
# If the intention is to show the output from the model
|
||||
# that is being composited over the original image,
|
||||
# we need to keep the original image around
|
||||
# and use it in the composite step.
|
||||
image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image_after_composite(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if save_samples:
|
||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||
@ -938,16 +1020,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if opts.enable_pnginfo:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||
|
||||
if mask_for_overlay is not None:
|
||||
if opts.return_mask or opts.save_mask:
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
image_mask = mask_for_overlay.convert('RGB')
|
||||
if save_samples and opts.save_mask:
|
||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||
if opts.return_mask:
|
||||
output_images.append(image_mask)
|
||||
|
||||
if opts.return_mask_composite or opts.save_mask_composite:
|
||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||
if save_samples and opts.save_mask_composite:
|
||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||
if opts.return_mask_composite:
|
||||
@ -1025,6 +1108,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
hr_sampler_name: str = None
|
||||
hr_prompt: str = ''
|
||||
hr_negative_prompt: str = ''
|
||||
force_task_id: str = None
|
||||
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
@ -1097,7 +1181,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
if self.hr_checkpoint_name:
|
||||
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||
|
||||
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
|
||||
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||
|
||||
if self.hr_checkpoint_info is None:
|
||||
@ -1124,8 +1210,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if not state.processing_has_refined_job_count:
|
||||
if state.job_count == -1:
|
||||
state.job_count = self.n_iter
|
||||
|
||||
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
||||
if getattr(self, 'txt2img_upscale', False):
|
||||
total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
|
||||
else:
|
||||
total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
|
||||
shared.total_tqdm.updateTotal(total_steps)
|
||||
state.job_count = state.job_count * 2
|
||||
state.processing_has_refined_job_count = True
|
||||
|
||||
@ -1138,12 +1227,39 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
if self.firstpass_image is not None and self.enable_hr:
|
||||
# here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
|
||||
samples = None
|
||||
decoded_samples = torch.asarray(np.expand_dims(image, 0))
|
||||
|
||||
else:
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
image = torch.from_numpy(np.expand_dims(image, axis=0))
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
|
||||
samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
|
||||
decoded_samples = None
|
||||
devices.torch_gc()
|
||||
|
||||
else:
|
||||
# here we generate an image normally
|
||||
|
||||
x = self.rng.next()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
@ -1351,12 +1467,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
mask_blur_x: int = 4
|
||||
mask_blur_y: int = 4
|
||||
mask_blur: int = None
|
||||
mask_round: bool = True
|
||||
inpainting_fill: int = 0
|
||||
inpaint_full_res: bool = True
|
||||
inpaint_full_res_padding: int = 0
|
||||
inpainting_mask_invert: int = 0
|
||||
initial_noise_multiplier: float = None
|
||||
latent_mask: Image = None
|
||||
force_task_id: str = None
|
||||
|
||||
image_mask: Any = field(default=None, init=False)
|
||||
|
||||
@ -1386,6 +1504,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
self.mask_blur_y = value
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||
|
||||
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
@ -1396,10 +1516,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if image_mask is not None:
|
||||
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
||||
# but we still want to support binary masks.
|
||||
image_mask = create_binary_mask(image_mask)
|
||||
image_mask = create_binary_mask(image_mask, round=self.mask_round)
|
||||
|
||||
if self.inpainting_mask_invert:
|
||||
image_mask = ImageOps.invert(image_mask)
|
||||
self.extra_generation_params["Mask mode"] = "Inpaint not masked"
|
||||
|
||||
if self.mask_blur_x > 0:
|
||||
np_mask = np.array(image_mask)
|
||||
@ -1413,16 +1534,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||
image_mask = Image.fromarray(np_mask)
|
||||
|
||||
if self.mask_blur_x > 0 or self.mask_blur_y > 0:
|
||||
self.extra_generation_params["Mask blur"] = self.mask_blur
|
||||
|
||||
if self.inpaint_full_res:
|
||||
self.mask_for_overlay = image_mask
|
||||
mask = image_mask.convert('L')
|
||||
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
||||
crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
|
||||
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
||||
x1, y1, x2, y2 = crop_region
|
||||
|
||||
mask = mask.crop(crop_region)
|
||||
image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
||||
|
||||
self.extra_generation_params["Inpaint area"] = "Only masked"
|
||||
self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
|
||||
else:
|
||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||
np_mask = np.array(image_mask)
|
||||
@ -1442,7 +1569,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
# Save init image
|
||||
if opts.save_init_img:
|
||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
|
||||
|
||||
image = images.flatten(img, opts.img2img_background_color)
|
||||
|
||||
@ -1464,6 +1591,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if self.inpainting_fill != 1:
|
||||
image = masking.fill(image, latent_mask)
|
||||
|
||||
if self.inpainting_fill == 0:
|
||||
self.extra_generation_params["Masked content"] = 'fill'
|
||||
|
||||
if add_color_corrections:
|
||||
self.color_corrections.append(setup_color_correction(image))
|
||||
|
||||
@ -1503,6 +1633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||
latmask = latmask[0]
|
||||
if self.mask_round:
|
||||
latmask = np.around(latmask)
|
||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||
|
||||
@ -1512,10 +1643,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
# this needs to be fixed to be done in sample() using actual seeds for batches
|
||||
if self.inpainting_fill == 2:
|
||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
||||
self.extra_generation_params["Masked content"] = 'latent noise'
|
||||
|
||||
elif self.inpainting_fill == 3:
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
self.extra_generation_params["Masked content"] = 'latent nothing'
|
||||
|
||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
|
||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
x = self.rng.next()
|
||||
@ -1527,7 +1661,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
blended_samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
||||
if self.scripts is not None:
|
||||
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
|
||||
self.scripts.on_mask_blend(self, mba)
|
||||
blended_samples = mba.blended_latent
|
||||
|
||||
samples = blended_samples
|
||||
|
||||
del x
|
||||
devices.torch_gc()
|
||||
|
42
modules/processing_scripts/comments.py
Normal file
42
modules/processing_scripts/comments.py
Normal file
@ -0,0 +1,42 @@
|
||||
from modules import scripts, shared, script_callbacks
|
||||
import re
|
||||
|
||||
|
||||
def strip_comments(text):
|
||||
text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # while line comment
|
||||
text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class ScriptStripComments(scripts.Script):
|
||||
def title(self):
|
||||
return "Comments"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def process(self, p, *args):
|
||||
if not shared.opts.enable_prompt_comments:
|
||||
return
|
||||
|
||||
p.all_prompts = [strip_comments(x) for x in p.all_prompts]
|
||||
p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts]
|
||||
|
||||
p.main_prompt = strip_comments(p.main_prompt)
|
||||
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
|
||||
|
||||
|
||||
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
|
||||
if not shared.opts.enable_prompt_comments:
|
||||
return
|
||||
|
||||
params.prompt = strip_comments(params.prompt)
|
||||
|
||||
|
||||
script_callbacks.on_before_token_counter(before_token_counter)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), {
|
||||
"enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."),
|
||||
}))
|
@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_models
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
|
||||
return None if info is None else info.title
|
||||
|
||||
self.infotext_fields = [
|
||||
(enable_refiner, lambda d: 'Refiner' in d),
|
||||
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||
(refiner_switch_at, 'Refiner switch at'),
|
||||
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||
]
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
@ -3,8 +3,10 @@ import json
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts, ui, errors
|
||||
from modules.infotext_utils import PasteField
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui_components import ToolButton
|
||||
from modules import infotext_utils
|
||||
|
||||
|
||||
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
@ -51,12 +53,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||
|
||||
self.infotext_fields = [
|
||||
(self.seed, "Seed"),
|
||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
(subseed, "Variation seed"),
|
||||
(subseed_strength, "Variation seed strength"),
|
||||
(seed_resize_from_w, "Seed resize from-1"),
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
PasteField(self.seed, "Seed", api="seed"),
|
||||
PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||
PasteField(subseed, "Variation seed", api="subseed"),
|
||||
PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
|
||||
PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
|
||||
PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
|
||||
]
|
||||
|
||||
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||
@ -76,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
p.seed_resize_from_h = seed_resize_from_h
|
||||
|
||||
|
||||
|
||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||
@ -84,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
||||
|
||||
def copy_seed(gen_info_string: str, index):
|
||||
res = -1
|
||||
|
||||
try:
|
||||
gen_info = json.loads(gen_info_string)
|
||||
index -= gen_info.get('index_of_first_image', 0)
|
||||
|
||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||
else:
|
||||
all_seeds = gen_info.get('all_seeds', [-1])
|
||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
infotext = gen_info.get('infotexts')[index]
|
||||
gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
|
||||
res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
|
||||
except Exception:
|
||||
if gen_info_string:
|
||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||
errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
|
||||
|
||||
return [res, gr.update()]
|
||||
|
||||
|
@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
|
||||
from modules.shared import opts
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
from collections import OrderedDict
|
||||
import string
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
current_task = None
|
||||
pending_tasks = {}
|
||||
pending_tasks = OrderedDict()
|
||||
finished_tasks = []
|
||||
recorded_results = []
|
||||
recorded_results_limit = 2
|
||||
@ -34,6 +37,11 @@ def finish_task(id_task):
|
||||
if len(finished_tasks) > 16:
|
||||
finished_tasks.pop(0)
|
||||
|
||||
def create_task_id(task_type):
|
||||
N = 7
|
||||
res = ''.join(random.choices(string.ascii_uppercase +
|
||||
string.digits, k=N))
|
||||
return f"task({task_type}-{res})"
|
||||
|
||||
def record_results(id_task, res):
|
||||
recorded_results.append((id_task, res))
|
||||
@ -44,6 +52,9 @@ def record_results(id_task, res):
|
||||
def add_task_to_queue(id_job):
|
||||
pending_tasks[id_job] = time.time()
|
||||
|
||||
class PendingTasksResponse(BaseModel):
|
||||
size: int = Field(title="Pending task size")
|
||||
tasks: List[str] = Field(title="Pending task ids")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||
@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
|
||||
|
||||
|
||||
def setup_progress_api(app):
|
||||
app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
|
||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||
|
||||
|
||||
def get_pending_tasks():
|
||||
pending_tasks_ids = list(pending_tasks)
|
||||
pending_len = len(pending_tasks_ids)
|
||||
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
|
||||
|
||||
|
||||
def progressapi(req: ProgressRequest):
|
||||
active = req.id_task == current_task
|
||||
queued = req.id_task in pending_tasks
|
||||
|
@ -1,12 +1,9 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
self.name = "RealESRGAN"
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||
from realesrgan import RealESRGANer # noqa: F401
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
scalers = get_realesrgan_models(self)
|
||||
|
||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||
for scaler in scalers:
|
||||
@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
except Exception:
|
||||
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
||||
self.enable = False
|
||||
self.scalers = []
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
if not self.enable:
|
||||
return img
|
||||
@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.local_data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||
)
|
||||
return upscale_with_model(
|
||||
model_descriptor,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
# TODO: `outscale`?
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
for scaler in self.scalers:
|
||||
@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
||||
|
||||
def get_realesrgan_models(scaler):
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
models = [
|
||||
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
|
||||
return [
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
),
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
]
|
||||
return models
|
||||
except Exception:
|
||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||
|
@ -34,7 +34,7 @@ def randn_local(seed, shape):
|
||||
|
||||
|
||||
def randn_like(x):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
@ -48,7 +48,7 @@ def randn_like(x):
|
||||
|
||||
|
||||
def randn_without_seed(shape, generator=None):
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||
|
||||
Use either randn() or manual_seed() to initialize the generator."""
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
from collections import namedtuple
|
||||
@ -41,7 +42,7 @@ class ExtraNoiseParams:
|
||||
|
||||
|
||||
class CFGDenoiserParams:
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
|
||||
self.x = x
|
||||
"""Latent image representation in the process of being denoised"""
|
||||
|
||||
@ -63,6 +64,9 @@ class CFGDenoiserParams:
|
||||
self.text_uncond = text_uncond
|
||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||
|
||||
self.denoiser = denoiser
|
||||
"""Current CFGDenoiser object with processing parameters"""
|
||||
|
||||
|
||||
class CFGDenoisedParams:
|
||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||
@ -103,6 +107,15 @@ class ImageGridLoopParams:
|
||||
self.rows = rows
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BeforeTokenCounterParams:
|
||||
prompt: str
|
||||
steps: int
|
||||
styles: list
|
||||
|
||||
is_positive: bool = True
|
||||
|
||||
|
||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
||||
callback_map = dict(
|
||||
callbacks_app_started=[],
|
||||
@ -125,6 +138,7 @@ callback_map = dict(
|
||||
callbacks_on_reload=[],
|
||||
callbacks_list_optimizers=[],
|
||||
callbacks_list_unets=[],
|
||||
callbacks_before_token_counter=[],
|
||||
)
|
||||
|
||||
|
||||
@ -306,6 +320,14 @@ def list_unets_callback():
|
||||
return res
|
||||
|
||||
|
||||
def before_token_counter_callback(params: BeforeTokenCounterParams):
|
||||
for c in callback_map['callbacks_before_token_counter']:
|
||||
try:
|
||||
c.callback(params)
|
||||
except Exception:
|
||||
report_exception(c, 'before_token_counter')
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
@ -480,3 +502,10 @@ def on_list_unets(callback):
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||
|
||||
|
||||
def on_before_token_counter(callback):
|
||||
"""register a function to be called when UI is counting tokens for a prompt.
|
||||
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
||||
|
||||
add_callback(callback_map['callbacks_before_token_counter'], callback)
|
||||
|
@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
class MaskBlendArgs:
|
||||
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
|
||||
self.current_latent = current_latent
|
||||
self.nmask = nmask
|
||||
self.init_latent = init_latent
|
||||
self.mask = mask
|
||||
self.blended_latent = blended_latent
|
||||
|
||||
self.denoiser = denoiser
|
||||
self.is_final_blend = denoiser is None
|
||||
self.sigma = sigma
|
||||
|
||||
class PostSampleArgs:
|
||||
def __init__(self, samples):
|
||||
self.samples = samples
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
class PostProcessMaskOverlayArgs:
|
||||
def __init__(self, index, mask_for_overlay, overlay_image):
|
||||
self.index = index
|
||||
self.mask_for_overlay = mask_for_overlay
|
||||
self.overlay_image = overlay_image
|
||||
|
||||
class PostprocessBatchListArgs:
|
||||
def __init__(self, images):
|
||||
@ -71,6 +91,9 @@ class Script:
|
||||
setup_for_ui_only = False
|
||||
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||
|
||||
controls = None
|
||||
"""A list of controls returned by the ui()."""
|
||||
|
||||
def title(self):
|
||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||
|
||||
@ -86,7 +109,7 @@ class Script:
|
||||
|
||||
def show(self, is_img2img):
|
||||
"""
|
||||
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
||||
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||
|
||||
This function should return:
|
||||
- False if the script should not be shown in UI at all
|
||||
@ -206,6 +229,25 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
|
||||
"""
|
||||
Called in inpainting mode when the original content is blended with the inpainted content.
|
||||
This is called at every step in the denoising process and once at the end.
|
||||
If is_final_blend is true, this is called for the final blending stage.
|
||||
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def post_sample(self, p, ps: PostSampleArgs, *args):
|
||||
"""
|
||||
Called after the samples have been generated,
|
||||
but before they have been decoded by the VAE, if applicable.
|
||||
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
@ -213,6 +255,22 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
Same as postprocess_image but after inpaint_full_res composite
|
||||
So that it operates on the full image instead of the inpaint_full_res crop region.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
@ -520,7 +578,12 @@ class ScriptRunner:
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_data in auto_processing_scripts + scripts_data:
|
||||
try:
|
||||
script = script_data.script_class()
|
||||
except Exception:
|
||||
errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
|
||||
continue
|
||||
|
||||
script.filename = script_data.path
|
||||
script.is_txt2img = not is_img2img
|
||||
script.is_img2img = is_img2img
|
||||
@ -573,6 +636,7 @@ class ScriptRunner:
|
||||
import modules.api.models as api_models
|
||||
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
script.controls = controls
|
||||
|
||||
if controls is None:
|
||||
return
|
||||
@ -645,6 +709,8 @@ class ScriptRunner:
|
||||
self.setup_ui_for_section(None, self.selectable_scripts)
|
||||
|
||||
def select_script(script_index):
|
||||
if script_index is None:
|
||||
script_index = 0
|
||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||
|
||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||
@ -688,7 +754,7 @@ class ScriptRunner:
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
if script_index == 0 or script_index is None:
|
||||
return None
|
||||
|
||||
script = self.selectable_scripts[script_index-1]
|
||||
@ -767,6 +833,22 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||
|
||||
def post_sample(self, p, ps: PostSampleArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.post_sample(p, ps, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||
|
||||
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.on_mask_blend(p, mba, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
@ -775,6 +857,22 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||
|
||||
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image_after_composite(p, pp, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||
try:
|
||||
@ -841,6 +939,35 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||
|
||||
def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
|
||||
"""Locate an arg of a specific script in script_args and set its value
|
||||
Args:
|
||||
args: all script args of process p, p.script_args
|
||||
script_name: the name target script name to
|
||||
arg_elem_id: the elem_id of the target arg
|
||||
value: the value to set
|
||||
fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
|
||||
Returns:
|
||||
Updated script args
|
||||
when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
|
||||
"""
|
||||
script = next((x for x in self.scripts if x.name == script_name), None)
|
||||
if script is None:
|
||||
raise RuntimeError(f"script {script_name} not found")
|
||||
|
||||
for i, control in enumerate(script.controls):
|
||||
if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
|
||||
index = script.args_from + i
|
||||
|
||||
if isinstance(args, tuple):
|
||||
return args[:index] + (value,) + args[index + 1:]
|
||||
elif isinstance(args, list):
|
||||
args[index] = value
|
||||
return args
|
||||
else:
|
||||
raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
|
||||
raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
|
70
modules/sd_emphasis.py
Normal file
70
modules/sd_emphasis.py
Normal file
@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
|
||||
|
||||
class Emphasis:
|
||||
"""Emphasis class decides how to death with (emphasized:1.1) text in prompts"""
|
||||
|
||||
name: str = "Base"
|
||||
description: str = ""
|
||||
|
||||
tokens: list[list[int]]
|
||||
"""tokens from the chunk of the prompt"""
|
||||
|
||||
multipliers: torch.Tensor
|
||||
"""tensor with multipliers, once for each token"""
|
||||
|
||||
z: torch.Tensor
|
||||
"""output of cond transformers network (CLIP)"""
|
||||
|
||||
def after_transformers(self):
|
||||
"""Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EmphasisNone(Emphasis):
|
||||
name = "None"
|
||||
description = "disable the mechanism entirely and treat (:.1.1) as literal characters"
|
||||
|
||||
|
||||
class EmphasisIgnore(Emphasis):
|
||||
name = "Ignore"
|
||||
description = "treat all empasised words as if they have no emphasis"
|
||||
|
||||
|
||||
class EmphasisOriginal(Emphasis):
|
||||
name = "Original"
|
||||
description = "the original emphasis implementation"
|
||||
|
||||
def after_transformers(self):
|
||||
original_mean = self.z.mean()
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
new_mean = self.z.mean()
|
||||
self.z = self.z * (original_mean / new_mean)
|
||||
|
||||
|
||||
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
||||
name = "No norm"
|
||||
description = "same as original, but without normalization (seems to work better for SDXL)"
|
||||
|
||||
def after_transformers(self):
|
||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||
|
||||
|
||||
def get_current_option(emphasis_option_name):
|
||||
return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)
|
||||
|
||||
|
||||
def get_options_descriptions():
|
||||
return ", ".join(f"{x.name}: {x.description}" for x in options)
|
||||
|
||||
|
||||
options = [
|
||||
EmphasisNone,
|
||||
EmphasisIgnore,
|
||||
EmphasisOriginal,
|
||||
EmphasisOriginalNoNorm,
|
||||
]
|
@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices, sd_hijack
|
||||
from modules import prompt_parser, devices, sd_hijack, sd_emphasis
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class PromptChunk:
|
||||
|
||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||
chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
"""
|
||||
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
||||
converts a batch of token ids (in python lists) into a single tensor with numeric representation of those tokens;
|
||||
All python lists with tokens are assumed to have same length, usually 77.
|
||||
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||
model - can be 768 and 1024.
|
||||
@ -88,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
Returns the list and the total number of tokens in the prompt.
|
||||
"""
|
||||
|
||||
if opts.enable_emphasis:
|
||||
if opts.emphasis != "None":
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
@ -136,7 +136,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
if token == self.comma_token:
|
||||
last_comma = len(chunk.tokens)
|
||||
|
||||
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||
# this is when we are at the end of allotted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
break_location = last_comma + 1
|
||||
@ -206,7 +206,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||
An example shape returned by this function can be: (2, 77, 768).
|
||||
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one element
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
|
||||
@ -230,7 +230,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
for fixes in self.hijack.fixes:
|
||||
for _position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
devices.torch_npu_set_device()
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
zs.append(z)
|
||||
|
||||
@ -249,6 +249,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
@ -274,12 +277,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
|
||||
pooled = getattr(z, 'pooled', None)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
emphasis = sd_emphasis.get_current_option(opts.emphasis)()
|
||||
emphasis.tokens = remade_batch_tokens
|
||||
emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
emphasis.z = z
|
||||
|
||||
emphasis.after_transformers()
|
||||
|
||||
z = emphasis.z
|
||||
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
@ -32,7 +32,7 @@ def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase,
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None
|
||||
mult_change = self.token_mults.get(token) if shared.opts.emphasis != "None" else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
|
@ -11,10 +11,14 @@ class CondFunc:
|
||||
break
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
for attr_name in func_path[i:-1]:
|
||||
resolved_obj = getattr(resolved_obj, attr_name)
|
||||
orig_func = getattr(resolved_obj, func_path[-1])
|
||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||
except AttributeError:
|
||||
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
|
||||
pass
|
||||
self.__init__(orig_func, sub_func, cond_func)
|
||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||
def __init__(self, orig_func, sub_func, cond_func):
|
||||
|
@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
from modules.shared import opts
|
||||
import tomesd
|
||||
import numpy as np
|
||||
|
||||
@ -348,10 +349,28 @@ class SkipWritingToConfig:
|
||||
SkipWritingToConfig.skip = self.previous
|
||||
|
||||
|
||||
def check_fp8(model):
|
||||
if model is None:
|
||||
return None
|
||||
if devices.get_optimal_device_name() == "mps":
|
||||
enable_fp8 = False
|
||||
elif shared.opts.fp8_storage == "Enable":
|
||||
enable_fp8 = True
|
||||
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||
enable_fp8 = True
|
||||
else:
|
||||
enable_fp8 = False
|
||||
return enable_fp8
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
@ -383,6 +402,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
timer.record("apply float()")
|
||||
else:
|
||||
@ -396,7 +416,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
@ -404,6 +428,30 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
apply_alpha_schedule_override(model)
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
@ -505,6 +553,48 @@ def repair_config(sd_config):
|
||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return alphas_bar
|
||||
|
||||
|
||||
def apply_alpha_schedule_override(sd_model, p=None):
|
||||
"""
|
||||
Applies an override to the alpha schedule of the model according to settings.
|
||||
- downcasts the alpha schedule to half precision
|
||||
- rescales the alpha schedule to have zero terminal SNR
|
||||
"""
|
||||
|
||||
if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):
|
||||
return
|
||||
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||
|
||||
if opts.use_downcasted_alpha_bar:
|
||||
if p is not None:
|
||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||
|
||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||
if p is not None:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
@ -651,6 +741,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
else:
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
}
|
||||
|
||||
@ -693,7 +784,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||
If no such model exists, returns None.
|
||||
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||
Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||
"""
|
||||
|
||||
already_loaded = None
|
||||
@ -746,7 +837,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
return None
|
||||
|
||||
|
||||
def reload_model_weights(sd_model=None, info=None):
|
||||
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
checkpoint_info = info or select_checkpoint()
|
||||
|
||||
timer = Timer()
|
||||
@ -758,11 +849,14 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
if check_fp8(sd_model) != devices.fp8:
|
||||
# load from state dict again to prevent extra numerical errors
|
||||
forced_reload = True
|
||||
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
|
||||
return sd_model
|
||||
|
||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||
return sd_model
|
||||
|
||||
if sd_model is not None:
|
||||
@ -793,13 +887,13 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
model_data.set_sd_model(sd_model)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user