mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 11:50:18 +08:00
Merge branch 'release_candidate'
This commit is contained in:
commit
cf2772fab0
@ -74,6 +74,7 @@ module.exports = {
|
|||||||
create_submit_args: "readonly",
|
create_submit_args: "readonly",
|
||||||
restart_reload: "readonly",
|
restart_reload: "readonly",
|
||||||
updateInput: "readonly",
|
updateInput: "readonly",
|
||||||
|
onEdit: "readonly",
|
||||||
//extraNetworks.js
|
//extraNetworks.js
|
||||||
requestGet: "readonly",
|
requestGet: "readonly",
|
||||||
popup: "readonly",
|
popup: "readonly",
|
||||||
|
73
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
73
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -1,25 +1,45 @@
|
|||||||
name: Bug Report
|
name: Bug Report
|
||||||
description: You think somethings is broken in the UI
|
description: You think something is broken in the UI
|
||||||
title: "[Bug]: "
|
title: "[Bug]: "
|
||||||
labels: ["bug-report"]
|
labels: ["bug-report"]
|
||||||
|
|
||||||
body:
|
body:
|
||||||
- type: checkboxes
|
|
||||||
attributes:
|
|
||||||
label: Is there an existing issue for this?
|
|
||||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
|
||||||
options:
|
|
||||||
- label: I have searched the existing issues and checked the recent builds/commits
|
|
||||||
required: true
|
|
||||||
- type: markdown
|
- type: markdown
|
||||||
attributes:
|
attributes:
|
||||||
value: |
|
value: |
|
||||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
> The title of the bug report should be short and descriptive.
|
||||||
|
> Use relevant keywords for searchability.
|
||||||
|
> Do not leave it blank, but also do not put an entire error log in it.
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Checklist
|
||||||
|
description: |
|
||||||
|
Please perform basic debugging to see if extensions or configuration is the cause of the issue.
|
||||||
|
Basic debug procedure
|
||||||
|
1. Disable all third-party extensions - check if extension is the cause
|
||||||
|
2. Update extensions and webui - sometimes things just need to be updated
|
||||||
|
3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
|
||||||
|
4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
|
||||||
|
5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
|
||||||
|
Before making a issue report please, check that the issue hasn't been reported recently.
|
||||||
|
options:
|
||||||
|
- label: The issue exists after disabling all extensions
|
||||||
|
- label: The issue exists on a clean installation of webui
|
||||||
|
- label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
|
||||||
|
- label: The issue exists in the current version of the webui
|
||||||
|
- label: The issue has not been reported before recently
|
||||||
|
- label: The issue has been reported before but has not been fixed yet
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
> Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: what-did
|
id: what-did
|
||||||
attributes:
|
attributes:
|
||||||
label: What happened?
|
label: What happened?
|
||||||
description: Tell us what happened in a very clear and simple way
|
description: Tell us what happened in a very clear and simple way
|
||||||
|
placeholder: |
|
||||||
|
txt2img is not working as intended.
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
@ -27,9 +47,9 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: Steps to reproduce the problem
|
label: Steps to reproduce the problem
|
||||||
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
||||||
value: |
|
placeholder: |
|
||||||
1. Go to ....
|
1. Go to ...
|
||||||
2. Press ....
|
2. Press ...
|
||||||
3. ...
|
3. ...
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
@ -38,13 +58,8 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: What should have happened?
|
label: What should have happened?
|
||||||
description: Tell us what you think the normal behavior should be
|
description: Tell us what you think the normal behavior should be
|
||||||
validations:
|
placeholder: |
|
||||||
required: true
|
WebUI should ...
|
||||||
- type: textarea
|
|
||||||
id: sysinfo
|
|
||||||
attributes:
|
|
||||||
label: Sysinfo
|
|
||||||
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
@ -58,12 +73,25 @@ body:
|
|||||||
- Brave
|
- Brave
|
||||||
- Apple Safari
|
- Apple Safari
|
||||||
- Microsoft Edge
|
- Microsoft Edge
|
||||||
|
- Android
|
||||||
|
- iOS
|
||||||
- Other
|
- Other
|
||||||
|
- type: textarea
|
||||||
|
id: sysinfo
|
||||||
|
attributes:
|
||||||
|
label: Sysinfo
|
||||||
|
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
||||||
|
placeholder: |
|
||||||
|
1. Go to WebUI Settings -> Sysinfo -> Download system info.
|
||||||
|
If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
|
||||||
|
2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
- type: textarea
|
- type: textarea
|
||||||
id: logs
|
id: logs
|
||||||
attributes:
|
attributes:
|
||||||
label: Console logs
|
label: Console logs
|
||||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
|
||||||
render: Shell
|
render: Shell
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
@ -71,4 +99,7 @@ body:
|
|||||||
id: misc
|
id: misc
|
||||||
attributes:
|
attributes:
|
||||||
label: Additional information
|
label: Additional information
|
||||||
description: Please provide us with any relevant additional info or context.
|
description: |
|
||||||
|
Please provide us with any relevant additional info or context.
|
||||||
|
Examples:
|
||||||
|
I have updated my GPU driver recently.
|
||||||
|
2
.github/workflows/on_pull_request.yaml
vendored
2
.github/workflows/on_pull_request.yaml
vendored
@ -20,7 +20,7 @@ jobs:
|
|||||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
# of PyTorch and other dependencies.
|
# of PyTorch and other dependencies.
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: pip install ruff==0.0.272
|
run: pip install ruff==0.1.6
|
||||||
- name: Run Ruff
|
- name: Run Ruff
|
||||||
run: ruff .
|
run: ruff .
|
||||||
lint-js:
|
lint-js:
|
||||||
|
162
CHANGELOG.md
162
CHANGELOG.md
@ -1,3 +1,165 @@
|
|||||||
|
## 1.7.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* settings tab rework: add search field, add categories, split UI settings page into many
|
||||||
|
* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364))
|
||||||
|
* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610))
|
||||||
|
* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568))
|
||||||
|
* option to move prompt from top row into generation parameters
|
||||||
|
* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865))
|
||||||
|
* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692))
|
||||||
|
* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944))
|
||||||
|
* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948))
|
||||||
|
* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170))
|
||||||
|
* remove Train->Preprocessing tab and put all its functionality into Extras tab
|
||||||
|
* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171))
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767))
|
||||||
|
* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))
|
||||||
|
* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838))
|
||||||
|
* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909))
|
||||||
|
* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975))
|
||||||
|
* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253))
|
||||||
|
* write infotext to gif images
|
||||||
|
* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068))
|
||||||
|
* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189))
|
||||||
|
* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444))
|
||||||
|
* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480))
|
||||||
|
* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631))
|
||||||
|
* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459))
|
||||||
|
* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638))
|
||||||
|
* add an option to not print stack traces on ctrl+c
|
||||||
|
* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644))
|
||||||
|
* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733))
|
||||||
|
* added a visible checkbox to input accordion
|
||||||
|
* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826))
|
||||||
|
* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968))
|
||||||
|
* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931))
|
||||||
|
* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009))
|
||||||
|
* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
|
||||||
|
* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
|
||||||
|
* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
|
||||||
|
* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
|
||||||
|
|
||||||
|
### Extensions and API:
|
||||||
|
* update gradio to 3.41.2
|
||||||
|
* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774))
|
||||||
|
* update pnginfo API to return dict with parsed values
|
||||||
|
* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856))
|
||||||
|
* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281))
|
||||||
|
* add an option to choose how to combine hires fix and refiner
|
||||||
|
* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135))
|
||||||
|
* sd_unet support for SDXL
|
||||||
|
* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276))
|
||||||
|
* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266))
|
||||||
|
* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077))
|
||||||
|
* add onEdit function for js and rework token-counter.js to use it
|
||||||
|
* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567))
|
||||||
|
* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463))
|
||||||
|
* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762))
|
||||||
|
* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884))
|
||||||
|
* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059))
|
||||||
|
* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119))
|
||||||
|
* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120))
|
||||||
|
* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063))
|
||||||
|
* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192))
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix pix2pix producing bad results
|
||||||
|
* fix defaults settings page breaking when any of main UI tabs are hidden
|
||||||
|
* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt
|
||||||
|
* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working
|
||||||
|
* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795))
|
||||||
|
* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797))
|
||||||
|
* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792))
|
||||||
|
* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780))
|
||||||
|
* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs
|
||||||
|
* hide --gradio-auth and --api-auth values from /internal/sysinfo report
|
||||||
|
* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819))
|
||||||
|
* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))
|
||||||
|
* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))
|
||||||
|
* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))
|
||||||
|
* get progressbar to display correctly in extensions tab
|
||||||
|
* keep order in list of checkpoints when loading model that doesn't have a checksum
|
||||||
|
* fix inpainting models in txt2img creating black pictures
|
||||||
|
* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876))
|
||||||
|
* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926))
|
||||||
|
* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084))
|
||||||
|
* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995))
|
||||||
|
* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924))
|
||||||
|
* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118))
|
||||||
|
* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412))
|
||||||
|
* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411))
|
||||||
|
* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395))
|
||||||
|
* fix: --sd_model in "Prompts from file or textbox" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302))
|
||||||
|
* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231))
|
||||||
|
* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210))
|
||||||
|
* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178))
|
||||||
|
* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877))
|
||||||
|
* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170))
|
||||||
|
* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139))
|
||||||
|
* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121))
|
||||||
|
* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469))
|
||||||
|
* repair unload sd checkpoint button
|
||||||
|
* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533))
|
||||||
|
* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718))
|
||||||
|
* properly apply sort order for extra network cards when selected from dropdown
|
||||||
|
* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962))
|
||||||
|
* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014))
|
||||||
|
* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156))
|
||||||
|
* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121))
|
||||||
|
* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131))
|
||||||
|
* extras tab batch: actually use original filename
|
||||||
|
* make webui not crash when running with --disable-all-extensions option
|
||||||
|
|
||||||
|
### Other:
|
||||||
|
* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814))
|
||||||
|
* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827))
|
||||||
|
* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842))
|
||||||
|
* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837))
|
||||||
|
* revert SGM noise multiplier change for img2img because it breaks hires fix
|
||||||
|
* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))
|
||||||
|
* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839))
|
||||||
|
* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851))
|
||||||
|
* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929))
|
||||||
|
* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028))
|
||||||
|
* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986))
|
||||||
|
* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213))
|
||||||
|
* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976))
|
||||||
|
* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880))
|
||||||
|
* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119))
|
||||||
|
* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846))
|
||||||
|
* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418))
|
||||||
|
* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372))
|
||||||
|
* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313))
|
||||||
|
* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282))
|
||||||
|
* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229))
|
||||||
|
* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458))
|
||||||
|
* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466))
|
||||||
|
* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475))
|
||||||
|
* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630))
|
||||||
|
* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535))
|
||||||
|
* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991))
|
||||||
|
* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))
|
||||||
|
* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829))
|
||||||
|
* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797))
|
||||||
|
* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855))
|
||||||
|
* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004))
|
||||||
|
* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996))
|
||||||
|
* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977))
|
||||||
|
* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929))
|
||||||
|
* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035))
|
||||||
|
* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084))
|
||||||
|
* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936))
|
||||||
|
* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108))
|
||||||
|
* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957))
|
||||||
|
* alternate implementation for unet forward replacement that does not depend on hijack being applied
|
||||||
|
* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178))
|
||||||
|
* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177))
|
||||||
|
* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181))
|
||||||
|
* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186))
|
||||||
|
|
||||||
## 1.6.1
|
## 1.6.1
|
||||||
|
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
|
12
README.md
12
README.md
@ -88,9 +88,10 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||||
- Now without any bad letters!
|
- Now without any bad letters!
|
||||||
- Load checkpoints in safetensors format
|
- Load checkpoints in safetensors format
|
||||||
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
|
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
|
||||||
- Now with a license!
|
- Now with a license!
|
||||||
- Reorder elements in the UI from settings screen
|
- Reorder elements in the UI from settings screen
|
||||||
|
- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||||
@ -103,7 +104,7 @@ Alternatively, use online services (like Google Colab):
|
|||||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||||
|
|
||||||
### Installation on Windows 10/11 with NVidia-GPUs using release package
|
### Installation on Windows 10/11 with NVidia-GPUs using release package
|
||||||
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
|
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents.
|
||||||
2. Run `update.bat`.
|
2. Run `update.bat`.
|
||||||
3. Run `run.bat`.
|
3. Run `run.bat`.
|
||||||
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
|
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
|
||||||
@ -120,7 +121,9 @@ Alternatively, use online services (like Google Colab):
|
|||||||
# Debian-based:
|
# Debian-based:
|
||||||
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
|
||||||
# Red Hat-based:
|
# Red Hat-based:
|
||||||
sudo dnf install wget git python3
|
sudo dnf install wget git python3 gperftools-libs libglvnd-glx
|
||||||
|
# openSUSE-based:
|
||||||
|
sudo zypper install wget git python3 libtcmalloc4 libglvnd
|
||||||
# Arch-based:
|
# Arch-based:
|
||||||
sudo pacman -S wget git python3
|
sudo pacman -S wget git python3
|
||||||
```
|
```
|
||||||
@ -146,7 +149,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
|
|||||||
## Credits
|
## Credits
|
||||||
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
|
||||||
|
|
||||||
- Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers
|
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||||
@ -173,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- LyCORIS - KohakuBlueleaf
|
- LyCORIS - KohakuBlueleaf
|
||||||
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||||
|
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
73
configs/alt-diffusion-m18-inference.yaml
Normal file
73
configs/alt-diffusion-m18-inference.yaml
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1024
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: modules.xlmr_m18.BertSeriesModelWithTransformation
|
||||||
|
params:
|
||||||
|
name: "XLMR-Large"
|
33
extensions-builtin/Lora/lora_logger.py
Normal file
33
extensions-builtin/Lora/lora_logger.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class ColoredFormatter(logging.Formatter):
|
||||||
|
COLORS = {
|
||||||
|
"DEBUG": "\033[0;36m", # CYAN
|
||||||
|
"INFO": "\033[0;32m", # GREEN
|
||||||
|
"WARNING": "\033[0;33m", # YELLOW
|
||||||
|
"ERROR": "\033[0;31m", # RED
|
||||||
|
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
||||||
|
"RESET": "\033[0m", # RESET COLOR
|
||||||
|
}
|
||||||
|
|
||||||
|
def format(self, record):
|
||||||
|
colored_record = copy.copy(record)
|
||||||
|
levelname = colored_record.levelname
|
||||||
|
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
||||||
|
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
||||||
|
return super().format(colored_record)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("lora")
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler(sys.stdout)
|
||||||
|
handler.setFormatter(
|
||||||
|
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
|
||||||
|
)
|
||||||
|
logger.addHandler(handler)
|
@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
|
|||||||
up = up.reshape(up.size(0), -1)
|
up = up.reshape(up.size(0), -1)
|
||||||
down = down.reshape(down.size(0), -1)
|
down = down.reshape(down.size(0), -1)
|
||||||
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
|
||||||
|
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
||||||
|
'''
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
|
secon value is a value for weight.
|
||||||
|
|
||||||
|
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
'''
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m<n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension%new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m>factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
|
||||||
|
@ -93,6 +93,7 @@ class Network: # LoraModule
|
|||||||
self.unet_multiplier = 1.0
|
self.unet_multiplier = 1.0
|
||||||
self.dyn_dim = None
|
self.dyn_dim = None
|
||||||
self.modules = {}
|
self.modules = {}
|
||||||
|
self.bundle_embeddings = {}
|
||||||
self.mtime = None
|
self.mtime = None
|
||||||
|
|
||||||
self.mentioned_name = None
|
self.mentioned_name = None
|
||||||
|
33
extensions-builtin/Lora/network_glora.py
Normal file
33
extensions-builtin/Lora/network_glora.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
import network
|
||||||
|
|
||||||
|
class ModuleTypeGLora(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||||
|
return NetworkModuleGLora(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||||
|
class NetworkModuleGLora(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.w1a = weights.w["a1.weight"]
|
||||||
|
self.w1b = weights.w["b1.weight"]
|
||||||
|
self.w2a = weights.w["a2.weight"]
|
||||||
|
self.w2b = weights.w["b2.weight"]
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
82
extensions-builtin/Lora/network_oft.py
Normal file
82
extensions-builtin/Lora/network_oft.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import torch
|
||||||
|
import network
|
||||||
|
from lyco_helpers import factorization
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeOFT(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["oft_blocks"]) or all(x in weights.w for x in ["oft_diag"]):
|
||||||
|
return NetworkModuleOFT(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Supports both kohya-ss' implementation of COFT https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
|
||||||
|
# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
|
||||||
|
class NetworkModuleOFT(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.lin_module = None
|
||||||
|
self.org_module: list[torch.Module] = [self.sd_module]
|
||||||
|
|
||||||
|
self.scale = 1.0
|
||||||
|
|
||||||
|
# kohya-ss
|
||||||
|
if "oft_blocks" in weights.w.keys():
|
||||||
|
self.is_kohya = True
|
||||||
|
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||||
|
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||||
|
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||||
|
# LyCORIS
|
||||||
|
elif "oft_diag" in weights.w.keys():
|
||||||
|
self.is_kohya = False
|
||||||
|
self.oft_blocks = weights.w["oft_diag"]
|
||||||
|
# self.alpha is unused
|
||||||
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
self.out_dim = self.sd_module.out_features
|
||||||
|
elif is_conv:
|
||||||
|
self.out_dim = self.sd_module.out_channels
|
||||||
|
elif is_other_linear:
|
||||||
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
self.constraint = self.alpha * self.out_dim
|
||||||
|
self.num_blocks = self.dim
|
||||||
|
self.block_size = self.out_dim // self.dim
|
||||||
|
else:
|
||||||
|
self.constraint = None
|
||||||
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
if self.is_kohya:
|
||||||
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
|
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
|
merged_weight = torch.einsum(
|
||||||
|
'k n m, k n ... -> k m ...',
|
||||||
|
R,
|
||||||
|
merged_weight
|
||||||
|
)
|
||||||
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
|
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
@ -5,16 +5,21 @@ import re
|
|||||||
import lora_patches
|
import lora_patches
|
||||||
import network
|
import network
|
||||||
import network_lora
|
import network_lora
|
||||||
|
import network_glora
|
||||||
import network_hada
|
import network_hada
|
||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
import network_norm
|
import network_norm
|
||||||
|
import network_oft
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||||
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
|
|
||||||
|
from lora_logger import logger
|
||||||
|
|
||||||
module_types = [
|
module_types = [
|
||||||
network_lora.ModuleTypeLora(),
|
network_lora.ModuleTypeLora(),
|
||||||
@ -23,6 +28,8 @@ module_types = [
|
|||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
network_norm.ModuleTypeNorm(),
|
network_norm.ModuleTypeNorm(),
|
||||||
|
network_glora.ModuleTypeGLora(),
|
||||||
|
network_oft.ModuleTypeOFT(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -149,9 +156,20 @@ def load_network(name, network_on_disk):
|
|||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||||
|
|
||||||
matched_networks = {}
|
matched_networks = {}
|
||||||
|
bundle_embeddings = {}
|
||||||
|
|
||||||
for key_network, weight in sd.items():
|
for key_network, weight in sd.items():
|
||||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
key_network_without_network_parts, _, network_part = key_network.partition(".")
|
||||||
|
|
||||||
|
if key_network_without_network_parts == "bundle_emb":
|
||||||
|
emb_name, vec_name = network_part.split(".", 1)
|
||||||
|
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||||
|
if vec_name.split('.')[0] == 'string_to_param':
|
||||||
|
_, k2 = vec_name.split('.', 1)
|
||||||
|
emb_dict['string_to_param'] = {k2: weight}
|
||||||
|
else:
|
||||||
|
emb_dict[vec_name] = weight
|
||||||
|
bundle_embeddings[emb_name] = emb_dict
|
||||||
|
|
||||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
@ -174,6 +192,17 @@ def load_network(name, network_on_disk):
|
|||||||
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# kohya_ss OFT module
|
||||||
|
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
|
||||||
|
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# KohakuBlueLeaf OFT module
|
||||||
|
if sd_module is None and "oft_diag" in key:
|
||||||
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||||
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match[key_network] = key
|
keys_failed_to_match[key_network] = key
|
||||||
continue
|
continue
|
||||||
@ -195,6 +224,14 @@ def load_network(name, network_on_disk):
|
|||||||
|
|
||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
|
embeddings = {}
|
||||||
|
for emb_name, data in bundle_embeddings.items():
|
||||||
|
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
|
||||||
|
embedding.loaded = None
|
||||||
|
embeddings[emb_name] = embedding
|
||||||
|
|
||||||
|
net.bundle_embeddings = embeddings
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
@ -210,11 +247,15 @@ def purge_networks_from_memory():
|
|||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
|
emb_db = sd_hijack.model_hijack.embedding_db
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
if net.name in names:
|
if net.name in names:
|
||||||
already_loaded[net.name] = net
|
already_loaded[net.name] = net
|
||||||
|
for emb_name, embedding in net.bundle_embeddings.items():
|
||||||
|
if embedding.loaded:
|
||||||
|
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
||||||
|
|
||||||
loaded_networks.clear()
|
loaded_networks.clear()
|
||||||
|
|
||||||
@ -257,6 +298,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
|
for emb_name, embedding in net.bundle_embeddings.items():
|
||||||
|
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
|
||||||
|
logger.warning(
|
||||||
|
f'Skip bundle embedding: "{emb_name}"'
|
||||||
|
' as it was already loaded from embeddings folder'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
embedding.loaded = False
|
||||||
|
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
|
||||||
|
embedding.loaded = True
|
||||||
|
emb_db.register_embedding(embedding, shared.sd_model)
|
||||||
|
else:
|
||||||
|
emb_db.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
@ -418,6 +474,7 @@ def network_forward(module, input, original_forward):
|
|||||||
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
self.network_current_names = ()
|
self.network_current_names = ()
|
||||||
self.network_weights_backup = None
|
self.network_weights_backup = None
|
||||||
|
self.network_bias_backup = None
|
||||||
|
|
||||||
|
|
||||||
def network_Linear_forward(self, input):
|
def network_Linear_forward(self, input):
|
||||||
@ -564,6 +621,7 @@ extra_network_lora = None
|
|||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
loaded_networks = []
|
loaded_networks = []
|
||||||
|
loaded_bundle_embeddings = {}
|
||||||
networks_in_memory = {}
|
networks_in_memory = {}
|
||||||
available_network_hash_lookup = {}
|
available_network_hash_lookup = {}
|
||||||
forbidden_network_aliases = {}
|
forbidden_network_aliases = {}
|
||||||
|
@ -17,6 +17,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
lora_on_disk = networks.available_networks.get(name)
|
lora_on_disk = networks.available_networks.get(name)
|
||||||
|
if lora_on_disk is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
@ -66,9 +68,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return item
|
return item
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(networks.available_networks):
|
# instantiate a list to protect against concurrent modification
|
||||||
|
names = list(networks.available_networks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
item = self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
|
||||||
if item is not None:
|
if item is not None:
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
self.setting_names = []
|
self.setting_names = []
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
|
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||||
|
|
||||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group():
|
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||||
|
|
||||||
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)
|
||||||
|
|
||||||
@ -64,11 +65,14 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
p.override_settings[name] = value
|
p.override_settings[name] = value
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('ui', "User interface"), {
|
shared.options_templates.update(shared.options_section(('settings_in_ui', "Settings in UI", "ui"), {
|
||||||
"extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
"settings_in_ui": shared.OptionHTML("""
|
||||||
"extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
This page allows you to add some settings to the main interface of txt2img and img2img tabs.
|
||||||
"extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(),
|
"""),
|
||||||
"extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui()
|
"extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(),
|
||||||
|
"extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(),
|
||||||
|
"extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(),
|
||||||
|
"extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
351
extensions-builtin/hypertile/hypertile.py
Normal file
351
extensions-builtin/hypertile/hypertile.py
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
"""
|
||||||
|
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||||
|
Warn: The patch works well only if the input image has a width and height that are multiples of 128
|
||||||
|
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from functools import wraps, cache
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch.nn as nn
|
||||||
|
import random
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HypertileParams:
|
||||||
|
depth = 0
|
||||||
|
layer_name = ""
|
||||||
|
tile_size: int = 0
|
||||||
|
swap_size: int = 0
|
||||||
|
aspect_ratio: float = 1.0
|
||||||
|
forward = None
|
||||||
|
enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO add SD-XL layers
|
||||||
|
DEPTH_LAYERS = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.9.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.10.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.11.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
"decoder.mid.attn_1",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.6.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
3: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# XL layers, thanks for GitHub@gel-crabs for the help
|
||||||
|
DEPTH_LAYERS_XL = {
|
||||||
|
0: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 VAE
|
||||||
|
"decoder.mid_block.attentions.0",
|
||||||
|
"decoder.mid.attn_1",
|
||||||
|
],
|
||||||
|
1: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||||
|
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"input_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.3.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.4.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.5.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.0.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.0.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.1.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.1.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.2.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.2.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.3.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.3.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.4.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.4.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.5.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.5.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.6.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.6.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.7.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.7.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.8.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.8.attn1",
|
||||||
|
"input_blocks.7.1.transformer_blocks.9.attn1",
|
||||||
|
"input_blocks.8.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.0.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.1.1.transformer_blocks.9.attn1",
|
||||||
|
"output_blocks.2.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
2: [
|
||||||
|
# SD 1.5 U-Net (diffusers)
|
||||||
|
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||||
|
# SD 1.5 U-Net (ldm)
|
||||||
|
"middle_block.1.transformer_blocks.0.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.1.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.2.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.3.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.4.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.5.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.6.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.7.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.8.attn1",
|
||||||
|
"middle_block.1.transformer_blocks.9.attn1",
|
||||||
|
],
|
||||||
|
3 : [] # TODO - separate layers for SD-XL
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
RNG_INSTANCE = random.Random()
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:
|
||||||
|
"""
|
||||||
|
Returns divisors of value that
|
||||||
|
x * min_value <= value
|
||||||
|
in big -> small order, amount of divisors is limited by max_options
|
||||||
|
"""
|
||||||
|
max_options = max(1, max_options) # at least 1 option should be returned
|
||||||
|
min_value = min(min_value, value)
|
||||||
|
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
|
||||||
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
|
||||||
|
return ns
|
||||||
|
|
||||||
|
|
||||||
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
|
"""
|
||||||
|
Returns a random divisor of value that
|
||||||
|
x * min_value <= value
|
||||||
|
if max_options is 1, the behavior is deterministic
|
||||||
|
"""
|
||||||
|
ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors
|
||||||
|
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
|
||||||
|
|
||||||
|
return ns[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def set_hypertile_seed(seed: int) -> None:
|
||||||
|
RNG_INSTANCE.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def largest_tile_size_available(width: int, height: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the largest tile size available for a given width and height
|
||||||
|
Tile size is always a power of 2
|
||||||
|
"""
|
||||||
|
gcd = math.gcd(width, height)
|
||||||
|
largest_tile_size_available = 1
|
||||||
|
while gcd % (largest_tile_size_available * 2) == 0:
|
||||||
|
largest_tile_size_available *= 2
|
||||||
|
return largest_tile_size_available
|
||||||
|
|
||||||
|
|
||||||
|
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
We check all possible divisors of hw and return the closest to the aspect ratio
|
||||||
|
"""
|
||||||
|
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
||||||
|
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
||||||
|
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
||||||
|
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
||||||
|
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||||
|
return closest_pair
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
"""
|
||||||
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
# find h and w such that h*w = hw and h/w = aspect_ratio
|
||||||
|
if h * w != hw:
|
||||||
|
w_candidate = hw / h
|
||||||
|
# check if w is an integer
|
||||||
|
if not w_candidate.is_integer():
|
||||||
|
h_candidate = hw / w
|
||||||
|
# check if h is an integer
|
||||||
|
if not h_candidate.is_integer():
|
||||||
|
return iterative_closest_divisors(hw, aspect_ratio)
|
||||||
|
else:
|
||||||
|
h = int(h_candidate)
|
||||||
|
else:
|
||||||
|
w = int(w_candidate)
|
||||||
|
return h, w
|
||||||
|
|
||||||
|
|
||||||
|
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
|
||||||
|
|
||||||
|
@wraps(params.forward)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not params.enabled:
|
||||||
|
return params.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
latent_tile_size = max(128, params.tile_size) // 8
|
||||||
|
x = args[0]
|
||||||
|
|
||||||
|
# VAE
|
||||||
|
if x.ndim == 4:
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
nh = random_divisor(h, latent_tile_size, params.swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size, params.swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||||
|
|
||||||
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||||
|
|
||||||
|
# U-Net
|
||||||
|
else:
|
||||||
|
hw: int = x.size(1)
|
||||||
|
h, w = find_hw_candidates(hw, params.aspect_ratio)
|
||||||
|
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||||
|
|
||||||
|
factor = 2 ** params.depth if scale_depth else 1
|
||||||
|
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
|
||||||
|
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
|
||||||
|
out = params.forward(x, *args[1:], **kwargs)
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
|
||||||
|
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
|
||||||
|
if hypertile_layers is None:
|
||||||
|
if not enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
hypertile_layers = {}
|
||||||
|
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
|
||||||
|
|
||||||
|
for depth in range(4):
|
||||||
|
for layer_name, module in model.named_modules():
|
||||||
|
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||||
|
params = HypertileParams()
|
||||||
|
module.__webui_hypertile_params = params
|
||||||
|
params.forward = module.forward
|
||||||
|
params.depth = depth
|
||||||
|
params.layer_name = layer_name
|
||||||
|
module.forward = self_attn_forward(params)
|
||||||
|
|
||||||
|
hypertile_layers[layer_name] = 1
|
||||||
|
|
||||||
|
model.__webui_hypertile_layers = hypertile_layers
|
||||||
|
|
||||||
|
aspect_ratio = width / height
|
||||||
|
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
|
||||||
|
|
||||||
|
for layer_name, module in model.named_modules():
|
||||||
|
if layer_name in hypertile_layers:
|
||||||
|
params = module.__webui_hypertile_params
|
||||||
|
|
||||||
|
params.tile_size = tile_size
|
||||||
|
params.swap_size = swap_size
|
||||||
|
params.aspect_ratio = aspect_ratio
|
||||||
|
params.enabled = enable and params.depth <= max_depth
|
109
extensions-builtin/hypertile/scripts/hypertile_script.py
Normal file
109
extensions-builtin/hypertile/scripts/hypertile_script.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
import hypertile
|
||||||
|
from modules import scripts, script_callbacks, shared
|
||||||
|
from scripts.hypertile_xyz import add_axis_options
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptHypertile(scripts.Script):
|
||||||
|
name = "Hypertile"
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def process(self, p, *args):
|
||||||
|
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||||
|
|
||||||
|
configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)
|
||||||
|
|
||||||
|
self.add_infotext(p)
|
||||||
|
|
||||||
|
def before_hr(self, p, *args):
|
||||||
|
|
||||||
|
enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet
|
||||||
|
|
||||||
|
# exclusive hypertile seed for the second pass
|
||||||
|
if enable:
|
||||||
|
hypertile.set_hypertile_seed(p.all_seeds[0])
|
||||||
|
|
||||||
|
configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)
|
||||||
|
|
||||||
|
if enable and not shared.opts.hypertile_enable_unet:
|
||||||
|
p.extra_generation_params["Hypertile U-Net second pass"] = True
|
||||||
|
|
||||||
|
self.add_infotext(p, add_unet_params=True)
|
||||||
|
|
||||||
|
def add_infotext(self, p, add_unet_params=False):
|
||||||
|
def option(name):
|
||||||
|
value = getattr(shared.opts, name)
|
||||||
|
default_value = shared.opts.get_default(name)
|
||||||
|
return None if value == default_value else value
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_unet:
|
||||||
|
p.extra_generation_params["Hypertile U-Net"] = True
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_unet or add_unet_params:
|
||||||
|
p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet')
|
||||||
|
p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet')
|
||||||
|
p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet')
|
||||||
|
|
||||||
|
if shared.opts.hypertile_enable_vae:
|
||||||
|
p.extra_generation_params["Hypertile VAE"] = True
|
||||||
|
p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae')
|
||||||
|
p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae')
|
||||||
|
p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae')
|
||||||
|
|
||||||
|
|
||||||
|
def configure_hypertile(width, height, enable_unet=True):
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.first_stage_model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_vae,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_vae,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_vae,
|
||||||
|
enable=shared.opts.hypertile_enable_vae,
|
||||||
|
)
|
||||||
|
|
||||||
|
hypertile.hypertile_hook_model(
|
||||||
|
shared.sd_model.model,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
swap_size=shared.opts.hypertile_swap_size_unet,
|
||||||
|
max_depth=shared.opts.hypertile_max_depth_unet,
|
||||||
|
tile_size_max=shared.opts.hypertile_max_tile_unet,
|
||||||
|
enable=enable_unet,
|
||||||
|
is_sdxl=shared.sd_model.is_sdxl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
options = {
|
||||||
|
"hypertile_explanation": shared.OptionHTML("""
|
||||||
|
<a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
|
||||||
|
resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
|
||||||
|
benefit.
|
||||||
|
"""),
|
||||||
|
|
||||||
|
"hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"),
|
||||||
|
"hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"),
|
||||||
|
"hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"),
|
||||||
|
"hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"),
|
||||||
|
"hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"),
|
||||||
|
|
||||||
|
"hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"),
|
||||||
|
"hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"),
|
||||||
|
"hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"),
|
||||||
|
"hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, opt in options.items():
|
||||||
|
opt.section = ('hypertile', "Hypertile")
|
||||||
|
shared.opts.add_option(name, opt)
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
script_callbacks.on_before_ui(add_axis_options)
|
51
extensions-builtin/hypertile/scripts/hypertile_xyz.py
Normal file
51
extensions-builtin/hypertile/scripts/hypertile_xyz.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from modules import scripts
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module
|
||||||
|
|
||||||
|
def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
|
||||||
|
"""
|
||||||
|
Returns a function that applies the given value to the given value_name in opts.data.
|
||||||
|
"""
|
||||||
|
def validate(value_name:str, value:str):
|
||||||
|
value = int(value)
|
||||||
|
# validate value
|
||||||
|
if not min_range == -1:
|
||||||
|
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
|
||||||
|
if not max_range == -1:
|
||||||
|
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
|
||||||
|
def apply_int(p, x, xs):
|
||||||
|
validate(value_name, x)
|
||||||
|
opts.data[value_name] = int(x)
|
||||||
|
return apply_int
|
||||||
|
|
||||||
|
def bool_applier(value_name:str):
|
||||||
|
"""
|
||||||
|
Returns a function that applies the given value to the given value_name in opts.data.
|
||||||
|
"""
|
||||||
|
def validate(value_name:str, value:str):
|
||||||
|
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
|
||||||
|
def apply_bool(p, x, xs):
|
||||||
|
validate(value_name, x)
|
||||||
|
value_boolean = x.lower() == "true"
|
||||||
|
opts.data[value_name] = value_boolean
|
||||||
|
return apply_bool
|
||||||
|
|
||||||
|
def add_axis_options():
|
||||||
|
extra_axis_options = [
|
||||||
|
xyz_grid.AxisOption("[Hypertile] Unet First pass Enabled", str, bool_applier("hypertile_enable_unet"), choices=xyz_grid.boolean_choice(reverse=True)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] Unet Second pass Enabled", str, bool_applier("hypertile_enable_unet_secondpass"), choices=xyz_grid.boolean_choice(reverse=True)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] Unet Max Depth", int, int_applier("hypertile_max_depth_unet", 0, 3), choices=lambda: [str(x) for x in range(4)]),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] Unet Max Tile Size", int, int_applier("hypertile_max_tile_unet", 0, 512)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] Unet Swap Size", int, int_applier("hypertile_swap_size_unet", 0, 64)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] VAE Enabled", str, bool_applier("hypertile_enable_vae"), choices=xyz_grid.boolean_choice(reverse=True)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] VAE Max Depth", int, int_applier("hypertile_max_depth_vae", 0, 3), choices=lambda: [str(x) for x in range(4)]),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] VAE Max Tile Size", int, int_applier("hypertile_max_tile_vae", 0, 512)),
|
||||||
|
xyz_grid.AxisOption("[Hypertile] VAE Swap Size", int, int_applier("hypertile_swap_size_vae", 0, 64)),
|
||||||
|
]
|
||||||
|
set_a = {opt.label for opt in xyz_grid.axis_options}
|
||||||
|
set_b = {opt.label for opt in extra_axis_options}
|
||||||
|
if set_a.intersection(set_b):
|
||||||
|
return
|
||||||
|
|
||||||
|
xyz_grid.axis_options.extend(extra_axis_options)
|
@ -12,6 +12,8 @@ function isMobile() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function reportWindowSize() {
|
function reportWindowSize() {
|
||||||
|
if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
|
||||||
|
|
||||||
var currentlyMobile = isMobile();
|
var currentlyMobile = isMobile();
|
||||||
if (currentlyMobile == isSetupForMobile) return;
|
if (currentlyMobile == isSetupForMobile) return;
|
||||||
isSetupForMobile = currentlyMobile;
|
isSetupForMobile = currentlyMobile;
|
||||||
|
2
javascript/dragdrop.js
vendored
2
javascript/dragdrop.js
vendored
@ -119,7 +119,7 @@ window.addEventListener('paste', e => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const firstFreeImageField = visibleImageFields
|
const firstFreeImageField = visibleImageFields
|
||||||
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
.filter(el => !el.querySelector('img'))?.[0];
|
||||||
|
|
||||||
dropReplaceImage(
|
dropReplaceImage(
|
||||||
firstFreeImageField ?
|
firstFreeImageField ?
|
||||||
|
@ -18,37 +18,43 @@ function keyupEditAttention(event) {
|
|||||||
const before = text.substring(0, selectionStart);
|
const before = text.substring(0, selectionStart);
|
||||||
let beforeParen = before.lastIndexOf(OPEN);
|
let beforeParen = before.lastIndexOf(OPEN);
|
||||||
if (beforeParen == -1) return false;
|
if (beforeParen == -1) return false;
|
||||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
|
||||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
let beforeClosingParen = before.lastIndexOf(CLOSE);
|
||||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
|
||||||
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find closing parenthesis around current cursor
|
// Find closing parenthesis around current cursor
|
||||||
const after = text.substring(selectionStart);
|
const after = text.substring(selectionStart);
|
||||||
let afterParen = after.indexOf(CLOSE);
|
let afterParen = after.indexOf(CLOSE);
|
||||||
if (afterParen == -1) return false;
|
if (afterParen == -1) return false;
|
||||||
let afterParenOpen = after.indexOf(OPEN);
|
|
||||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
let afterOpeningParen = after.indexOf(OPEN);
|
||||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
|
||||||
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
|
||||||
}
|
|
||||||
if (beforeParen === -1 || afterParen === -1) return false;
|
|
||||||
|
|
||||||
// Set the selection to the text between the parenthesis
|
// Set the selection to the text between the parenthesis
|
||||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||||
const lastColon = parenContent.lastIndexOf(":");
|
if (/.*:-?[\d.]+/s.test(parenContent)) {
|
||||||
selectionStart = beforeParen + 1;
|
const lastColon = parenContent.lastIndexOf(":");
|
||||||
selectionEnd = selectionStart + lastColon;
|
selectionStart = beforeParen + 1;
|
||||||
|
selectionEnd = selectionStart + lastColon;
|
||||||
|
} else {
|
||||||
|
selectionStart = beforeParen + 1;
|
||||||
|
selectionEnd = selectionStart + parenContent.length;
|
||||||
|
}
|
||||||
|
|
||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
function selectCurrentWord() {
|
function selectCurrentWord() {
|
||||||
if (selectionStart !== selectionEnd) return false;
|
if (selectionStart !== selectionEnd) return false;
|
||||||
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"};
|
||||||
|
let delimiters = opts.keyedit_delimiters;
|
||||||
|
|
||||||
// seek backward until to find beggining
|
for (let i of opts.keyedit_delimiters_whitespace) {
|
||||||
|
delimiters += whitespace_delimiters[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// seek backward to find beginning
|
||||||
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
||||||
selectionStart--;
|
selectionStart--;
|
||||||
}
|
}
|
||||||
@ -63,7 +69,7 @@ function keyupEditAttention(event) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||||
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
|
||||||
selectCurrentWord();
|
selectCurrentWord();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,33 +77,54 @@ function keyupEditAttention(event) {
|
|||||||
|
|
||||||
var closeCharacter = ')';
|
var closeCharacter = ')';
|
||||||
var delta = opts.keyedit_precision_attention;
|
var delta = opts.keyedit_precision_attention;
|
||||||
|
var start = selectionStart > 0 ? text[selectionStart - 1] : "";
|
||||||
|
var end = text[selectionEnd];
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<') {
|
if (start == '<') {
|
||||||
closeCharacter = '>';
|
closeCharacter = '>';
|
||||||
delta = opts.keyedit_precision_extra;
|
delta = opts.keyedit_precision_extra;
|
||||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
} else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
|
||||||
|
let numParen = 0;
|
||||||
|
|
||||||
|
while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
|
||||||
|
numParen++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (start == "[") {
|
||||||
|
weight = (1 / 1.1) ** numParen;
|
||||||
|
} else {
|
||||||
|
weight = 1.1 ** numParen;
|
||||||
|
}
|
||||||
|
|
||||||
|
weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
|
||||||
|
|
||||||
|
text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
|
||||||
|
selectionStart -= numParen - 1;
|
||||||
|
selectionEnd -= numParen - 1;
|
||||||
|
} else if (start != '(') {
|
||||||
// do not include spaces at the end
|
// do not include spaces at the end
|
||||||
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
||||||
selectionEnd -= 1;
|
selectionEnd--;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (selectionStart == selectionEnd) {
|
if (selectionStart == selectionEnd) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||||
|
|
||||||
selectionStart += 1;
|
selectionStart++;
|
||||||
selectionEnd += 1;
|
selectionEnd++;
|
||||||
}
|
}
|
||||||
|
|
||||||
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
if (text[selectionEnd] != ':') return;
|
||||||
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
weight = parseFloat(weight.toPrecision(12));
|
weight = parseFloat(weight.toPrecision(12));
|
||||||
if (String(weight).length == 1) weight += ".0";
|
if (Number.isInteger(weight)) weight += ".0";
|
||||||
|
|
||||||
if (closeCharacter == ')' && weight == 1) {
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
||||||
@ -105,7 +132,7 @@ function keyupEditAttention(event) {
|
|||||||
selectionStart--;
|
selectionStart--;
|
||||||
selectionEnd--;
|
selectionEnd--;
|
||||||
} else {
|
} else {
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
|
@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||||
|
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
|
||||||
|
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
|
||||||
|
|
||||||
sort.dataset.sortkey = 'sortDefault';
|
|
||||||
tabs.appendChild(searchDiv);
|
tabs.appendChild(searchDiv);
|
||||||
tabs.appendChild(sort);
|
tabs.appendChild(sort);
|
||||||
tabs.appendChild(sortOrder);
|
tabs.appendChild(sortOrder);
|
||||||
@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
|
|
||||||
elem.style.display = visible ? "" : "none";
|
elem.style.display = visible ? "" : "none";
|
||||||
});
|
});
|
||||||
|
|
||||||
|
applySort();
|
||||||
};
|
};
|
||||||
|
|
||||||
var applySort = function() {
|
var applySort = function() {
|
||||||
|
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||||
|
|
||||||
var reverse = sortOrder.classList.contains("sortReverse");
|
var reverse = sortOrder.classList.contains("sortReverse");
|
||||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
|
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||||
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
|
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||||
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
|
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
||||||
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
|
|
||||||
|
if (sortKeyStore == sort.dataset.sortkey) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.dataset.sortkey = sortKeyStore;
|
sort.dataset.sortkey = sortKeyStore;
|
||||||
|
|
||||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
|
||||||
cards.forEach(function(card) {
|
cards.forEach(function(card) {
|
||||||
card.originalParentElement = card.parentElement;
|
card.originalParentElement = card.parentElement;
|
||||||
});
|
});
|
||||||
@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
search.addEventListener("input", applyFilter);
|
search.addEventListener("input", applyFilter);
|
||||||
applyFilter();
|
|
||||||
["change", "blur", "click"].forEach(function(evt) {
|
|
||||||
sort.querySelector("input").addEventListener(evt, applySort);
|
|
||||||
});
|
|
||||||
sortOrder.addEventListener("click", function() {
|
sortOrder.addEventListener("click", function() {
|
||||||
sortOrder.classList.toggle("sortReverse");
|
sortOrder.classList.toggle("sortReverse");
|
||||||
applySort();
|
applySort();
|
||||||
});
|
});
|
||||||
|
applyFilter();
|
||||||
|
|
||||||
|
extraNetworksApplySort[tabname] = applySort;
|
||||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
|
|
||||||
var showDirsUpdate = function() {
|
var showDirsUpdate = function() {
|
||||||
@ -109,11 +111,51 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
showDirsUpdate();
|
showDirsUpdate();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
||||||
|
if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
|
||||||
|
|
||||||
|
var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
|
||||||
|
var prompt = gradioApp().getElementById(tabname + '_prompt_row');
|
||||||
|
var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
|
||||||
|
var elem = id ? gradioApp().getElementById(id) : null;
|
||||||
|
|
||||||
|
if (showNegativePrompt && elem) {
|
||||||
|
elem.insertBefore(negPrompt, elem.firstChild);
|
||||||
|
} else {
|
||||||
|
promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (showPrompt && elem) {
|
||||||
|
elem.insertBefore(prompt, elem.firstChild);
|
||||||
|
} else {
|
||||||
|
promptContainer.insertBefore(prompt, promptContainer.firstChild);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elem) {
|
||||||
|
elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||||
|
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
|
||||||
|
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
function applyExtraNetworkFilter(tabname) {
|
function applyExtraNetworkFilter(tabname) {
|
||||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyExtraNetworkSort(tabname) {
|
||||||
|
setTimeout(extraNetworksApplySort[tabname], 1);
|
||||||
|
}
|
||||||
|
|
||||||
var extraNetworksApplyFilter = {};
|
var extraNetworksApplyFilter = {};
|
||||||
|
var extraNetworksApplySort = {};
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks() {
|
function setupExtraNetworks() {
|
||||||
@ -140,14 +182,15 @@ function setupExtraNetworks() {
|
|||||||
|
|
||||||
onUiLoaded(setupExtraNetworks);
|
onUiLoaded(setupExtraNetworks);
|
||||||
|
|
||||||
var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
|
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
|
||||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
|
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
|
||||||
|
|
||||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||||
var m = text.match(re_extranet);
|
var m = text.match(re_extranet);
|
||||||
var replaced = false;
|
var replaced = false;
|
||||||
var newTextareaText;
|
var newTextareaText;
|
||||||
if (m) {
|
if (m) {
|
||||||
|
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
|
||||||
var extraTextAfterNet = m[2];
|
var extraTextAfterNet = m[2];
|
||||||
var partToSearch = m[1];
|
var partToSearch = m[1];
|
||||||
var foundAtPosition = -1;
|
var foundAtPosition = -1;
|
||||||
@ -161,8 +204,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
|||||||
return found;
|
return found;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
if (foundAtPosition >= 0) {
|
||||||
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||||
|
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
||||||
|
}
|
||||||
|
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
|
||||||
|
newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
||||||
@ -216,27 +264,24 @@ function extraNetworksSearchButton(tabs_id, event) {
|
|||||||
|
|
||||||
var globalPopup = null;
|
var globalPopup = null;
|
||||||
var globalPopupInner = null;
|
var globalPopupInner = null;
|
||||||
|
|
||||||
function closePopup() {
|
function closePopup() {
|
||||||
if (!globalPopup) return;
|
if (!globalPopup) return;
|
||||||
|
|
||||||
globalPopup.style.display = "none";
|
globalPopup.style.display = "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
function popup(contents) {
|
function popup(contents) {
|
||||||
if (!globalPopup) {
|
if (!globalPopup) {
|
||||||
globalPopup = document.createElement('div');
|
globalPopup = document.createElement('div');
|
||||||
globalPopup.onclick = closePopup;
|
|
||||||
globalPopup.classList.add('global-popup');
|
globalPopup.classList.add('global-popup');
|
||||||
|
|
||||||
var close = document.createElement('div');
|
var close = document.createElement('div');
|
||||||
close.classList.add('global-popup-close');
|
close.classList.add('global-popup-close');
|
||||||
close.onclick = closePopup;
|
close.addEventListener("click", closePopup);
|
||||||
close.title = "Close";
|
close.title = "Close";
|
||||||
globalPopup.appendChild(close);
|
globalPopup.appendChild(close);
|
||||||
|
|
||||||
globalPopupInner = document.createElement('div');
|
globalPopupInner = document.createElement('div');
|
||||||
globalPopupInner.onclick = function(event) {
|
|
||||||
event.stopPropagation(); return false;
|
|
||||||
};
|
|
||||||
globalPopupInner.classList.add('global-popup-inner');
|
globalPopupInner.classList.add('global-popup-inner');
|
||||||
globalPopup.appendChild(globalPopupInner);
|
globalPopup.appendChild(globalPopupInner);
|
||||||
|
|
||||||
@ -335,7 +380,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
|
|||||||
function extraNetworksRefreshSingleCard(page, tabname, name) {
|
function extraNetworksRefreshSingleCard(page, tabname, name) {
|
||||||
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
|
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
|
||||||
if (data && data.html) {
|
if (data && data.html) {
|
||||||
var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
|
var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`);
|
||||||
|
|
||||||
var newDiv = document.createElement('DIV');
|
var newDiv = document.createElement('DIV');
|
||||||
newDiv.innerHTML = data.html;
|
newDiv.innerHTML = data.html;
|
||||||
@ -347,3 +392,9 @@ function extraNetworksRefreshSingleCard(page, tabname, name) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
window.addEventListener("keydown", function(event) {
|
||||||
|
if (event.key == "Escape") {
|
||||||
|
closePopup();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
|
|||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
if (modalImage && modalImage.offsetParent) {
|
if (modalImage && modalImage.offsetParent) {
|
||||||
let currentButton = selected_gallery_button();
|
let currentButton = selected_gallery_button();
|
||||||
|
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||||
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {
|
||||||
|
// show preview image if available
|
||||||
|
modalImage.src = preview[preview.length - 1].src;
|
||||||
|
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
modalImage.src = currentButton.children[0].src;
|
modalImage.src = currentButton.children[0].src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
|
@ -1,37 +1,68 @@
|
|||||||
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
|
||||||
mutations.forEach(function(mutationRecord) {
|
|
||||||
var elem = mutationRecord.target;
|
|
||||||
var open = elem.classList.contains('open');
|
|
||||||
|
|
||||||
var accordion = elem.parentNode;
|
|
||||||
accordion.classList.toggle('input-accordion-open', open);
|
|
||||||
|
|
||||||
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
|
||||||
checkbox.checked = open;
|
|
||||||
updateInput(checkbox);
|
|
||||||
|
|
||||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
|
||||||
if (extra) {
|
|
||||||
extra.style.display = open ? "" : "none";
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
function inputAccordionChecked(id, checked) {
|
function inputAccordionChecked(id, checked) {
|
||||||
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
var accordion = gradioApp().getElementById(id);
|
||||||
if (label.classList.contains('open') != checked) {
|
accordion.visibleCheckbox.checked = checked;
|
||||||
label.click();
|
accordion.onVisibleCheckboxChange();
|
||||||
|
}
|
||||||
|
|
||||||
|
function setupAccordion(accordion) {
|
||||||
|
var labelWrap = accordion.querySelector('.label-wrap');
|
||||||
|
var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||||
|
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||||
|
var span = labelWrap.querySelector('span');
|
||||||
|
var linked = true;
|
||||||
|
|
||||||
|
var isOpen = function() {
|
||||||
|
return labelWrap.classList.contains('open');
|
||||||
|
};
|
||||||
|
|
||||||
|
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutationRecord) {
|
||||||
|
accordion.classList.toggle('input-accordion-open', isOpen());
|
||||||
|
|
||||||
|
if (linked) {
|
||||||
|
accordion.visibleCheckbox.checked = isOpen();
|
||||||
|
accordion.onVisibleCheckboxChange();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||||
|
|
||||||
|
if (extra) {
|
||||||
|
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accordion.onChecked = function(checked) {
|
||||||
|
if (isOpen() != checked) {
|
||||||
|
labelWrap.click();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var visibleCheckbox = document.createElement('INPUT');
|
||||||
|
visibleCheckbox.type = 'checkbox';
|
||||||
|
visibleCheckbox.checked = isOpen();
|
||||||
|
visibleCheckbox.id = accordion.id + "-visible-checkbox";
|
||||||
|
visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
|
||||||
|
span.insertBefore(visibleCheckbox, span.firstChild);
|
||||||
|
|
||||||
|
accordion.visibleCheckbox = visibleCheckbox;
|
||||||
|
accordion.onVisibleCheckboxChange = function() {
|
||||||
|
if (linked && isOpen() != visibleCheckbox.checked) {
|
||||||
|
labelWrap.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
gradioCheckbox.checked = visibleCheckbox.checked;
|
||||||
|
updateInput(gradioCheckbox);
|
||||||
|
};
|
||||||
|
|
||||||
|
visibleCheckbox.addEventListener('click', function(event) {
|
||||||
|
linked = false;
|
||||||
|
event.stopPropagation();
|
||||||
|
});
|
||||||
|
visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||||
var labelWrap = accordion.querySelector('.label-wrap');
|
setupAccordion(accordion);
|
||||||
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
|
||||||
|
|
||||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
|
||||||
if (extra) {
|
|
||||||
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
|
|||||||
lastHeadImg = headImg;
|
lastHeadImg = headImg;
|
||||||
|
|
||||||
// play notification sound if available
|
// play notification sound if available
|
||||||
gradioApp().querySelector('#audio_notification audio')?.play();
|
const notificationAudio = gradioApp().querySelector('#audio_notification audio');
|
||||||
|
if (notificationAudio) {
|
||||||
|
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
|
||||||
|
notificationAudio.play();
|
||||||
|
}
|
||||||
|
|
||||||
if (document.hasFocus()) return;
|
if (document.hasFocus()) return;
|
||||||
|
|
||||||
|
71
javascript/settings.js
Normal file
71
javascript/settings.js
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
let settingsExcludeTabsFromShowAll = {
|
||||||
|
settings_tab_defaults: 1,
|
||||||
|
settings_tab_sysinfo: 1,
|
||||||
|
settings_tab_actions: 1,
|
||||||
|
settings_tab_licenses: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
function settingsShowAllTabs() {
|
||||||
|
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
|
||||||
|
if (settingsExcludeTabsFromShowAll[elem.id]) return;
|
||||||
|
|
||||||
|
elem.style.display = "block";
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function settingsShowOneTab() {
|
||||||
|
gradioApp().querySelector('#settings_show_one_page').click();
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
var edit = gradioApp().querySelector('#settings_search');
|
||||||
|
var editTextarea = gradioApp().querySelector('#settings_search > label > input');
|
||||||
|
var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');
|
||||||
|
var settings_tabs = gradioApp().querySelector('#settings div');
|
||||||
|
|
||||||
|
onEdit('settingsSearch', editTextarea, 250, function() {
|
||||||
|
var searchText = (editTextarea.value || "").trim().toLowerCase();
|
||||||
|
|
||||||
|
gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {
|
||||||
|
var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;
|
||||||
|
elem.style.display = visible ? "" : "none";
|
||||||
|
});
|
||||||
|
|
||||||
|
if (searchText != "") {
|
||||||
|
settingsShowAllTabs();
|
||||||
|
} else {
|
||||||
|
settingsShowOneTab();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
settings_tabs.insertBefore(edit, settings_tabs.firstChild);
|
||||||
|
settings_tabs.appendChild(buttonShowAllPages);
|
||||||
|
|
||||||
|
|
||||||
|
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
onOptionsChanged(function() {
|
||||||
|
if (gradioApp().querySelector('#settings .settings-category')) return;
|
||||||
|
|
||||||
|
var sectionMap = {};
|
||||||
|
gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {
|
||||||
|
sectionMap[x.textContent.trim()] = x;
|
||||||
|
});
|
||||||
|
|
||||||
|
opts._categories.forEach(function(x) {
|
||||||
|
var section = x[0];
|
||||||
|
var category = x[1];
|
||||||
|
|
||||||
|
var span = document.createElement('SPAN');
|
||||||
|
span.textContent = category;
|
||||||
|
span.className = 'settings-category';
|
||||||
|
|
||||||
|
var sectionElem = sectionMap[section];
|
||||||
|
if (!sectionElem) return;
|
||||||
|
|
||||||
|
sectionElem.parentElement.insertBefore(span, sectionElem);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -1,10 +1,9 @@
|
|||||||
let promptTokenCountDebounceTime = 800;
|
let promptTokenCountUpdateFunctions = {};
|
||||||
let promptTokenCountTimeouts = {};
|
|
||||||
var promptTokenCountUpdateFunctions = {};
|
|
||||||
|
|
||||||
function update_txt2img_tokens(...args) {
|
function update_txt2img_tokens(...args) {
|
||||||
// Called from Gradio
|
// Called from Gradio
|
||||||
update_token_counter("txt2img_token_button");
|
update_token_counter("txt2img_token_button");
|
||||||
|
update_token_counter("txt2img_negative_token_button");
|
||||||
if (args.length == 2) {
|
if (args.length == 2) {
|
||||||
return args[0];
|
return args[0];
|
||||||
}
|
}
|
||||||
@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) {
|
|||||||
function update_img2img_tokens(...args) {
|
function update_img2img_tokens(...args) {
|
||||||
// Called from Gradio
|
// Called from Gradio
|
||||||
update_token_counter("img2img_token_button");
|
update_token_counter("img2img_token_button");
|
||||||
|
update_token_counter("img2img_negative_token_button");
|
||||||
if (args.length == 2) {
|
if (args.length == 2) {
|
||||||
return args[0];
|
return args[0];
|
||||||
}
|
}
|
||||||
@ -21,16 +21,7 @@ function update_img2img_tokens(...args) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function update_token_counter(button_id) {
|
||||||
if (opts.disable_token_counters) {
|
promptTokenCountUpdateFunctions[button_id]?.();
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (promptTokenCountTimeouts[button_id]) {
|
|
||||||
clearTimeout(promptTokenCountTimeouts[button_id]);
|
|
||||||
}
|
|
||||||
promptTokenCountTimeouts[button_id] = setTimeout(
|
|
||||||
() => gradioApp().getElementById(button_id)?.click(),
|
|
||||||
promptTokenCountDebounceTime,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) {
|
|||||||
prompt.parentElement.insertBefore(counter, prompt);
|
prompt.parentElement.insertBefore(counter, prompt);
|
||||||
prompt.parentElement.style.position = "relative";
|
prompt.parentElement.style.position = "relative";
|
||||||
|
|
||||||
promptTokenCountUpdateFunctions[id] = function() {
|
var func = onEdit(id, textarea, 800, function() {
|
||||||
update_token_counter(id_button);
|
gradioApp().getElementById(id_button)?.click();
|
||||||
};
|
});
|
||||||
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
|
promptTokenCountUpdateFunctions[id] = func;
|
||||||
|
promptTokenCountUpdateFunctions[id_button] = func;
|
||||||
}
|
}
|
||||||
|
|
||||||
function setupTokenCounters() {
|
function setupTokenCounters() {
|
||||||
|
@ -170,6 +170,23 @@ function submit_img2img() {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function submit_extras() {
|
||||||
|
showSubmitButtons('extras', false);
|
||||||
|
|
||||||
|
var id = randomId();
|
||||||
|
|
||||||
|
requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
|
||||||
|
showSubmitButtons('extras', true);
|
||||||
|
});
|
||||||
|
|
||||||
|
var res = create_submit_args(arguments);
|
||||||
|
|
||||||
|
res[0] = id;
|
||||||
|
|
||||||
|
console.log(res);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
function restoreProgressTxt2img() {
|
function restoreProgressTxt2img() {
|
||||||
showRestoreProgressButton("txt2img", false);
|
showRestoreProgressButton("txt2img", false);
|
||||||
var id = localGet("txt2img_task_id");
|
var id = localGet("txt2img_task_id");
|
||||||
@ -198,9 +215,33 @@ function restoreProgressImg2img() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configure the width and height elements on `tabname` to accept
|
||||||
|
* pasting of resolutions in the form of "width x height".
|
||||||
|
*/
|
||||||
|
function setupResolutionPasting(tabname) {
|
||||||
|
var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);
|
||||||
|
var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);
|
||||||
|
for (const el of [width, height]) {
|
||||||
|
el.addEventListener('paste', function(event) {
|
||||||
|
var pasteData = event.clipboardData.getData('text/plain');
|
||||||
|
var parsed = pasteData.match(/^\s*(\d+)\D+(\d+)\s*$/);
|
||||||
|
if (parsed) {
|
||||||
|
width.value = parsed[1];
|
||||||
|
height.value = parsed[2];
|
||||||
|
updateInput(width);
|
||||||
|
updateInput(height);
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onUiLoaded(function() {
|
onUiLoaded(function() {
|
||||||
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
showRestoreProgressButton('txt2img', localGet("txt2img_task_id"));
|
||||||
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
showRestoreProgressButton('img2img', localGet("img2img_task_id"));
|
||||||
|
setupResolutionPasting('txt2img');
|
||||||
|
setupResolutionPasting('img2img');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
@ -263,21 +304,6 @@ onAfterUiUpdate(function() {
|
|||||||
json_elem.parentElement.style.display = "none";
|
json_elem.parentElement.style.display = "none";
|
||||||
|
|
||||||
setupTokenCounters();
|
setupTokenCounters();
|
||||||
|
|
||||||
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
|
|
||||||
var settings_tabs = gradioApp().querySelector('#settings div');
|
|
||||||
if (show_all_pages && settings_tabs) {
|
|
||||||
settings_tabs.appendChild(show_all_pages);
|
|
||||||
show_all_pages.onclick = function() {
|
|
||||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
|
|
||||||
if (elem.id == "settings_tab_licenses") {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.style.display = "block";
|
|
||||||
});
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
onOptionsChanged(function() {
|
onOptionsChanged(function() {
|
||||||
@ -366,3 +392,20 @@ function switchWidthHeight(tabname) {
|
|||||||
updateInput(height);
|
updateInput(height);
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
var onEditTimers = {};
|
||||||
|
|
||||||
|
// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
|
||||||
|
function onEdit(editId, elem, afterMs, func) {
|
||||||
|
var edited = function() {
|
||||||
|
var existingTimer = onEditTimers[editId];
|
||||||
|
if (existingTimer) clearTimeout(existingTimer);
|
||||||
|
|
||||||
|
onEditTimers[editId] = setTimeout(func, afterMs);
|
||||||
|
};
|
||||||
|
|
||||||
|
elem.addEventListener("input", edited);
|
||||||
|
|
||||||
|
return edited;
|
||||||
|
}
|
||||||
|
@ -17,19 +17,17 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin,Image
|
from PIL import PngImagePlugin, Image
|
||||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import Dict, List, Any
|
from typing import Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
@ -103,7 +101,8 @@ def decode_base64_to_image(encoding):
|
|||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
|
if isinstance(image, str):
|
||||||
|
return image
|
||||||
if opts.samples_format.lower() == 'png':
|
if opts.samples_format.lower() == 'png':
|
||||||
use_metadata = False
|
use_metadata = False
|
||||||
metadata = PngImagePlugin.PngInfo()
|
metadata = PngImagePlugin.PngInfo()
|
||||||
@ -221,28 +220,28 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
||||||
|
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
||||||
|
|
||||||
if shared.cmd_opts.api_server_stop:
|
if shared.cmd_opts.api_server_stop:
|
||||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||||
@ -473,9 +472,6 @@ class Api:
|
|||||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
|
||||||
return models.PNGInfoResponse(info="")
|
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return models.PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
@ -484,9 +480,10 @@ class Api:
|
|||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||||
|
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||||
|
|
||||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
@ -541,12 +538,12 @@ class Api:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
def unloadapi(self):
|
def unloadapi(self):
|
||||||
unload_model_weights()
|
sd_models.unload_model_weights()
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def reloadapi(self):
|
def reloadapi(self):
|
||||||
reload_model_weights()
|
sd_models.send_model_to_device(shared.sd_model)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -564,9 +561,9 @@ class Api:
|
|||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
def set_config(self, req: Dict[str, Any]):
|
def set_config(self, req: dict[str, Any]):
|
||||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||||
|
|
||||||
for k, v in req.items():
|
for k, v in req.items():
|
||||||
@ -676,19 +673,6 @@ class Api:
|
|||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
|
||||||
try:
|
|
||||||
shared.state.begin(job="preprocess")
|
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
|
||||||
shared.state.end()
|
|
||||||
return models.PreprocessResponse(info='preprocess complete')
|
|
||||||
except KeyError as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
|
||||||
finally:
|
|
||||||
shared.state.end()
|
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="train_embedding")
|
shared.state.begin(job="train_embedding")
|
||||||
@ -770,6 +754,25 @@ class Api:
|
|||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
|
def get_extensions_list(self):
|
||||||
|
from modules import extensions
|
||||||
|
extensions.list_extensions()
|
||||||
|
ext_list = []
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
ext: extensions.Extension
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
if ext.remote is not None:
|
||||||
|
ext_list.append({
|
||||||
|
"name": ext.name,
|
||||||
|
"remote": ext.remote,
|
||||||
|
"branch": ext.branch,
|
||||||
|
"commit_hash":ext.commit_hash,
|
||||||
|
"commit_date":ext.commit_date,
|
||||||
|
"version":ext.version,
|
||||||
|
"enabled":ext.enabled
|
||||||
|
})
|
||||||
|
return ext_list
|
||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Literal
|
||||||
from typing_extensions import Literal
|
|
||||||
from inflection import underscore
|
from inflection import underscore
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||||
from modules.shared import sd_upscalers, opts, parser
|
from modules.shared import sd_upscalers, opts, parser
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
API_NOT_ALLOWED = [
|
API_NOT_ALLOWED = [
|
||||||
"self",
|
"self",
|
||||||
@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
|||||||
).generate_model()
|
).generate_model()
|
||||||
|
|
||||||
class TextToImageResponse(BaseModel):
|
class TextToImageResponse(BaseModel):
|
||||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
parameters: dict
|
parameters: dict
|
||||||
info: str
|
info: str
|
||||||
|
|
||||||
class ImageToImageResponse(BaseModel):
|
class ImageToImageResponse(BaseModel):
|
||||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||||
parameters: dict
|
parameters: dict
|
||||||
info: str
|
info: str
|
||||||
|
|
||||||
@ -168,17 +166,18 @@ class FileData(BaseModel):
|
|||||||
name: str = Field(title="File name")
|
name: str = Field(title="File name")
|
||||||
|
|
||||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||||
|
|
||||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||||
|
|
||||||
class PNGInfoRequest(BaseModel):
|
class PNGInfoRequest(BaseModel):
|
||||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||||
|
|
||||||
class PNGInfoResponse(BaseModel):
|
class PNGInfoResponse(BaseModel):
|
||||||
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||||
items: dict = Field(title="Items", description="An object containing all the info the image had")
|
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
|
||||||
|
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
|
||||||
|
|
||||||
class ProgressRequest(BaseModel):
|
class ProgressRequest(BaseModel):
|
||||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||||
@ -203,9 +202,6 @@ class TrainResponse(BaseModel):
|
|||||||
class CreateResponse(BaseModel):
|
class CreateResponse(BaseModel):
|
||||||
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
|
||||||
|
|
||||||
class PreprocessResponse(BaseModel):
|
|
||||||
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
|
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for key, metadata in opts.data_labels.items():
|
for key, metadata in opts.data_labels.items():
|
||||||
value = opts.data.get(key)
|
value = opts.data.get(key)
|
||||||
@ -232,8 +228,8 @@ FlagsModel = create_model("Flags", **flags)
|
|||||||
|
|
||||||
class SamplerItem(BaseModel):
|
class SamplerItem(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
aliases: List[str] = Field(title="Aliases")
|
aliases: list[str] = Field(title="Aliases")
|
||||||
options: Dict[str, str] = Field(title="Options")
|
options: dict[str, str] = Field(title="Options")
|
||||||
|
|
||||||
class UpscalerItem(BaseModel):
|
class UpscalerItem(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
@ -284,8 +280,8 @@ class EmbeddingItem(BaseModel):
|
|||||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||||
|
|
||||||
class EmbeddingsResponse(BaseModel):
|
class EmbeddingsResponse(BaseModel):
|
||||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||||
|
|
||||||
class MemoryResponse(BaseModel):
|
class MemoryResponse(BaseModel):
|
||||||
ram: dict = Field(title="RAM", description="System memory stats")
|
ram: dict = Field(title="RAM", description="System memory stats")
|
||||||
@ -303,11 +299,20 @@ class ScriptArg(BaseModel):
|
|||||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||||
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||||
|
|
||||||
|
|
||||||
class ScriptInfo(BaseModel):
|
class ScriptInfo(BaseModel):
|
||||||
name: str = Field(default=None, title="Name", description="Script name")
|
name: str = Field(default=None, title="Name", description="Script name")
|
||||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||||
|
|
||||||
|
class ExtensionItem(BaseModel):
|
||||||
|
name: str = Field(title="Name", description="Extension name")
|
||||||
|
remote: str = Field(title="Remote", description="Extension Repository URL")
|
||||||
|
branch: str = Field(title="Branch", description="Extension Repository Branch")
|
||||||
|
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
||||||
|
version: str = Field(title="Version", description="Extension Version")
|
||||||
|
commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||||
|
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||||
|
@ -32,7 +32,7 @@ def dump_cache():
|
|||||||
with cache_lock:
|
with cache_lock:
|
||||||
cache_filename_tmp = cache_filename + "-"
|
cache_filename_tmp = cache_filename + "-"
|
||||||
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
os.replace(cache_filename_tmp, cache_filename)
|
os.replace(cache_filename_tmp, cache_filename)
|
||||||
|
|
||||||
|
@ -70,6 +70,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
|||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
|
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
|
||||||
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
@ -90,7 +91,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
|||||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
||||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||||
@ -107,13 +108,14 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
|||||||
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
|
||||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||||
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||||
|
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
|
||||||
|
@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -38,7 +37,7 @@ def list_config_states():
|
|||||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
for cs in config_states:
|
for cs in config_states:
|
||||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
name = cs.get("name", "Config")
|
name = cs.get("name", "Config")
|
||||||
full_name = f"{name}: {timestamp}"
|
full_name = f"{name}: {timestamp}"
|
||||||
all_config_states[full_name] = cs
|
all_config_states[full_name] = cs
|
||||||
|
@ -8,6 +8,13 @@ from modules import errors, shared
|
|||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
|
|
||||||
|
if shared.cmd_opts.use_ipex:
|
||||||
|
from modules import xpu_specific
|
||||||
|
|
||||||
|
|
||||||
|
def has_xpu() -> bool:
|
||||||
|
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu
|
||||||
|
|
||||||
|
|
||||||
def has_mps() -> bool:
|
def has_mps() -> bool:
|
||||||
if sys.platform != "darwin":
|
if sys.platform != "darwin":
|
||||||
@ -30,6 +37,9 @@ def get_optimal_device_name():
|
|||||||
if has_mps():
|
if has_mps():
|
||||||
return "mps"
|
return "mps"
|
||||||
|
|
||||||
|
if has_xpu():
|
||||||
|
return xpu_specific.get_xpu_device_string()
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +48,7 @@ def get_optimal_device():
|
|||||||
|
|
||||||
|
|
||||||
def get_device_for(task):
|
def get_device_for(task):
|
||||||
if task in shared.cmd_opts.use_cpu:
|
if task in shared.cmd_opts.use_cpu or "all" in shared.cmd_opts.use_cpu:
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
return get_optimal_device()
|
return get_optimal_device()
|
||||||
@ -54,13 +64,17 @@ def torch_gc():
|
|||||||
if has_mps():
|
if has_mps():
|
||||||
mac_specific.torch_mps_gc()
|
mac_specific.torch_mps_gc()
|
||||||
|
|
||||||
|
if has_xpu():
|
||||||
|
xpu_specific.torch_xpu_gc()
|
||||||
|
|
||||||
|
|
||||||
def enable_tf32():
|
def enable_tf32():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||||
|
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -6,6 +6,21 @@ import traceback
|
|||||||
exception_records = []
|
exception_records = []
|
||||||
|
|
||||||
|
|
||||||
|
def format_traceback(tb):
|
||||||
|
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||||
|
|
||||||
|
|
||||||
|
def format_exception(e, tb):
|
||||||
|
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||||
|
|
||||||
|
|
||||||
|
def get_exceptions():
|
||||||
|
try:
|
||||||
|
return list(reversed(exception_records))
|
||||||
|
except Exception as e:
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
|
||||||
def record_exception():
|
def record_exception():
|
||||||
_, e, tb = sys.exc_info()
|
_, e, tb = sys.exc_info()
|
||||||
if e is None:
|
if e is None:
|
||||||
@ -14,8 +29,7 @@ def record_exception():
|
|||||||
if exception_records and exception_records[-1] == e:
|
if exception_records and exception_records[-1] == e:
|
||||||
return
|
return
|
||||||
|
|
||||||
from modules import sysinfo
|
exception_records.append(format_exception(e, tb))
|
||||||
exception_records.append(sysinfo.format_exception(e, tb))
|
|
||||||
|
|
||||||
if len(exception_records) > 5:
|
if len(exception_records) > 5:
|
||||||
exception_records.pop(0)
|
exception_records.pop(0)
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import configparser
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import re
|
||||||
|
|
||||||
from modules import shared, errors, cache, scripts
|
from modules import shared, errors, cache, scripts
|
||||||
from modules.gitpython_hack import Repo
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
|
||||||
|
|
||||||
os.makedirs(extensions_dir, exist_ok=True)
|
os.makedirs(extensions_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -19,11 +22,55 @@ def active():
|
|||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionMetadata:
|
||||||
|
filename = "metadata.ini"
|
||||||
|
config: configparser.ConfigParser
|
||||||
|
canonical_name: str
|
||||||
|
requires: list
|
||||||
|
|
||||||
|
def __init__(self, path, canonical_name):
|
||||||
|
self.config = configparser.ConfigParser()
|
||||||
|
|
||||||
|
filepath = os.path.join(path, self.filename)
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
try:
|
||||||
|
self.config.read(filepath)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
|
||||||
|
|
||||||
|
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||||
|
self.canonical_name = canonical_name.lower().strip()
|
||||||
|
|
||||||
|
self.requires = self.get_script_requirements("Requires", "Extension")
|
||||||
|
|
||||||
|
def get_script_requirements(self, field, section, extra_section=None):
|
||||||
|
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
||||||
|
like Requires or Before, and section is the name of the [section] in the ini file; additionally,
|
||||||
|
reads more requirements from [extra_section] if specified."""
|
||||||
|
|
||||||
|
x = self.config.get(section, field, fallback='')
|
||||||
|
|
||||||
|
if extra_section:
|
||||||
|
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
||||||
|
|
||||||
|
return self.parse_list(x.lower())
|
||||||
|
|
||||||
|
def parse_list(self, text):
|
||||||
|
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# both "," and " " are accepted as separator
|
||||||
|
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
|
||||||
|
metadata: ExtensionMetadata
|
||||||
|
|
||||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.path = path
|
self.path = path
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
@ -36,6 +83,8 @@ class Extension:
|
|||||||
self.branch = None
|
self.branch = None
|
||||||
self.remote = None
|
self.remote = None
|
||||||
self.have_info_from_repo = False
|
self.have_info_from_repo = False
|
||||||
|
self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
|
||||||
|
self.canonical_name = metadata.canonical_name
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {x: getattr(self, x) for x in self.cached_fields}
|
return {x: getattr(self, x) for x in self.cached_fields}
|
||||||
@ -56,6 +105,7 @@ class Extension:
|
|||||||
self.do_read_info_from_repo()
|
self.do_read_info_from_repo()
|
||||||
|
|
||||||
return self.to_dict()
|
return self.to_dict()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||||
self.from_dict(d)
|
self.from_dict(d)
|
||||||
@ -136,9 +186,6 @@ class Extension:
|
|||||||
def list_extensions():
|
def list_extensions():
|
||||||
extensions.clear()
|
extensions.clear()
|
||||||
|
|
||||||
if not os.path.isdir(extensions_dir):
|
|
||||||
return
|
|
||||||
|
|
||||||
if shared.cmd_opts.disable_all_extensions:
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
elif shared.opts.disable_all_extensions == "all":
|
elif shared.opts.disable_all_extensions == "all":
|
||||||
@ -148,18 +195,43 @@ def list_extensions():
|
|||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
extension_paths = []
|
loaded_extensions = {}
|
||||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
|
||||||
|
# scan through extensions directory and load metadata
|
||||||
|
for dirname in [extensions_builtin_dir, extensions_dir]:
|
||||||
if not os.path.isdir(dirname):
|
if not os.path.isdir(dirname):
|
||||||
return
|
continue
|
||||||
|
|
||||||
for extension_dirname in sorted(os.listdir(dirname)):
|
for extension_dirname in sorted(os.listdir(dirname)):
|
||||||
path = os.path.join(dirname, extension_dirname)
|
path = os.path.join(dirname, extension_dirname)
|
||||||
if not os.path.isdir(path):
|
if not os.path.isdir(path):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
canonical_name = extension_dirname
|
||||||
|
metadata = ExtensionMetadata(path, canonical_name)
|
||||||
|
|
||||||
for dirname, path, is_builtin in extension_paths:
|
# check for duplicated canonical names
|
||||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
|
||||||
extensions.append(extension)
|
if already_loaded_extension is not None:
|
||||||
|
errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_builtin = dirname == extensions_builtin_dir
|
||||||
|
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
||||||
|
extensions.append(extension)
|
||||||
|
loaded_extensions[canonical_name] = extension
|
||||||
|
|
||||||
|
# check for requirements
|
||||||
|
for extension in extensions:
|
||||||
|
for req in extension.metadata.requires:
|
||||||
|
required_extension = loaded_extensions.get(req)
|
||||||
|
if required_extension is None:
|
||||||
|
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not extension.enabled:
|
||||||
|
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
extensions: list[Extension] = []
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
@ -9,15 +10,12 @@ from modules.paths import data_path
|
|||||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||||
type_of_gr_update = type(gr.update())
|
type_of_gr_update = type(gr.update())
|
||||||
|
|
||||||
paste_fields = {}
|
|
||||||
registered_param_bindings = []
|
|
||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
@ -30,6 +28,10 @@ class ParamBinding:
|
|||||||
self.paste_field_names = paste_field_names or []
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
|
paste_fields: dict[str, dict] = {}
|
||||||
|
registered_param_bindings: list[ParamBinding] = []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
paste_fields.clear()
|
paste_fields.clear()
|
||||||
registered_param_bindings.clear()
|
registered_param_bindings.clear()
|
||||||
@ -113,7 +115,6 @@ def register_paste_params_button(binding: ParamBinding):
|
|||||||
|
|
||||||
|
|
||||||
def connect_paste_params_buttons():
|
def connect_paste_params_buttons():
|
||||||
binding: ParamBinding
|
|
||||||
for binding in registered_param_bindings:
|
for binding in registered_param_bindings:
|
||||||
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
destination_image_component = paste_fields[binding.tabname]["init_img"]
|
||||||
fields = paste_fields[binding.tabname]["fields"]
|
fields = paste_fields[binding.tabname]["fields"]
|
||||||
@ -313,6 +314,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "VAE Decoder" not in res:
|
if "VAE Decoder" not in res:
|
||||||
res["VAE Decoder"] = "Full"
|
res["VAE Decoder"] = "Full"
|
||||||
|
|
||||||
|
skip = set(shared.opts.infotext_skip_pasting)
|
||||||
|
res = {k: v for k, v in res.items() if k not in skip}
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -443,3 +447,4 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
outputs=[],
|
outputs=[],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
|
|||||||
model_dir = "GFPGAN"
|
model_dir = "GFPGAN"
|
||||||
user_path = None
|
user_path = None
|
||||||
model_path = os.path.join(paths.models_path, model_dir)
|
model_path = os.path.join(paths.models_path, model_dir)
|
||||||
|
model_file_path = None
|
||||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||||
have_gfpgan = False
|
have_gfpgan = False
|
||||||
loaded_gfpgan_model = None
|
loaded_gfpgan_model = None
|
||||||
@ -17,6 +18,7 @@ loaded_gfpgan_model = None
|
|||||||
def gfpgann():
|
def gfpgann():
|
||||||
global loaded_gfpgan_model
|
global loaded_gfpgan_model
|
||||||
global model_path
|
global model_path
|
||||||
|
global model_file_path
|
||||||
if loaded_gfpgan_model is not None:
|
if loaded_gfpgan_model is not None:
|
||||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||||
return loaded_gfpgan_model
|
return loaded_gfpgan_model
|
||||||
@ -24,17 +26,24 @@ def gfpgann():
|
|||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||||
|
|
||||||
if len(models) == 1 and models[0].startswith("http"):
|
if len(models) == 1 and models[0].startswith("http"):
|
||||||
model_file = models[0]
|
model_file = models[0]
|
||||||
elif len(models) != 0:
|
elif len(models) != 0:
|
||||||
latest_file = max(models, key=os.path.getctime)
|
gfp_models = []
|
||||||
|
for item in models:
|
||||||
|
if 'GFPGAN' in os.path.basename(item):
|
||||||
|
gfp_models.append(item)
|
||||||
|
latest_file = max(gfp_models, key=os.path.getctime)
|
||||||
model_file = latest_file
|
model_file = latest_file
|
||||||
else:
|
else:
|
||||||
print("Unable to load gfpgan model!")
|
print("Unable to load gfpgan model!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||||
|
model_file_path = model_file
|
||||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
@ -77,19 +86,25 @@ def setup_model(dirname):
|
|||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
global model_file_path
|
||||||
|
|
||||||
|
facexlib_path = model_path
|
||||||
|
|
||||||
|
if dirname is not None:
|
||||||
|
facexlib_path = dirname
|
||||||
|
|
||||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||||
|
|
||||||
def my_load_file_from_url(**kwargs):
|
def my_load_file_from_url(**kwargs):
|
||||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||||
|
|
||||||
def facex_load_file_from_url(**kwargs):
|
def facex_load_file_from_url(**kwargs):
|
||||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||||
|
|
||||||
def facex_load_file_from_url2(**kwargs):
|
def facex_load_file_from_url2(**kwargs):
|
||||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||||
|
|
||||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||||
|
@ -23,7 +23,7 @@ class Git(git.Git):
|
|||||||
)
|
)
|
||||||
return self._parse_object_header(ret)
|
return self._parse_object_header(ret)
|
||||||
|
|
||||||
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
|
||||||
# Not really streaming, per se; this buffers the entire object in memory.
|
# Not really streaming, per se; this buffers the entire object in memory.
|
||||||
# Shouldn't be a problem for our use case, since we're only using this for
|
# Shouldn't be a problem for our use case, since we're only using this for
|
||||||
# object headers (commit objects).
|
# object headers (commit objects).
|
||||||
|
@ -47,10 +47,20 @@ def Block_get_config(self):
|
|||||||
|
|
||||||
|
|
||||||
def BlockContext_init(self, *args, **kwargs):
|
def BlockContext_init(self, *args, **kwargs):
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.before_component(self, **kwargs)
|
||||||
|
|
||||||
|
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||||
|
|
||||||
res = original_BlockContext_init(self, *args, **kwargs)
|
res = original_BlockContext_init(self, *args, **kwargs)
|
||||||
|
|
||||||
add_classes_to_gradio_component(self)
|
add_classes_to_gradio_component(self)
|
||||||
|
|
||||||
|
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||||
|
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.after_component(self, **kwargs)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
||||||
from modules import images, processing
|
from modules import images, processing
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||||
p.cfg_scale = preview_cfg_scale
|
p.cfg_scale = preview_cfg_scale
|
||||||
p.seed = preview_seed
|
p.seed = preview_seed
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
|
@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
|
|||||||
})
|
})
|
||||||
|
|
||||||
piexif.insert(exif_bytes, filename)
|
piexif.insert(exif_bytes, filename)
|
||||||
|
elif extension.lower() == ".gif":
|
||||||
|
image.save(filename, format=image_format, comment=geninfo)
|
||||||
else:
|
else:
|
||||||
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||||
|
|
||||||
@ -661,7 +663,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
|
|
||||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||||
|
|
||||||
os.replace(temp_file_path, filename_without_extension + extension)
|
filename = filename_without_extension + extension
|
||||||
|
if shared.opts.save_images_replace_action != "Replace":
|
||||||
|
n = 0
|
||||||
|
while os.path.exists(filename):
|
||||||
|
n += 1
|
||||||
|
filename = f"{filename_without_extension}-{n}{extension}"
|
||||||
|
os.replace(temp_file_path, filename)
|
||||||
|
|
||||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||||
if hasattr(os, 'statvfs'):
|
if hasattr(os, 'statvfs'):
|
||||||
@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
|||||||
geninfo = items.pop('parameters', None)
|
geninfo = items.pop('parameters', None)
|
||||||
|
|
||||||
if "exif" in items:
|
if "exif" in items:
|
||||||
exif = piexif.load(items["exif"])
|
exif_data = items["exif"]
|
||||||
|
try:
|
||||||
|
exif = piexif.load(exif_data)
|
||||||
|
except OSError:
|
||||||
|
# memory / exif was not valid so piexif tried to read from a file
|
||||||
|
exif = None
|
||||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||||
try:
|
try:
|
||||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||||
@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
|||||||
if exif_comment:
|
if exif_comment:
|
||||||
items['exif comment'] = exif_comment
|
items['exif comment'] = exif_comment
|
||||||
geninfo = exif_comment
|
geninfo = exif_comment
|
||||||
|
elif "comment" in items: # for gif
|
||||||
|
geninfo = items["comment"].decode('utf8', errors="ignore")
|
||||||
|
|
||||||
for field in IGNORED_INFO_KEYS:
|
for field in IGNORED_INFO_KEYS:
|
||||||
items.pop(field, None)
|
items.pop(field, None)
|
||||||
|
@ -10,6 +10,7 @@ from modules import images as imgutil
|
|||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
|
from modules.sd_models import get_closet_checkpoint_match
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
@ -41,7 +42,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
cfg_scale = p.cfg_scale
|
cfg_scale = p.cfg_scale
|
||||||
sampler_name = p.sampler_name
|
sampler_name = p.sampler_name
|
||||||
steps = p.steps
|
steps = p.steps
|
||||||
|
override_settings = p.override_settings
|
||||||
|
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||||
|
batch_results = None
|
||||||
|
discard_further_results = False
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(images):
|
||||||
state.job = f"{i+1} out of {len(images)}"
|
state.job = f"{i+1} out of {len(images)}"
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
@ -104,16 +108,42 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||||
|
|
||||||
|
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
|
||||||
|
if model_info is not None:
|
||||||
|
p.override_settings['sd_model_checkpoint'] = model_info.name
|
||||||
|
elif sd_model_checkpoint_override:
|
||||||
|
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
|
||||||
|
else:
|
||||||
|
p.override_settings.pop("sd_model_checkpoint", None)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
p.outpath_samples = output_dir
|
||||||
|
p.override_settings['save_to_dirs'] = False
|
||||||
|
p.override_settings['save_images_replace_action'] = "Add number suffix"
|
||||||
|
if p.n_iter > 1 or p.batch_size > 1:
|
||||||
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||||
|
else:
|
||||||
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||||
|
|
||||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
|
|
||||||
if proc is None:
|
if proc is None:
|
||||||
if output_dir:
|
p.override_settings.pop('save_images_replace_action', None)
|
||||||
p.outpath_samples = output_dir
|
proc = process_images(p)
|
||||||
p.override_settings['save_to_dirs'] = False
|
|
||||||
if p.n_iter > 1 or p.batch_size > 1:
|
if not discard_further_results and proc:
|
||||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
if batch_results:
|
||||||
else:
|
batch_results.images.extend(proc.images)
|
||||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
batch_results.infotexts.extend(proc.infotexts)
|
||||||
process_images(p)
|
else:
|
||||||
|
batch_results = proc
|
||||||
|
|
||||||
|
if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
|
||||||
|
discard_further_results = True
|
||||||
|
batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||||
|
batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
|
||||||
|
|
||||||
|
return batch_results
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
@ -189,7 +219,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
|
|
||||||
p.user = request.username
|
p.user = request.username
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
if mask:
|
if mask:
|
||||||
@ -198,10 +228,10 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
with closing(p):
|
with closing(p):
|
||||||
if is_batch:
|
if is_batch:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||||
|
processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
||||||
|
|
||||||
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
|
if processed is None:
|
||||||
|
processed = Processed(p, [], p.seed, "")
|
||||||
processed = Processed(p, [], p.seed, "")
|
|
||||||
else:
|
else:
|
||||||
processed = modules.scripts.scripts_img2img.run(p, *args)
|
processed = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if processed is None:
|
if processed is None:
|
||||||
|
@ -3,3 +3,14 @@ import sys
|
|||||||
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
|
||||||
if "--xformers" not in "".join(sys.argv):
|
if "--xformers" not in "".join(sys.argv):
|
||||||
sys.modules["xformers"] = None
|
sys.modules["xformers"] = None
|
||||||
|
|
||||||
|
# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks
|
||||||
|
# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985
|
||||||
|
try:
|
||||||
|
import torchvision.transforms.functional_tensor # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import torchvision.transforms.functional as functional
|
||||||
|
sys.modules["torchvision.transforms.functional_tensor"] = functional
|
||||||
|
except ImportError:
|
||||||
|
pass # shrug...
|
||||||
|
@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
|
|
||||||
from modules import devices
|
from modules import devices
|
||||||
devices.first_time_calculation()
|
devices.first_time_calculation()
|
||||||
|
if not shared.cmd_opts.skip_load_model_at_start:
|
||||||
Thread(target=load_model).start()
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
from modules import shared_items
|
from modules import shared_items
|
||||||
shared_items.reload_hypernetworks()
|
shared_items.reload_hypernetworks()
|
||||||
|
@ -150,10 +150,14 @@ def dumpstacks():
|
|||||||
|
|
||||||
def configure_sigint_handler():
|
def configure_sigint_handler():
|
||||||
# make the program just exit at ctrl+c without waiting for anything
|
# make the program just exit at ctrl+c without waiting for anything
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
def sigint_handler(sig, frame):
|
def sigint_handler(sig, frame):
|
||||||
print(f'Interrupted with signal {sig} in {frame}')
|
print(f'Interrupted with signal {sig} in {frame}')
|
||||||
|
|
||||||
dumpstacks()
|
if shared.opts.dump_stacks_on_signal:
|
||||||
|
dumpstacks()
|
||||||
|
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import importlib.metadata
|
||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -64,7 +65,7 @@ Use --skip-python-version-check to suppress this warning.
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def commit_hash():
|
def commit_hash():
|
||||||
try:
|
try:
|
||||||
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return "<none>"
|
return "<none>"
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ def commit_hash():
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def git_tag():
|
def git_tag():
|
||||||
try:
|
try:
|
||||||
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@ -119,11 +120,16 @@ def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_
|
|||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
try:
|
try:
|
||||||
spec = importlib.util.find_spec(package)
|
dist = importlib.metadata.distribution(package)
|
||||||
except ModuleNotFoundError:
|
except importlib.metadata.PackageNotFoundError:
|
||||||
return False
|
try:
|
||||||
|
spec = importlib.util.find_spec(package)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
return False
|
||||||
|
|
||||||
return spec is not None
|
return spec is not None
|
||||||
|
|
||||||
|
return dist is not None
|
||||||
|
|
||||||
|
|
||||||
def repo_dir(name):
|
def repo_dir(name):
|
||||||
@ -310,6 +316,26 @@ def requirements_met(requirements_file):
|
|||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
|
if args.use_ipex:
|
||||||
|
if platform.system() == "Windows":
|
||||||
|
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||||
|
# This is NOT an Intel official release so please use it at your own risk!!
|
||||||
|
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
|
||||||
|
#
|
||||||
|
# Strengths (over official IPEX 2.0.110 windows release):
|
||||||
|
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
|
||||||
|
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
|
||||||
|
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
|
||||||
|
# Limitation:
|
||||||
|
# - Only works for python 3.10
|
||||||
|
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
|
||||||
|
else:
|
||||||
|
# Using official IPEX release for linux since it's already an AOT build.
|
||||||
|
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
|
||||||
|
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
|
||||||
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
||||||
@ -352,6 +378,8 @@ def prepare_environment():
|
|||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
startup_timer.record("install torch")
|
startup_timer.record("install torch")
|
||||||
|
|
||||||
|
if args.use_ipex:
|
||||||
|
args.skip_torch_cuda_test = True
|
||||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Torch is not able to use GPU; '
|
'Torch is not able to use GPU; '
|
||||||
@ -441,7 +469,7 @@ def dump_sysinfo():
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
text = sysinfo.get()
|
text = sysinfo.get()
|
||||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
file.write(text)
|
file.write(text)
|
||||||
|
@ -14,21 +14,24 @@ def list_localizations(dirname):
|
|||||||
if ext.lower() != ".json":
|
if ext.lower() != ".json":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
localizations[fn] = os.path.join(dirname, file)
|
localizations[fn] = [os.path.join(dirname, file)]
|
||||||
|
|
||||||
for file in scripts.list_scripts("localizations", ".json"):
|
for file in scripts.list_scripts("localizations", ".json"):
|
||||||
fn, ext = os.path.splitext(file.filename)
|
fn, ext = os.path.splitext(file.filename)
|
||||||
localizations[fn] = file.path
|
if fn not in localizations:
|
||||||
|
localizations[fn] = []
|
||||||
|
localizations[fn].append(file.path)
|
||||||
|
|
||||||
|
|
||||||
def localization_js(current_localization_name: str) -> str:
|
def localization_js(current_localization_name: str) -> str:
|
||||||
fn = localizations.get(current_localization_name, None)
|
fns = localizations.get(current_localization_name, None)
|
||||||
data = {}
|
data = {}
|
||||||
if fn is not None:
|
if fns is not None:
|
||||||
try:
|
for fn in fns:
|
||||||
with open(fn, "r", encoding="utf8") as file:
|
try:
|
||||||
data = json.load(file)
|
with open(fn, "r", encoding="utf8") as file:
|
||||||
except Exception:
|
data.update(json.load(file))
|
||||||
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
except Exception:
|
||||||
|
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||||
|
|
||||||
return f"window.localization = {json.dumps(data)}"
|
return f"window.localization = {json.dumps(data)}"
|
||||||
|
@ -1,16 +1,41 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
class TqdmLoggingHandler(logging.Handler):
|
||||||
|
def __init__(self, level=logging.INFO):
|
||||||
|
super().__init__(level)
|
||||||
|
|
||||||
|
def emit(self, record):
|
||||||
|
try:
|
||||||
|
msg = self.format(record)
|
||||||
|
tqdm.write(msg)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
TQDM_IMPORTED = True
|
||||||
|
except ImportError:
|
||||||
|
# tqdm does not exist before first launch
|
||||||
|
# I will import once the UI finishes seting up the enviroment and reloads.
|
||||||
|
TQDM_IMPORTED = False
|
||||||
|
|
||||||
def setup_logging(loglevel):
|
def setup_logging(loglevel):
|
||||||
if loglevel is None:
|
if loglevel is None:
|
||||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
|
loghandlers = []
|
||||||
|
|
||||||
|
if TQDM_IMPORTED:
|
||||||
|
loghandlers.append(TqdmLoggingHandler())
|
||||||
|
|
||||||
if loglevel:
|
if loglevel:
|
||||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=log_level,
|
level=log_level,
|
||||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
datefmt='%Y-%m-%d %H:%M:%S',
|
||||||
|
handlers=loghandlers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
import platform
|
import platform
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
return cumsum_func(input, *args, **kwargs)
|
return cumsum_func(input, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
|
def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
|
||||||
|
try:
|
||||||
|
return orig_func(*args, **kwargs)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "not implemented for" in str(e) and "Half" in str(e):
|
||||||
|
input_tensor = args[0]
|
||||||
|
return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
|
||||||
|
else:
|
||||||
|
print(f"An unexpected RuntimeError occurred: {str(e)}")
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
@ -77,6 +89,9 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
|
||||||
|
CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
|
||||||
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
if platform.processor() == 'i386':
|
if platform.processor() == 'i386':
|
||||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
|
@ -24,10 +24,15 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
|||||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||||
from ldm.modules.ema import LitEma
|
from ldm.modules.ema import LitEma
|
||||||
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
||||||
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
|
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
|
||||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ldm.models.autoencoder import VQModelInterface
|
||||||
|
except Exception:
|
||||||
|
class VQModelInterface:
|
||||||
|
pass
|
||||||
|
|
||||||
__conditioning_keys__ = {'concat': 'c_concat',
|
__conditioning_keys__ = {'concat': 'c_concat',
|
||||||
'crossattn': 'c_crossattn',
|
'crossattn': 'c_crossattn',
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@ -8,13 +9,14 @@ from modules.shared_cmd_options import cmd_opts
|
|||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
self.component_args = component_args
|
self.component_args = component_args
|
||||||
self.onchange = onchange
|
self.onchange = onchange
|
||||||
self.section = section
|
self.section = section
|
||||||
|
self.category_id = category_id
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
self.do_not_save = False
|
self.do_not_save = False
|
||||||
|
|
||||||
@ -63,7 +65,11 @@ class OptionHTML(OptionInfo):
|
|||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for v in options_dict.values():
|
for v in options_dict.values():
|
||||||
v.section = section_identifier
|
if len(section_identifier) == 2:
|
||||||
|
v.section = section_identifier
|
||||||
|
elif len(section_identifier) == 3:
|
||||||
|
v.section = section_identifier[0:2]
|
||||||
|
v.category_id = section_identifier[2]
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
|
|
||||||
@ -76,7 +82,7 @@ class Options:
|
|||||||
|
|
||||||
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):
|
||||||
self.data_labels = data_labels
|
self.data_labels = data_labels
|
||||||
self.data = {k: v.default for k, v in self.data_labels.items()}
|
self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}
|
||||||
self.restricted_opts = restricted_opts
|
self.restricted_opts = restricted_opts
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
@ -158,7 +164,7 @@ class Options:
|
|||||||
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
assert not cmd_opts.freeze_settings, "saving settings is disabled"
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf8") as file:
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
json.dump(self.data, file, indent=4)
|
json.dump(self.data, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def same_type(self, x, y):
|
def same_type(self, x, y):
|
||||||
if x is None or y is None:
|
if x is None or y is None:
|
||||||
@ -206,21 +212,59 @@ class Options:
|
|||||||
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
|
|
||||||
|
item_categories = {}
|
||||||
|
for item in self.data_labels.values():
|
||||||
|
category = categories.mapping.get(item.category_id)
|
||||||
|
category = "Uncategorized" if category is None else category.label
|
||||||
|
if category not in item_categories:
|
||||||
|
item_categories[category] = item.section[1]
|
||||||
|
|
||||||
|
# _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.
|
||||||
|
d["_categories"] = [[v, k] for k, v in item_categories.items()] + [["Defaults", "Other"]]
|
||||||
|
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
|
||||||
def add_option(self, key, info):
|
def add_option(self, key, info):
|
||||||
self.data_labels[key] = info
|
self.data_labels[key] = info
|
||||||
|
if key not in self.data and not info.do_not_save:
|
||||||
|
self.data[key] = info.default
|
||||||
|
|
||||||
def reorder(self):
|
def reorder(self):
|
||||||
"""reorder settings so that all items related to section always go together"""
|
"""Reorder settings so that:
|
||||||
|
- all items related to section always go together
|
||||||
|
- all sections belonging to a category go together
|
||||||
|
- sections inside a category are ordered alphabetically
|
||||||
|
- categories are ordered by creation order
|
||||||
|
|
||||||
|
Category is a superset of sections: for category "postprocessing" there could be multiple sections: "face restoration", "upscaling".
|
||||||
|
|
||||||
|
This function also changes items' category_id so that all items belonging to a section have the same category_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category_ids = {}
|
||||||
|
section_categories = {}
|
||||||
|
|
||||||
section_ids = {}
|
|
||||||
settings_items = self.data_labels.items()
|
settings_items = self.data_labels.items()
|
||||||
for _, item in settings_items:
|
for _, item in settings_items:
|
||||||
if item.section not in section_ids:
|
if item.section not in section_categories:
|
||||||
section_ids[item.section] = len(section_ids)
|
section_categories[item.section] = item.category_id
|
||||||
|
|
||||||
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
for _, item in settings_items:
|
||||||
|
item.category_id = section_categories.get(item.section)
|
||||||
|
|
||||||
|
for category_id in categories.mapping:
|
||||||
|
if category_id not in category_ids:
|
||||||
|
category_ids[category_id] = len(category_ids)
|
||||||
|
|
||||||
|
def sort_key(x):
|
||||||
|
item: OptionInfo = x[1]
|
||||||
|
category_order = category_ids.get(item.category_id, len(category_ids))
|
||||||
|
section_order = item.section[1]
|
||||||
|
|
||||||
|
return category_order, section_order
|
||||||
|
|
||||||
|
self.data_labels = dict(sorted(settings_items, key=sort_key))
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
def cast_value(self, key, value):
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
@ -243,3 +287,22 @@ class Options:
|
|||||||
value = expected_type(value)
|
value = expected_type(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OptionsCategory:
|
||||||
|
id: str
|
||||||
|
label: str
|
||||||
|
|
||||||
|
class OptionsCategories:
|
||||||
|
def __init__(self):
|
||||||
|
self.mapping = {}
|
||||||
|
|
||||||
|
def register_category(self, category_id, label):
|
||||||
|
if category_id in self.mapping:
|
||||||
|
return category_id
|
||||||
|
|
||||||
|
self.mapping[category_id] = OptionsCategory(category_id, label)
|
||||||
|
|
||||||
|
|
||||||
|
categories = OptionsCategories()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||||
|
|
||||||
import modules.safe # noqa: F401
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import shlex
|
|||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
sys.argv += shlex.split(commandline_args)
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
|
cwd = os.getcwd()
|
||||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
script_path = os.path.dirname(modules_path)
|
script_path = os.path.dirname(modules_path)
|
||||||
|
|
||||||
|
@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
image_list = shared.listfiles(input_dir)
|
image_list = shared.listfiles(input_dir)
|
||||||
for filename in image_list:
|
for filename in image_list:
|
||||||
try:
|
yield filename, filename
|
||||||
image = Image.open(filename)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
yield image, filename
|
|
||||||
else:
|
else:
|
||||||
assert image, 'image not selected'
|
assert image, 'image not selected'
|
||||||
yield image, None
|
yield image, None
|
||||||
@ -45,43 +41,97 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
infotext = ''
|
infotext = ''
|
||||||
|
|
||||||
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
|
||||||
|
shared.state.job_count = len(data_to_process)
|
||||||
|
|
||||||
|
for image_placeholder, name in data_to_process:
|
||||||
image_data: Image.Image
|
image_data: Image.Image
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
shared.state.skipped = False
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
if isinstance(image_placeholder, str):
|
||||||
|
try:
|
||||||
|
image_data = Image.open(image_placeholder)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
image_data = image_placeholder
|
||||||
|
|
||||||
|
shared.state.assign_current_image(image_data)
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(pp, args)
|
scripts.scripts_postproc.run(initial_pp, args)
|
||||||
|
|
||||||
if opts.use_original_name_batch and name is not None:
|
if shared.state.skipped:
|
||||||
basename = os.path.splitext(os.path.basename(name))[0]
|
continue
|
||||||
else:
|
|
||||||
basename = ''
|
|
||||||
|
|
||||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
used_suffixes = {}
|
||||||
|
for pp in [initial_pp, *initial_pp.extra_images]:
|
||||||
|
suffix = pp.get_suffix(used_suffixes)
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.use_original_name_batch and name is not None:
|
||||||
pp.image.info = existing_pnginfo
|
basename = os.path.splitext(os.path.basename(name))[0]
|
||||||
pp.image.info["postprocessing"] = infotext
|
forced_filename = basename + suffix
|
||||||
|
else:
|
||||||
|
basename = ''
|
||||||
|
forced_filename = None
|
||||||
|
|
||||||
if save_output:
|
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||||
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
|
||||||
|
|
||||||
if extras_mode != 2 or show_extras_results:
|
if opts.enable_pnginfo:
|
||||||
outputs.append(pp.image)
|
pp.image.info = existing_pnginfo
|
||||||
|
pp.image.info["postprocessing"] = infotext
|
||||||
|
|
||||||
|
if save_output:
|
||||||
|
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
||||||
|
|
||||||
|
if pp.caption:
|
||||||
|
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
||||||
|
if os.path.isfile(caption_filename):
|
||||||
|
with open(caption_filename, encoding="utf8") as file:
|
||||||
|
existing_caption = file.read().strip()
|
||||||
|
else:
|
||||||
|
existing_caption = ""
|
||||||
|
|
||||||
|
action = shared.opts.postprocessing_existing_caption_action
|
||||||
|
if action == 'Prepend' and existing_caption:
|
||||||
|
caption = f"{existing_caption} {pp.caption}"
|
||||||
|
elif action == 'Append' and existing_caption:
|
||||||
|
caption = f"{pp.caption} {existing_caption}"
|
||||||
|
elif action == 'Keep' and existing_caption:
|
||||||
|
caption = existing_caption
|
||||||
|
else:
|
||||||
|
caption = pp.caption
|
||||||
|
|
||||||
|
caption = caption.strip()
|
||||||
|
if caption:
|
||||||
|
with open(caption_filename, "w", encoding="utf8") as file:
|
||||||
|
file.write(caption)
|
||||||
|
|
||||||
|
if extras_mode != 2 or show_extras_results:
|
||||||
|
outputs.append(pp.image)
|
||||||
|
|
||||||
image_data.close()
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
shared.state.end()
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
|
||||||
|
|
||||||
|
def run_postprocessing_webui(id_task, *args, **kwargs):
|
||||||
|
return run_postprocessing(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
|
||||||
"""old handler for API"""
|
"""old handler for API"""
|
||||||
|
|
||||||
@ -97,9 +147,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||||||
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
"upscaler_2_visibility": extras_upscaler_2_visibility,
|
||||||
},
|
},
|
||||||
"GFPGAN": {
|
"GFPGAN": {
|
||||||
|
"enable": True,
|
||||||
"gfpgan_visibility": gfpgan_visibility,
|
"gfpgan_visibility": gfpgan_visibility,
|
||||||
},
|
},
|
||||||
"CodeFormer": {
|
"CodeFormer": {
|
||||||
|
"enable": True,
|
||||||
"codeformer_visibility": codeformer_visibility,
|
"codeformer_visibility": codeformer_visibility,
|
||||||
"codeformer_weight": codeformer_weight,
|
"codeformer_weight": codeformer_weight,
|
||||||
},
|
},
|
||||||
|
@ -142,7 +142,7 @@ class StableDiffusionProcessing:
|
|||||||
overlay_images: list = None
|
overlay_images: list = None
|
||||||
eta: float = None
|
eta: float = None
|
||||||
do_not_reload_embeddings: bool = False
|
do_not_reload_embeddings: bool = False
|
||||||
denoising_strength: float = 0
|
denoising_strength: float = None
|
||||||
ddim_discretize: str = None
|
ddim_discretize: str = None
|
||||||
s_min_uncond: float = None
|
s_min_uncond: float = None
|
||||||
s_churn: float = None
|
s_churn: float = None
|
||||||
@ -296,7 +296,7 @@ class StableDiffusionProcessing:
|
|||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def edit_image_conditioning(self, source_image):
|
def edit_image_conditioning(self, source_image):
|
||||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||||
|
|
||||||
return conditioning_image
|
return conditioning_image
|
||||||
|
|
||||||
@ -533,6 +533,7 @@ class Processed:
|
|||||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||||
self.infotexts = infotexts or [info]
|
self.infotexts = infotexts or [info]
|
||||||
|
self.version = program_version()
|
||||||
|
|
||||||
def js(self):
|
def js(self):
|
||||||
obj = {
|
obj = {
|
||||||
@ -567,6 +568,7 @@ class Processed:
|
|||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
"clip_skip": self.clip_skip,
|
"clip_skip": self.clip_skip,
|
||||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||||
|
"version": self.version,
|
||||||
}
|
}
|
||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
@ -677,8 +679,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||||
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.before_process(p)
|
p.scripts.before_process(p)
|
||||||
|
|
||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
@ -797,7 +799,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
@ -871,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
if opts.sd_vae_decode_method != 'Full':
|
if opts.sd_vae_decode_method != 'Full':
|
||||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
@ -884,6 +884,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
state.nextjob()
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
@ -936,27 +938,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
if opts.return_mask or opts.save_mask:
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
|
if save_samples and opts.save_mask:
|
||||||
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
|
if opts.return_mask:
|
||||||
|
output_images.append(image_mask)
|
||||||
|
|
||||||
if opts.save_mask:
|
if opts.return_mask_composite or opts.save_mask_composite:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
|
if save_samples and opts.save_mask_composite:
|
||||||
if opts.save_mask_composite:
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
if opts.return_mask_composite:
|
||||||
|
output_images.append(image_mask_composite)
|
||||||
if opts.return_mask:
|
|
||||||
output_images.append(image_mask)
|
|
||||||
|
|
||||||
if opts.return_mask_composite:
|
|
||||||
output_images.append(image_mask_composite)
|
|
||||||
|
|
||||||
del x_samples_ddim
|
del x_samples_ddim
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
if not infotexts:
|
||||||
|
infotexts.append(Processed(p, []).infotext(p, 0))
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
|
|
||||||
@ -1142,6 +1144,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
@ -1151,8 +1154,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
|
||||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
@ -1160,7 +1161,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
@ -1249,7 +1249,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return decoded_samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -29,8 +29,8 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
|||||||
else:
|
else:
|
||||||
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||||
|
|
||||||
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
|
||||||
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
|
||||||
|
|
||||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||||
|
|
||||||
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List
|
|
||||||
import lark
|
import lark
|
||||||
|
|
||||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
|
||||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||||
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
||||||
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
||||||
@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
|||||||
|
|
||||||
class ComposableScheduledPromptConditioning:
|
class ComposableScheduledPromptConditioning:
|
||||||
def __init__(self, schedules, weight=1.0):
|
def __init__(self, schedules, weight=1.0):
|
||||||
self.schedules: List[ScheduledPromptConditioning] = schedules
|
self.schedules: list[ScheduledPromptConditioning] = schedules
|
||||||
self.weight: float = weight
|
self.weight: float = weight
|
||||||
|
|
||||||
|
|
||||||
class MulticondLearnedConditioning:
|
class MulticondLearnedConditioning:
|
||||||
def __init__(self, shape, batch):
|
def __init__(self, shape, batch):
|
||||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||||
@ -278,7 +277,7 @@ class DictWithShape(dict):
|
|||||||
return self["crossattn"].shape
|
return self["crossattn"].shape
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||||
param = c[0][0].cond
|
param = c[0][0].cond
|
||||||
is_dict = isinstance(param, dict)
|
is_dict = isinstance(param, dict)
|
||||||
|
|
||||||
|
@ -14,7 +14,9 @@ def is_restartable() -> bool:
|
|||||||
def restart_program() -> None:
|
def restart_program() -> None:
|
||||||
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
||||||
|
|
||||||
(Path(script_path) / "tmp" / "restart").touch()
|
tmpdir = Path(script_path) / "tmp"
|
||||||
|
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(tmpdir / "restart").touch()
|
||||||
|
|
||||||
stop_program()
|
stop_program()
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class ImageRNG:
|
|||||||
self.is_first = True
|
self.is_first = True
|
||||||
|
|
||||||
def first(self):
|
def first(self):
|
||||||
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
|
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
|
||||||
|
|
||||||
xs = []
|
xs = []
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams):
|
|||||||
report_exception(c, 'image_grid')
|
report_exception(c, 'image_grid')
|
||||||
|
|
||||||
|
|
||||||
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
||||||
for c in callback_map['callbacks_infotext_pasted']:
|
for c in callback_map['callbacks_infotext_pasted']:
|
||||||
try:
|
try:
|
||||||
c.callback(infotext, params)
|
c.callback(infotext, params)
|
||||||
@ -449,7 +449,7 @@ def on_infotext_pasted(callback):
|
|||||||
"""register a function to be called before applying an infotext.
|
"""register a function to be called before applying an infotext.
|
||||||
The callback is called with two arguments:
|
The callback is called with two arguments:
|
||||||
- infotext: str - raw infotext.
|
- infotext: str - raw infotext.
|
||||||
- result: Dict[str, any] - parsed infotext parameters.
|
- result: dict[str, any] - parsed infotext parameters.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||||
|
|
||||||
|
@ -311,20 +311,113 @@ scripts_data = []
|
|||||||
postprocessing_scripts_data = []
|
postprocessing_scripts_data = []
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
def topological_sort(dependencies):
|
||||||
|
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
||||||
|
Ignores errors relating to missing dependeencies or circular dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
visited = {}
|
||||||
|
result = []
|
||||||
|
|
||||||
|
def inner(name):
|
||||||
|
visited[name] = True
|
||||||
|
|
||||||
|
for dep in dependencies.get(name, []):
|
||||||
|
if dep in dependencies and dep not in visited:
|
||||||
|
inner(dep)
|
||||||
|
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for depname in dependencies:
|
||||||
|
if depname not in visited:
|
||||||
|
inner(depname)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScriptWithDependencies:
|
||||||
|
script_canonical_name: str
|
||||||
|
file: ScriptFile
|
||||||
|
requires: list
|
||||||
|
load_before: list
|
||||||
|
load_after: list
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts = {}
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
|
||||||
if os.path.exists(basedir):
|
loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}
|
||||||
for filename in sorted(os.listdir(basedir)):
|
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
# build script dependency map
|
||||||
|
root_script_basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
|
if os.path.exists(root_script_basedir):
|
||||||
|
for filename in sorted(os.listdir(root_script_basedir)):
|
||||||
|
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.splitext(filename)[1].lower() != extension:
|
||||||
|
continue
|
||||||
|
|
||||||
|
script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
|
||||||
|
scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])
|
||||||
|
|
||||||
if include_extensions:
|
if include_extensions:
|
||||||
for ext in extensions.active():
|
for ext in extensions.active():
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
extension_scripts_list = ext.list_files(scriptdirname, extension)
|
||||||
|
for extension_script in extension_scripts_list:
|
||||||
|
if not os.path.isfile(extension_script.path):
|
||||||
|
continue
|
||||||
|
|
||||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
|
||||||
|
relative_path = scriptdirname + "/" + extension_script.filename
|
||||||
|
|
||||||
|
script = ScriptWithDependencies(
|
||||||
|
script_canonical_name=script_canonical_name,
|
||||||
|
file=extension_script,
|
||||||
|
requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
|
||||||
|
load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
|
||||||
|
load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
|
||||||
|
)
|
||||||
|
|
||||||
|
scripts[script_canonical_name] = script
|
||||||
|
loaded_extensions_scripts[ext.canonical_name].append(script)
|
||||||
|
|
||||||
|
for script_canonical_name, script in scripts.items():
|
||||||
|
# load before requires inverse dependency
|
||||||
|
# in this case, append the script name into the load_after list of the specified script
|
||||||
|
for load_before in script.load_before:
|
||||||
|
# if this requires an individual script to be loaded before
|
||||||
|
other_script = scripts.get(load_before)
|
||||||
|
if other_script:
|
||||||
|
other_script.load_after.append(script_canonical_name)
|
||||||
|
|
||||||
|
# if this requires an extension
|
||||||
|
other_extension_scripts = loaded_extensions_scripts.get(load_before)
|
||||||
|
if other_extension_scripts:
|
||||||
|
for other_script in other_extension_scripts:
|
||||||
|
other_script.load_after.append(script_canonical_name)
|
||||||
|
|
||||||
|
# if After mentions an extension, remove it and instead add all of its scripts
|
||||||
|
for load_after in list(script.load_after):
|
||||||
|
if load_after not in scripts and load_after in loaded_extensions_scripts:
|
||||||
|
script.load_after.remove(load_after)
|
||||||
|
|
||||||
|
for other_script in loaded_extensions_scripts.get(load_after, []):
|
||||||
|
script.load_after.append(other_script.script_canonical_name)
|
||||||
|
|
||||||
|
dependencies = {}
|
||||||
|
|
||||||
|
for script_canonical_name, script in scripts.items():
|
||||||
|
for required_script in script.requires:
|
||||||
|
if required_script not in scripts and required_script not in loaded_extensions:
|
||||||
|
errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)
|
||||||
|
|
||||||
|
dependencies[script_canonical_name] = script.load_after
|
||||||
|
|
||||||
|
ordered_scripts = topological_sort(dependencies)
|
||||||
|
scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]
|
||||||
|
|
||||||
return scripts_list
|
return scripts_list
|
||||||
|
|
||||||
@ -365,15 +458,9 @@ def load_scripts():
|
|||||||
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
|
||||||
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
|
||||||
|
|
||||||
def orderby(basedir):
|
# here the scripts_list is already ordered
|
||||||
# 1st webui, 2nd extensions-builtin, 3rd extensions
|
# processing_script is not considered though
|
||||||
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
|
for scriptfile in scripts_list:
|
||||||
for key in priority:
|
|
||||||
if basedir.startswith(key):
|
|
||||||
return priority[key]
|
|
||||||
return 9999
|
|
||||||
|
|
||||||
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
|
|
||||||
try:
|
try:
|
||||||
if scriptfile.basedir != paths.script_path:
|
if scriptfile.basedir != paths.script_path:
|
||||||
sys.path = [scriptfile.basedir] + sys.path
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
@ -473,17 +560,25 @@ class ScriptRunner:
|
|||||||
on_after.clear()
|
on_after.clear()
|
||||||
|
|
||||||
def create_script_ui(self, script):
|
def create_script_ui(self, script):
|
||||||
import modules.api.models as api_models
|
|
||||||
|
|
||||||
script.args_from = len(self.inputs)
|
script.args_from = len(self.inputs)
|
||||||
script.args_to = len(self.inputs)
|
script.args_to = len(self.inputs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.create_script_ui_inner(script)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error creating UI for {script.name}: ", exc_info=True)
|
||||||
|
|
||||||
|
def create_script_ui_inner(self, script):
|
||||||
|
import modules.api.models as api_models
|
||||||
|
|
||||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||||
|
|
||||||
if controls is None:
|
if controls is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||||
|
|
||||||
api_args = []
|
api_args = []
|
||||||
|
|
||||||
for control in controls:
|
for control in controls:
|
||||||
@ -491,11 +586,15 @@ class ScriptRunner:
|
|||||||
|
|
||||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||||
|
|
||||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
for field in ("value", "minimum", "maximum", "step"):
|
||||||
v = getattr(control, field, None)
|
v = getattr(control, field, None)
|
||||||
if v is not None:
|
if v is not None:
|
||||||
setattr(arg_info, field, v)
|
setattr(arg_info, field, v)
|
||||||
|
|
||||||
|
choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
|
||||||
|
if choices is not None:
|
||||||
|
arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
|
||||||
|
|
||||||
api_args.append(arg_info)
|
api_args.append(arg_info)
|
||||||
|
|
||||||
script.api_info = api_models.ScriptInfo(
|
script.api_info = api_models.ScriptInfo(
|
||||||
|
@ -1,13 +1,56 @@
|
|||||||
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import errors, shared
|
from modules import errors, shared
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class PostprocessedImageSharedInfo:
|
||||||
|
target_width: int = None
|
||||||
|
target_height: int = None
|
||||||
|
|
||||||
|
|
||||||
class PostprocessedImage:
|
class PostprocessedImage:
|
||||||
def __init__(self, image):
|
def __init__(self, image):
|
||||||
self.image = image
|
self.image = image
|
||||||
self.info = {}
|
self.info = {}
|
||||||
|
self.shared = PostprocessedImageSharedInfo()
|
||||||
|
self.extra_images = []
|
||||||
|
self.nametags = []
|
||||||
|
self.disable_processing = False
|
||||||
|
self.caption = None
|
||||||
|
|
||||||
|
def get_suffix(self, used_suffixes=None):
|
||||||
|
used_suffixes = {} if used_suffixes is None else used_suffixes
|
||||||
|
suffix = "-".join(self.nametags)
|
||||||
|
if suffix:
|
||||||
|
suffix = "-" + suffix
|
||||||
|
|
||||||
|
if suffix not in used_suffixes:
|
||||||
|
used_suffixes[suffix] = 1
|
||||||
|
return suffix
|
||||||
|
|
||||||
|
for i in range(1, 100):
|
||||||
|
proposed_suffix = suffix + "-" + str(i)
|
||||||
|
|
||||||
|
if proposed_suffix not in used_suffixes:
|
||||||
|
used_suffixes[proposed_suffix] = 1
|
||||||
|
return proposed_suffix
|
||||||
|
|
||||||
|
return suffix
|
||||||
|
|
||||||
|
def create_copy(self, new_image, *, nametags=None, disable_processing=False):
|
||||||
|
pp = PostprocessedImage(new_image)
|
||||||
|
pp.shared = self.shared
|
||||||
|
pp.nametags = self.nametags.copy()
|
||||||
|
pp.info = self.info.copy()
|
||||||
|
pp.disable_processing = disable_processing
|
||||||
|
|
||||||
|
if nametags is not None:
|
||||||
|
pp.nametags += nametags
|
||||||
|
|
||||||
|
return pp
|
||||||
|
|
||||||
|
|
||||||
class ScriptPostprocessing:
|
class ScriptPostprocessing:
|
||||||
@ -42,10 +85,17 @@ class ScriptPostprocessing:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def image_changed(self):
|
def process_firstpass(self, pp: PostprocessedImage, **args):
|
||||||
|
"""
|
||||||
|
Called for all scripts before calling process(). Scripts can examine the image here and set fields
|
||||||
|
of the pp object to communicate things to other scripts.
|
||||||
|
args contains a dictionary with all values returned by components from ui()
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def image_changed(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def run(self, pp: PostprocessedImage, args):
|
def run(self, pp: PostprocessedImage, args):
|
||||||
for script in self.scripts_in_preferred_order():
|
scripts = []
|
||||||
shared.state.job = script.name
|
|
||||||
|
|
||||||
|
for script in self.scripts_in_preferred_order():
|
||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, _component), value in zip(script.controls.items(), script_args):
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
scripts.append((script, process_args))
|
||||||
|
|
||||||
|
for script, process_args in scripts:
|
||||||
|
script.process_firstpass(pp, **process_args)
|
||||||
|
|
||||||
|
all_images = [pp]
|
||||||
|
|
||||||
|
for script, process_args in scripts:
|
||||||
|
if shared.state.skipped:
|
||||||
|
break
|
||||||
|
|
||||||
|
shared.state.job = script.name
|
||||||
|
|
||||||
|
for single_image in all_images.copy():
|
||||||
|
|
||||||
|
if not single_image.disable_processing:
|
||||||
|
script.process(single_image, **process_args)
|
||||||
|
|
||||||
|
for extra_image in single_image.extra_images:
|
||||||
|
if not isinstance(extra_image, PostprocessedImage):
|
||||||
|
extra_image = single_image.create_copy(extra_image)
|
||||||
|
|
||||||
|
all_images.append(extra_image)
|
||||||
|
|
||||||
|
single_image.extra_images.clear()
|
||||||
|
|
||||||
|
pp.extra_images = all_images[1:]
|
||||||
|
|
||||||
def create_args_for_run(self, scripts_args):
|
def create_args_for_run(self, scripts_args):
|
||||||
if not self.ui_created:
|
if not self.ui_created:
|
||||||
|
@ -215,7 +215,7 @@ class LoadStateDictOnMeta(ReplaceHelper):
|
|||||||
would be on the meta device.
|
would be on the meta device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if state_dict == sd:
|
if state_dict is sd:
|
||||||
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
|
||||||
|
|
||||||
original(module, state_dict, strict=strict)
|
original(module, state_dict, strict=strict)
|
||||||
|
@ -2,14 +2,15 @@ import torch
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
import ldm.modules.encoders.modules
|
import ldm.modules.encoders.modules
|
||||||
@ -37,6 +38,12 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
|
|||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
|
||||||
|
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
|
||||||
|
|
||||||
|
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
|
||||||
|
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers():
|
def list_optimizers():
|
||||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||||
@ -181,6 +188,20 @@ class StableDiffusionModelHijack:
|
|||||||
errors.display(e, "applying cross attention optimization")
|
errors.display(e, "applying cross attention optimization")
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
|
def convert_sdxl_to_ssd(self, m):
|
||||||
|
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
|
||||||
|
|
||||||
|
delattr(m.model.diffusion_model.middle_block, '1')
|
||||||
|
delattr(m.model.diffusion_model.middle_block, '2')
|
||||||
|
for i in ['9', '8', '7', '6', '5', '4']:
|
||||||
|
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
|
||||||
|
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
conditioner = getattr(m, 'conditioner', None)
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
if conditioner:
|
if conditioner:
|
||||||
@ -208,7 +229,7 @@ class StableDiffusionModelHijack:
|
|||||||
else:
|
else:
|
||||||
m.cond_stage_model = conditioner
|
m.cond_stage_model = conditioner
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
@ -239,10 +260,17 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
import modules.models.diffusion.ddpm_edit
|
||||||
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
|
||||||
|
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
|
||||||
|
sd_unet.original_forward = ldm_original_forward
|
||||||
|
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
|
||||||
|
sd_unet.original_forward = ldm_original_forward
|
||||||
|
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
|
||||||
|
sd_unet.original_forward = sgm_original_forward
|
||||||
|
else:
|
||||||
|
sd_unet.original_forward = None
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
conditioner = getattr(m, 'conditioner', None)
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
@ -279,7 +307,6 @@ class StableDiffusionModelHijack:
|
|||||||
self.layers = None
|
self.layers = None
|
||||||
self.clip = None
|
self.clip = None
|
||||||
|
|
||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
|
||||||
|
|
||||||
def apply_circular(self, enable):
|
def apply_circular(self, enable):
|
||||||
if self.circular_enabled == enable:
|
if self.circular_enabled == enable:
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
import collections
|
import collections
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf, ListConfig
|
||||||
from os import mkdir
|
from os import mkdir
|
||||||
from urllib import request
|
from urllib import request
|
||||||
import ldm.modules.midas as midas
|
import ldm.modules.midas as midas
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -49,11 +49,12 @@ class CheckpointInfo:
|
|||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
|
||||||
|
|
||||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
name = abspath.replace(abs_ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
name = abspath.replace(model_path, '')
|
name = abspath.replace(model_path, '')
|
||||||
else:
|
else:
|
||||||
@ -129,9 +130,12 @@ except Exception:
|
|||||||
|
|
||||||
|
|
||||||
def setup_model():
|
def setup_model():
|
||||||
|
"""called once at startup to do various one-time tasks related to SD models"""
|
||||||
|
|
||||||
os.makedirs(model_path, exist_ok=True)
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
|
||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
patch_given_betas()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles(use_short=False):
|
def checkpoint_tiles(use_short=False):
|
||||||
@ -226,15 +230,19 @@ def select_checkpoint():
|
|||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
checkpoint_dict_replacements = {
|
checkpoint_dict_replacements_sd1 = {
|
||||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
|
||||||
|
'conditioner.embedders.0.': 'cond_stage_model.',
|
||||||
|
}
|
||||||
|
|
||||||
def transform_checkpoint_dict_key(k):
|
|
||||||
for text, replacement in checkpoint_dict_replacements.items():
|
def transform_checkpoint_dict_key(k, replacements):
|
||||||
|
for text, replacement in replacements.items():
|
||||||
if k.startswith(text):
|
if k.startswith(text):
|
||||||
k = replacement + k[len(text):]
|
k = replacement + k[len(text):]
|
||||||
|
|
||||||
@ -245,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
||||||
pl_sd.pop("state_dict", None)
|
pl_sd.pop("state_dict", None)
|
||||||
|
|
||||||
|
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for k, v in pl_sd.items():
|
for k, v in pl_sd.items():
|
||||||
new_key = transform_checkpoint_dict_key(k)
|
if is_sd2_turbo:
|
||||||
|
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
|
||||||
|
else:
|
||||||
|
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
|
||||||
|
|
||||||
if new_key is not None:
|
if new_key is not None:
|
||||||
sd[new_key] = v
|
sd[new_key] = v
|
||||||
@ -309,6 +322,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
|||||||
if checkpoint_info in checkpoints_loaded:
|
if checkpoint_info in checkpoints_loaded:
|
||||||
# use checkpoint cache
|
# use checkpoint cache
|
||||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||||
|
# move to end as latest
|
||||||
|
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||||
return checkpoints_loaded[checkpoint_info]
|
return checkpoints_loaded[checkpoint_info]
|
||||||
|
|
||||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||||
@ -346,16 +361,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
model.is_sdxl = hasattr(model, 'conditioner')
|
model.is_sdxl = hasattr(model, 'conditioner')
|
||||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||||
|
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||||
if model.is_sdxl:
|
if model.is_sdxl:
|
||||||
sd_models_xl.extend_sdxl(model)
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
if model.is_ssd:
|
||||||
timer.record("apply weights to model")
|
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = state_dict
|
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
@ -453,6 +471,20 @@ def enable_midas_autodownload():
|
|||||||
midas.api.load_model = load_model_wrapper
|
midas.api.load_model = load_model_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def patch_given_betas():
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
|
|
||||||
|
def patched_register_schedule(*args, **kwargs):
|
||||||
|
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
||||||
|
|
||||||
|
if isinstance(args[1], ListConfig):
|
||||||
|
args = (args[0], np.array(args[1]), *args[2:])
|
||||||
|
|
||||||
|
original_register_schedule(*args, **kwargs)
|
||||||
|
|
||||||
|
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||||
|
|
||||||
|
|
||||||
def repair_config(sd_config):
|
def repair_config(sd_config):
|
||||||
|
|
||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
@ -777,17 +809,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
timer = Timer()
|
send_model_to_cpu(sd_model or shared.sd_model)
|
||||||
|
|
||||||
if model_data.sd_model:
|
|
||||||
model_data.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
|
||||||
model_data.sd_model = None
|
|
||||||
sd_model = None
|
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
|
|||||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||||
|
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||||
|
|
||||||
def is_using_v_parameterization_for_sd2(state_dict):
|
def is_using_v_parameterization_for_sd2(state_dict):
|
||||||
"""
|
"""
|
||||||
@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
if diffusion_model_input.shape[1] == 8:
|
if diffusion_model_input.shape[1] == 8:
|
||||||
return config_instruct_pix2pix
|
return config_instruct_pix2pix
|
||||||
|
|
||||||
|
|
||||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||||
|
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||||
|
return config_alt_diffusion_m18
|
||||||
return config_alt_diffusion
|
return config_alt_diffusion
|
||||||
|
|
||||||
return config_default
|
return config_default
|
||||||
|
@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
|
|||||||
"""structure with additional information about the file with model's weights"""
|
"""structure with additional information about the file with model's weights"""
|
||||||
|
|
||||||
is_sdxl: bool
|
is_sdxl: bool
|
||||||
"""True if the model's architecture is SDXL"""
|
"""True if the model's architecture is SDXL or SSD"""
|
||||||
|
|
||||||
|
is_ssd: bool
|
||||||
|
"""True if the model is SSD"""
|
||||||
|
|
||||||
is_sd2: bool
|
is_sd2: bool
|
||||||
"""True if the model's architecture is SD 2.x"""
|
"""True if the model's architecture is SD 2.x"""
|
||||||
|
@ -60,7 +60,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||||
while restart_times > 0:
|
while restart_times > 0:
|
||||||
restart_times -= 1
|
restart_times -= 1
|
||||||
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
|
||||||
|
|
||||||
last_sigma = None
|
last_sigma = None
|
||||||
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||||
|
@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
|
|||||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|||||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import torch.nn
|
import torch.nn
|
||||||
import ldm.modules.diffusionmodules.openaimodel
|
|
||||||
|
|
||||||
from modules import script_callbacks, shared, devices
|
from modules import script_callbacks, shared, devices
|
||||||
|
|
||||||
unet_options = []
|
unet_options = []
|
||||||
current_unet_option = None
|
current_unet_option = None
|
||||||
current_unet = None
|
current_unet = None
|
||||||
|
original_forward = None # not used, only left temporarily for compatibility
|
||||||
|
|
||||||
def list_unets():
|
def list_unets():
|
||||||
new_unets = script_callbacks.list_unets_callback()
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
def create_unet_forward(original_forward):
|
||||||
if current_unet is not None:
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
if current_unet is not None:
|
||||||
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return UNetModel_forward
|
||||||
|
|
||||||
|
@ -14,5 +14,5 @@ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
|||||||
else:
|
else:
|
||||||
cmd_opts, _ = parser.parse_known_args()
|
cmd_opts, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
|
||||||
cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access
|
cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access
|
||||||
|
@ -44,9 +44,9 @@ def refresh_unet_list():
|
|||||||
modules.sd_unet.list_unets()
|
modules.sd_unet.list_unets()
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint_tiles():
|
def list_checkpoint_tiles(use_short=False):
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
return modules.sd_models.checkpoint_tiles()
|
return modules.sd_models.checkpoint_tiles(use_short)
|
||||||
|
|
||||||
|
|
||||||
def refresh_checkpoints():
|
def refresh_checkpoints():
|
||||||
@ -66,7 +66,25 @@ def reload_hypernetworks():
|
|||||||
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def get_infotext_names():
|
||||||
|
from modules import generation_parameters_copypaste, shared
|
||||||
|
res = {}
|
||||||
|
|
||||||
|
for info in shared.opts.data_labels.values():
|
||||||
|
if info.infotext:
|
||||||
|
res[info.infotext] = 1
|
||||||
|
|
||||||
|
for tab_data in generation_parameters_copypaste.paste_fields.values():
|
||||||
|
for _, name in tab_data.get("fields") or []:
|
||||||
|
if isinstance(name, str):
|
||||||
|
res[name] = 1
|
||||||
|
|
||||||
|
return list(res)
|
||||||
|
|
||||||
|
|
||||||
ui_reorder_categories_builtin_items = [
|
ui_reorder_categories_builtin_items = [
|
||||||
|
"prompt",
|
||||||
|
"image",
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
"accordions",
|
"accordions",
|
||||||
|
@ -3,7 +3,7 @@ import gradio as gr
|
|||||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from modules.shared_cmd_options import cmd_opts
|
from modules.shared_cmd_options import cmd_opts
|
||||||
from modules.options import options_section, OptionInfo, OptionHTML
|
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||||
|
|
||||||
options_templates = {}
|
options_templates = {}
|
||||||
hide_dirs = shared.hide_dirs
|
hide_dirs = shared.hide_dirs
|
||||||
@ -21,12 +21,19 @@ restricted_opts = {
|
|||||||
"outdir_init_images"
|
"outdir_init_images"
|
||||||
}
|
}
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
categories.register_category("saving", "Saving images")
|
||||||
|
categories.register_category("sd", "Stable Diffusion")
|
||||||
|
categories.register_category("ui", "User Interface")
|
||||||
|
categories.register_category("system", "System")
|
||||||
|
categories.register_category("postprocessing", "Postprocessing")
|
||||||
|
categories.register_category("training", "Training")
|
||||||
|
|
||||||
|
options_templates.update(options_section(('saving-images', "Saving images/grids", "saving"), {
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||||
|
"save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
|
||||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
@ -39,8 +46,6 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
|
||||||
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
|
||||||
|
|
||||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
|
||||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
|
||||||
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
|
||||||
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
|
||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
@ -62,9 +67,12 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||||
|
|
||||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||||
|
|
||||||
|
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
|
||||||
|
"notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
@ -76,7 +84,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
@ -84,22 +92,23 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling", "postprocessing"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
|
||||||
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
"face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"),
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System", "system"), {
|
||||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||||
|
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
@ -109,15 +118,16 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
|
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('API', "API"), {
|
options_templates.update(options_section(('API', "API", "system"), {
|
||||||
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
|
||||||
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),
|
||||||
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
"api_useragent": OptionInfo("", "User agent for requests", restrict_api=True),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training", "training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
|
||||||
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
"pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
|
||||||
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
|
||||||
@ -132,8 +142,8 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
"training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion", "sd"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
@ -149,14 +159,14 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
"hires_fix_refiner_pass": OptionInfo("second pass", "Hires fix: which pass to enable refiner for", gr.Radio, {"choices": ["first pass", "second pass", "both passes"]}, infotext="Hires refiner"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), {
|
||||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('vae', "VAE"), {
|
options_templates.update(options_section(('vae', "VAE", "sd"), {
|
||||||
"sd_vae_explanation": OptionHTML("""
|
"sd_vae_explanation": OptionHTML("""
|
||||||
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
|
||||||
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling
|
||||||
@ -171,7 +181,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('img2img', "img2img"), {
|
options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'),
|
||||||
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
"img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"),
|
||||||
@ -184,9 +194,10 @@ options_templates.update(options_section(('img2img', "img2img"), {
|
|||||||
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
@ -197,7 +208,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
|
|||||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
@ -222,14 +233,17 @@ options_templates.update(options_section(('interrogate', "Interrogate"), {
|
|||||||
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
options_templates.update(options_section(('extra_networks', "Extra Networks", "sd"), {
|
||||||
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||||
|
"extra_networks_dir_button_function": OptionInfo(False, "Add a '/' to the beginning of directory buttons").info("Buttons will display the contents of the selected directory without acting as a search filter."),
|
||||||
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}),
|
||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
|
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||||
|
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||||
@ -237,42 +251,66 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
"keyedit_precision_attention": OptionInfo(0.1, "Precision for (attention:1.1) when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
"keyedit_precision_extra": OptionInfo(0.05, "Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Word delimiters when editing the prompt with Ctrl+up/down"),
|
||||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(),
|
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
|
||||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
|
|
||||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
|
||||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
|
||||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
|
||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
|
||||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
||||||
|
"return_grid": OptionInfo(True, "Show grid in gallery"),
|
||||||
|
"do_not_show_images": OptionInfo(False, "Do not show any images in gallery"),
|
||||||
|
"js_modal_lightbox": OptionInfo(True, "Full page image viewer: enable"),
|
||||||
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
|
||||||
|
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('ui_alternatives', "UI alternatives", "ui"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(),
|
||||||
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||||
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
"sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
|
||||||
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||||
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
|
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||||
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),
|
||||||
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||||
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||||
|
"ui_reorder_list": OptionInfo([], "UI item order for txt2img/img2img tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(),
|
||||||
|
"gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"),
|
||||||
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
options_templates.update(options_section(('infotext', "Infotext", "ui"), {
|
||||||
|
"infotext_explanation": OptionHTML("""
|
||||||
|
Infotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again.
|
||||||
|
It is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.
|
||||||
|
"""),
|
||||||
|
"enable_pnginfo": OptionInfo(True, "Write infotext to metadata of the generated image"),
|
||||||
|
"save_txt": OptionInfo(False, "Create a text file with infotext next to every generated image"),
|
||||||
|
|
||||||
|
"add_model_name_to_info": OptionInfo(True, "Add model name to infotext"),
|
||||||
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to infotext"),
|
||||||
|
"add_vae_name_to_info": OptionInfo(True, "Add VAE name to infotext"),
|
||||||
|
"add_vae_hash_to_info": OptionInfo(True, "Add VAE hash to infotext"),
|
||||||
|
"add_user_name_to_info": OptionInfo(False, "Add user name to infotext when authenticated"),
|
||||||
|
"add_version_to_infotext": OptionInfo(True, "Add program version to infotext"),
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
|
||||||
|
"infotext_skip_pasting": OptionInfo([], "Disregard fields from pasted infotext", ui_components.DropdownMulti, lambda: {"choices": shared_items.get_infotext_names()}),
|
||||||
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
|
||||||
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
<li>Ignore: keep prompt and styles dropdown as it is.</li>
|
||||||
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>
|
||||||
@ -282,7 +320,7 @@ options_templates.update(options_section(('infotext', "Infotext"), {
|
|||||||
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews", "ui"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
@ -293,9 +331,10 @@ options_templates.update(options_section(('ui', "Live previews"), {
|
|||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
"live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"),
|
||||||
|
"js_live_preview_in_modal_lightbox": OptionInfo(False, "Show Live preview in full page image viewer"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters", "sd"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||||
@ -305,8 +344,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
@ -317,10 +356,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
|
'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section((None, "Hidden options"), {
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
@ -329,4 +369,3 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -103,6 +103,7 @@ class State:
|
|||||||
|
|
||||||
def begin(self, job: str = "(unknown)"):
|
def begin(self, job: str = "(unknown)"):
|
||||||
self.sampling_step = 0
|
self.sampling_step = 0
|
||||||
|
self.time_start = time.time()
|
||||||
self.job_count = -1
|
self.job_count = -1
|
||||||
self.processing_has_refined_job_count = False
|
self.processing_has_refined_job_count = False
|
||||||
self.job_no = 0
|
self.job_no = 0
|
||||||
@ -114,7 +115,6 @@ class State:
|
|||||||
self.skipped = False
|
self.skipped = False
|
||||||
self.interrupted = False
|
self.interrupted = False
|
||||||
self.textinfo = None
|
self.textinfo = None
|
||||||
self.time_start = time.time()
|
|
||||||
self.job = job
|
self.job = job
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
log.info("Starting job %s", job)
|
log.info("Starting job %s", job)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import csv
|
import csv
|
||||||
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
|
||||||
import typing
|
import typing
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
@ -10,6 +10,7 @@ class PromptStyle(typing.NamedTuple):
|
|||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str
|
||||||
negative_prompt: str
|
negative_prompt: str
|
||||||
|
path: str = None
|
||||||
|
|
||||||
|
|
||||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||||
@ -29,38 +30,61 @@ def apply_styles_to_prompt(prompt, styles):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
re_spaces = re.compile(" +")
|
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||||
|
"""
|
||||||
|
Checks the prompt to see if the style text is wrapped around it. If so,
|
||||||
|
returns True plus the prompt text without the style text. Otherwise, returns
|
||||||
|
False with the original prompt.
|
||||||
|
|
||||||
|
Note that the "cleaned" version of the style text is only used for matching
|
||||||
def extract_style_text_from_prompt(style_text, prompt):
|
purposes here. It isn't returned; the original style text is not modified.
|
||||||
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
|
"""
|
||||||
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
|
stripped_prompt = prompt
|
||||||
|
stripped_style_text = style_text
|
||||||
if "{prompt}" in stripped_style_text:
|
if "{prompt}" in stripped_style_text:
|
||||||
left, right = stripped_style_text.split("{prompt}", 2)
|
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||||
|
# return True and the "inner" prompt text that isn't part of the style.
|
||||||
|
try:
|
||||||
|
left, right = stripped_style_text.split("{prompt}", 2)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the style text has multple "{prompt}"s, we can't split it into
|
||||||
|
# two parts. This is an error, but we can't do anything about it.
|
||||||
|
print(f"Unable to compare style text to prompt:\n{style_text}")
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return False, prompt
|
||||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||||
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
||||||
return True, prompt
|
return True, prompt
|
||||||
else:
|
else:
|
||||||
|
# Work out whether the given prompt ends with the style text. If so, we
|
||||||
|
# return True and the prompt text up to where the style text starts.
|
||||||
if stripped_prompt.endswith(stripped_style_text):
|
if stripped_prompt.endswith(stripped_style_text):
|
||||||
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
||||||
|
if prompt.endswith(", "):
|
||||||
if prompt.endswith(', '):
|
|
||||||
prompt = prompt[:-2]
|
prompt = prompt[:-2]
|
||||||
|
|
||||||
return True, prompt
|
return True, prompt
|
||||||
|
|
||||||
return False, prompt
|
return False, prompt
|
||||||
|
|
||||||
|
|
||||||
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||||
|
"""
|
||||||
|
Takes a style and compares it to the prompt and negative prompt. If the style
|
||||||
|
matches, returns True plus the prompt and negative prompt with the style text
|
||||||
|
removed. Otherwise, returns False with the original prompt and negative prompt.
|
||||||
|
"""
|
||||||
if not style.prompt and not style.negative_prompt:
|
if not style.prompt and not style.negative_prompt:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
||||||
|
style.prompt, prompt
|
||||||
|
)
|
||||||
if not match_positive:
|
if not match_positive:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
||||||
|
style.negative_prompt, negative_prompt
|
||||||
|
)
|
||||||
if not match_negative:
|
if not match_negative:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
@ -69,25 +93,84 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
|||||||
|
|
||||||
class StyleDatabase:
|
class StyleDatabase:
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str):
|
||||||
self.no_style = PromptStyle("None", "", "")
|
self.no_style = PromptStyle("None", "", "", None)
|
||||||
self.styles = {}
|
self.styles = {}
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
folder, file = os.path.split(self.path)
|
||||||
|
filename, _, ext = file.partition('*')
|
||||||
|
self.default_path = os.path.join(folder, filename + ext)
|
||||||
|
|
||||||
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def reload(self):
|
def reload(self):
|
||||||
|
"""
|
||||||
|
Clears the style database and reloads the styles from the CSV file(s)
|
||||||
|
matching the path used to initialize the database.
|
||||||
|
"""
|
||||||
self.styles.clear()
|
self.styles.clear()
|
||||||
|
|
||||||
if not os.path.exists(self.path):
|
path, filename = os.path.split(self.path)
|
||||||
return
|
|
||||||
|
|
||||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
if "*" in filename:
|
||||||
|
fileglob = filename.split("*")[0] + "*.csv"
|
||||||
|
filelist = []
|
||||||
|
for file in os.listdir(path):
|
||||||
|
if fnmatch.fnmatch(file, fileglob):
|
||||||
|
filelist.append(file)
|
||||||
|
# Add a visible divider to the style list
|
||||||
|
half_len = round(len(file) / 2)
|
||||||
|
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
||||||
|
divider = f"{divider} {'-' * (40 - len(divider))}"
|
||||||
|
self.styles[divider] = PromptStyle(
|
||||||
|
f"{divider}", None, None, "do_not_save"
|
||||||
|
)
|
||||||
|
# Add styles from this CSV file
|
||||||
|
self.load_from_csv(os.path.join(path, file))
|
||||||
|
if len(filelist) == 0:
|
||||||
|
print(f"No styles found in {path} matching {fileglob}")
|
||||||
|
return
|
||||||
|
elif not os.path.exists(self.path):
|
||||||
|
print(f"Style database not found: {self.path}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.load_from_csv(self.path)
|
||||||
|
|
||||||
|
def load_from_csv(self, path: str):
|
||||||
|
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||||
reader = csv.DictReader(file, skipinitialspace=True)
|
reader = csv.DictReader(file, skipinitialspace=True)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
|
# Ignore empty rows or rows starting with a comment
|
||||||
|
if not row or row["name"].startswith("#"):
|
||||||
|
continue
|
||||||
# Support loading old CSV format with "name, text"-columns
|
# Support loading old CSV format with "name, text"-columns
|
||||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||||
negative_prompt = row.get("negative_prompt", "")
|
negative_prompt = row.get("negative_prompt", "")
|
||||||
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
|
# Add style to database
|
||||||
|
self.styles[row["name"]] = PromptStyle(
|
||||||
|
row["name"], prompt, negative_prompt, path
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_style_paths(self) -> set:
|
||||||
|
"""Returns a set of all distinct paths of files that styles are loaded from."""
|
||||||
|
# Update any styles without a path to the default path
|
||||||
|
for style in list(self.styles.values()):
|
||||||
|
if not style.path:
|
||||||
|
self.styles[style.name] = style._replace(path=self.default_path)
|
||||||
|
|
||||||
|
# Create a list of all distinct paths, including the default path
|
||||||
|
style_paths = set()
|
||||||
|
style_paths.add(self.default_path)
|
||||||
|
for _, style in self.styles.items():
|
||||||
|
if style.path:
|
||||||
|
style_paths.add(style.path)
|
||||||
|
|
||||||
|
# Remove any paths for styles that are just list dividers
|
||||||
|
style_paths.discard("do_not_save")
|
||||||
|
|
||||||
|
return style_paths
|
||||||
|
|
||||||
def get_style_prompts(self, styles):
|
def get_style_prompts(self, styles):
|
||||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
@ -96,20 +179,40 @@ class StyleDatabase:
|
|||||||
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
||||||
|
|
||||||
def apply_styles_to_prompt(self, prompt, styles):
|
def apply_styles_to_prompt(self, prompt, styles):
|
||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
|
return apply_styles_to_prompt(
|
||||||
|
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
|
)
|
||||||
|
|
||||||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
return apply_styles_to_prompt(
|
||||||
|
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
||||||
|
)
|
||||||
|
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str = None) -> None:
|
||||||
# Always keep a backup file around
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
if os.path.exists(path):
|
_ = path
|
||||||
shutil.copy(path, f"{path}.bak")
|
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
style_paths = self.get_style_paths()
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
|
||||||
writer.writeheader()
|
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
||||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
|
||||||
|
for style_path in style_paths:
|
||||||
|
# Always keep a backup file around
|
||||||
|
if os.path.exists(style_path):
|
||||||
|
shutil.copy(style_path, f"{style_path}.bak")
|
||||||
|
|
||||||
|
# Write the styles to the CSV file
|
||||||
|
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
|
||||||
|
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
for style in (s for s in self.styles.values() if s.path == style_path):
|
||||||
|
# Skip style list dividers, e.g. "STYLES.CSV"
|
||||||
|
if style.name.lower().strip("# ") in csv_names:
|
||||||
|
continue
|
||||||
|
# Write style fields, ignoring the path field
|
||||||
|
writer.writerow(
|
||||||
|
{k: v for k, v in style._asdict().items() if k != "path"}
|
||||||
|
)
|
||||||
|
|
||||||
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
||||||
extracted = []
|
extracted = []
|
||||||
@ -120,7 +223,9 @@ class StyleDatabase:
|
|||||||
found_style = None
|
found_style = None
|
||||||
|
|
||||||
for style in applicable_styles:
|
for style in applicable_styles:
|
||||||
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
|
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
|
||||||
|
style, prompt, negative_prompt
|
||||||
|
)
|
||||||
if is_match:
|
if is_match:
|
||||||
found_style = style
|
found_style = style
|
||||||
prompt = new_prompt
|
prompt = new_prompt
|
||||||
|
@ -15,7 +15,7 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
import math
|
import math
|
||||||
from typing import Optional, NamedTuple, List
|
from typing import Optional, NamedTuple
|
||||||
|
|
||||||
|
|
||||||
def narrow_trunc(
|
def narrow_trunc(
|
||||||
@ -97,7 +97,7 @@ def _query_chunk_attention(
|
|||||||
)
|
)
|
||||||
return summarize_chunk(query, key_chunk, value_chunk)
|
return summarize_chunk(query, key_chunk, value_chunk)
|
||||||
|
|
||||||
chunks: List[AttnChunk] = [
|
chunks: list[AttnChunk] = [
|
||||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||||
]
|
]
|
||||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -84,7 +83,7 @@ def get_dict():
|
|||||||
"Checksum": checksum_token,
|
"Checksum": checksum_token,
|
||||||
"Commandline": get_argv(),
|
"Commandline": get_argv(),
|
||||||
"Torch env info": get_torch_sysinfo(),
|
"Torch env info": get_torch_sysinfo(),
|
||||||
"Exceptions": get_exceptions(),
|
"Exceptions": errors.get_exceptions(),
|
||||||
"CPU": {
|
"CPU": {
|
||||||
"model": platform.processor(),
|
"model": platform.processor(),
|
||||||
"count logical": psutil.cpu_count(logical=True),
|
"count logical": psutil.cpu_count(logical=True),
|
||||||
@ -104,21 +103,6 @@ def get_dict():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def format_traceback(tb):
|
|
||||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
|
||||||
|
|
||||||
|
|
||||||
def format_exception(e, tb):
|
|
||||||
return {"exception": str(e), "traceback": format_traceback(tb)}
|
|
||||||
|
|
||||||
|
|
||||||
def get_exceptions():
|
|
||||||
try:
|
|
||||||
return list(reversed(errors.exception_records))
|
|
||||||
except Exception as e:
|
|
||||||
return str(e)
|
|
||||||
|
|
||||||
|
|
||||||
def get_environment():
|
def get_environment():
|
||||||
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ import requests
|
|||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
|
from modules import paths_internal
|
||||||
|
from pkg_resources import parse_version
|
||||||
|
|
||||||
GREEN = "#0F0"
|
GREEN = "#0F0"
|
||||||
BLUE = "#00F"
|
BLUE = "#00F"
|
||||||
@ -25,7 +27,6 @@ def crop_image(im, settings):
|
|||||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
|
|
||||||
|
|
||||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||||
im_debug = im.copy()
|
im_debug = im.copy()
|
||||||
|
|
||||||
@ -69,6 +70,7 @@ def crop_image(im, settings):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def focal_point(im, settings):
|
def focal_point(im, settings):
|
||||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||||
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []
|
||||||
@ -78,118 +80,120 @@ def focal_point(im, settings):
|
|||||||
|
|
||||||
weight_pref_total = 0
|
weight_pref_total = 0
|
||||||
if corner_points:
|
if corner_points:
|
||||||
weight_pref_total += settings.corner_points_weight
|
weight_pref_total += settings.corner_points_weight
|
||||||
if entropy_points:
|
if entropy_points:
|
||||||
weight_pref_total += settings.entropy_points_weight
|
weight_pref_total += settings.entropy_points_weight
|
||||||
if face_points:
|
if face_points:
|
||||||
weight_pref_total += settings.face_points_weight
|
weight_pref_total += settings.face_points_weight
|
||||||
|
|
||||||
corner_centroid = None
|
corner_centroid = None
|
||||||
if corner_points:
|
if corner_points:
|
||||||
corner_centroid = centroid(corner_points)
|
corner_centroid = centroid(corner_points)
|
||||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||||
pois.append(corner_centroid)
|
pois.append(corner_centroid)
|
||||||
|
|
||||||
entropy_centroid = None
|
entropy_centroid = None
|
||||||
if entropy_points:
|
if entropy_points:
|
||||||
entropy_centroid = centroid(entropy_points)
|
entropy_centroid = centroid(entropy_points)
|
||||||
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total
|
||||||
pois.append(entropy_centroid)
|
pois.append(entropy_centroid)
|
||||||
|
|
||||||
face_centroid = None
|
face_centroid = None
|
||||||
if face_points:
|
if face_points:
|
||||||
face_centroid = centroid(face_points)
|
face_centroid = centroid(face_points)
|
||||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||||
pois.append(face_centroid)
|
pois.append(face_centroid)
|
||||||
|
|
||||||
average_point = poi_average(pois, settings)
|
average_point = poi_average(pois, settings)
|
||||||
|
|
||||||
if settings.annotate_image:
|
if settings.annotate_image:
|
||||||
d = ImageDraw.Draw(im)
|
d = ImageDraw.Draw(im)
|
||||||
max_size = min(im.width, im.height) * 0.07
|
max_size = min(im.width, im.height) * 0.07
|
||||||
if corner_centroid is not None:
|
if corner_centroid is not None:
|
||||||
color = BLUE
|
color = BLUE
|
||||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
d.text((box[0], box[1] - 15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(corner_points) > 1:
|
if len(corner_points) > 1:
|
||||||
for f in corner_points:
|
for f in corner_points:
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
if entropy_centroid is not None:
|
if entropy_centroid is not None:
|
||||||
color = "#ff0"
|
color = "#ff0"
|
||||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
d.text((box[0], box[1] - 15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(entropy_points) > 1:
|
if len(entropy_points) > 1:
|
||||||
for f in entropy_points:
|
for f in entropy_points:
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
if face_centroid is not None:
|
if face_centroid is not None:
|
||||||
color = RED
|
color = RED
|
||||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
d.text((box[0], box[1] - 15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(face_points) > 1:
|
if len(face_points) > 1:
|
||||||
for f in face_points:
|
for f in face_points:
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
|
||||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||||
|
|
||||||
return average_point
|
return average_point
|
||||||
|
|
||||||
|
|
||||||
def image_face_points(im, settings):
|
def image_face_points(im, settings):
|
||||||
if settings.dnn_model_path is not None:
|
if settings.dnn_model_path is not None:
|
||||||
detector = cv2.FaceDetectorYN.create(
|
detector = cv2.FaceDetectorYN.create(
|
||||||
settings.dnn_model_path,
|
settings.dnn_model_path,
|
||||||
"",
|
"",
|
||||||
(im.width, im.height),
|
(im.width, im.height),
|
||||||
0.9, # score threshold
|
0.9, # score threshold
|
||||||
0.3, # nms threshold
|
0.3, # nms threshold
|
||||||
5000 # keep top k before nms
|
5000 # keep top k before nms
|
||||||
)
|
)
|
||||||
faces = detector.detect(np.array(im))
|
faces = detector.detect(np.array(im))
|
||||||
results = []
|
results = []
|
||||||
if faces[1] is not None:
|
if faces[1] is not None:
|
||||||
for face in faces[1]:
|
for face in faces[1]:
|
||||||
x = face[0]
|
x = face[0]
|
||||||
y = face[1]
|
y = face[1]
|
||||||
w = face[2]
|
w = face[2]
|
||||||
h = face[3]
|
h = face[3]
|
||||||
results.append(
|
results.append(
|
||||||
PointOfInterest(
|
PointOfInterest(
|
||||||
int(x + (w * 0.5)), # face focus left/right is center
|
int(x + (w * 0.5)), # face focus left/right is center
|
||||||
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
int(y + (h * 0.33)), # face focus up/down is close to the top of the head
|
||||||
size = w,
|
size=w,
|
||||||
weight = 1/len(faces[1])
|
weight=1 / len(faces[1])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
np_im = np.array(im)
|
np_im = np.array(im)
|
||||||
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
tries = [
|
tries = [
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01 ],
|
[f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05 ],
|
[f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],
|
||||||
[ f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05 ]
|
[f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]
|
||||||
]
|
]
|
||||||
for t in tries:
|
for t in tries:
|
||||||
classifier = cv2.CascadeClassifier(t[0])
|
classifier = cv2.CascadeClassifier(t[0])
|
||||||
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
minsize = int(min(im.width, im.height) * t[1]) # at least N percent of the smallest side
|
||||||
try:
|
try:
|
||||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
minNeighbors=7, minSize=(minsize, minsize),
|
||||||
except Exception:
|
flags=cv2.CASCADE_SCALE_IMAGE)
|
||||||
continue
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
if faces:
|
if faces:
|
||||||
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]
|
||||||
return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects]
|
return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),
|
||||||
|
weight=1 / len(rects)) for r in rects]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@ -198,7 +202,7 @@ def image_corner_points(im, settings):
|
|||||||
|
|
||||||
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
# naive attempt at preventing focal points from collecting at watermarks near the bottom
|
||||||
gd = ImageDraw.Draw(grayscale)
|
gd = ImageDraw.Draw(grayscale)
|
||||||
gd.rectangle([0, im.height*.9, im.width, im.height], fill="#999")
|
gd.rectangle([0, im.height * .9, im.width, im.height], fill="#999")
|
||||||
|
|
||||||
np_im = np.array(grayscale)
|
np_im = np.array(grayscale)
|
||||||
|
|
||||||
@ -206,7 +210,7 @@ def image_corner_points(im, settings):
|
|||||||
np_im,
|
np_im,
|
||||||
maxCorners=100,
|
maxCorners=100,
|
||||||
qualityLevel=0.04,
|
qualityLevel=0.04,
|
||||||
minDistance=min(grayscale.width, grayscale.height)*0.06,
|
minDistance=min(grayscale.width, grayscale.height) * 0.06,
|
||||||
useHarrisDetector=False,
|
useHarrisDetector=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,8 +219,8 @@ def image_corner_points(im, settings):
|
|||||||
|
|
||||||
focal_points = []
|
focal_points = []
|
||||||
for point in points:
|
for point in points:
|
||||||
x, y = point.ravel()
|
x, y = point.ravel()
|
||||||
focal_points.append(PointOfInterest(x, y, size=4, weight=1/len(points)))
|
focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))
|
||||||
|
|
||||||
return focal_points
|
return focal_points
|
||||||
|
|
||||||
@ -225,13 +229,13 @@ def image_entropy_points(im, settings):
|
|||||||
landscape = im.height < im.width
|
landscape = im.height < im.width
|
||||||
portrait = im.height > im.width
|
portrait = im.height > im.width
|
||||||
if landscape:
|
if landscape:
|
||||||
move_idx = [0, 2]
|
move_idx = [0, 2]
|
||||||
move_max = im.size[0]
|
move_max = im.size[0]
|
||||||
elif portrait:
|
elif portrait:
|
||||||
move_idx = [1, 3]
|
move_idx = [1, 3]
|
||||||
move_max = im.size[1]
|
move_max = im.size[1]
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
e_max = 0
|
e_max = 0
|
||||||
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
crop_current = [0, 0, settings.crop_width, settings.crop_height]
|
||||||
@ -241,14 +245,14 @@ def image_entropy_points(im, settings):
|
|||||||
e = image_entropy(crop)
|
e = image_entropy(crop)
|
||||||
|
|
||||||
if (e > e_max):
|
if (e > e_max):
|
||||||
e_max = e
|
e_max = e
|
||||||
crop_best = list(crop_current)
|
crop_best = list(crop_current)
|
||||||
|
|
||||||
crop_current[move_idx[0]] += 4
|
crop_current[move_idx[0]] += 4
|
||||||
crop_current[move_idx[1]] += 4
|
crop_current[move_idx[1]] += 4
|
||||||
|
|
||||||
x_mid = int(crop_best[0] + settings.crop_width/2)
|
x_mid = int(crop_best[0] + settings.crop_width / 2)
|
||||||
y_mid = int(crop_best[1] + settings.crop_height/2)
|
y_mid = int(crop_best[1] + settings.crop_height / 2)
|
||||||
|
|
||||||
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]
|
||||||
|
|
||||||
@ -294,22 +298,23 @@ def is_square(w, h):
|
|||||||
return w == h
|
return w == h
|
||||||
|
|
||||||
|
|
||||||
def download_and_cache_models(dirname):
|
model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
|
||||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
if parse_version(cv2.__version__) >= parse_version('4.8'):
|
||||||
model_file_name = 'face_detection_yunet.onnx'
|
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
|
||||||
|
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
|
||||||
|
else:
|
||||||
|
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
|
||||||
|
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
|
|
||||||
os.makedirs(dirname, exist_ok=True)
|
|
||||||
|
|
||||||
cache_file = os.path.join(dirname, model_file_name)
|
def download_and_cache_models():
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(model_file_path):
|
||||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
os.makedirs(model_dir_opencv, exist_ok=True)
|
||||||
response = requests.get(download_url)
|
print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
|
||||||
with open(cache_file, "wb") as f:
|
response = requests.get(model_url)
|
||||||
|
with open(model_file_path, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
return model_file_path
|
||||||
if os.path.exists(cache_file):
|
|
||||||
return cache_file
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class PointOfInterest:
|
class PointOfInterest:
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
import os
|
|
||||||
from PIL import Image, ImageOps
|
|
||||||
import math
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from modules import paths, shared, images, deepbooru
|
|
||||||
from modules.textual_inversion import autocrop
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
|
||||||
try:
|
|
||||||
if process_caption:
|
|
||||||
shared.interrogator.load()
|
|
||||||
|
|
||||||
if process_caption_deepbooru:
|
|
||||||
deepbooru.model.start()
|
|
||||||
|
|
||||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
|
|
||||||
if process_caption:
|
|
||||||
shared.interrogator.send_blip_to_ram()
|
|
||||||
|
|
||||||
if process_caption_deepbooru:
|
|
||||||
deepbooru.model.stop()
|
|
||||||
|
|
||||||
|
|
||||||
def listfiles(dirname):
|
|
||||||
return os.listdir(dirname)
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessParams:
|
|
||||||
src = None
|
|
||||||
dstdir = None
|
|
||||||
subindex = 0
|
|
||||||
flip = False
|
|
||||||
process_caption = False
|
|
||||||
process_caption_deepbooru = False
|
|
||||||
preprocess_txt_action = None
|
|
||||||
|
|
||||||
|
|
||||||
def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
|
|
||||||
caption = ""
|
|
||||||
|
|
||||||
if params.process_caption:
|
|
||||||
caption += shared.interrogator.generate_caption(image)
|
|
||||||
|
|
||||||
if params.process_caption_deepbooru:
|
|
||||||
if caption:
|
|
||||||
caption += ", "
|
|
||||||
caption += deepbooru.model.tag_multi(image)
|
|
||||||
|
|
||||||
filename_part = params.src
|
|
||||||
filename_part = os.path.splitext(filename_part)[0]
|
|
||||||
filename_part = os.path.basename(filename_part)
|
|
||||||
|
|
||||||
basename = f"{index:05}-{params.subindex}-{filename_part}"
|
|
||||||
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
|
||||||
|
|
||||||
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
|
||||||
caption = f"{existing_caption} {caption}"
|
|
||||||
elif params.preprocess_txt_action == 'append' and existing_caption:
|
|
||||||
caption = f"{caption} {existing_caption}"
|
|
||||||
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
|
||||||
caption = existing_caption
|
|
||||||
|
|
||||||
caption = caption.strip()
|
|
||||||
|
|
||||||
if caption:
|
|
||||||
with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
|
|
||||||
file.write(caption)
|
|
||||||
|
|
||||||
params.subindex += 1
|
|
||||||
|
|
||||||
|
|
||||||
def save_pic(image, index, params, existing_caption=None):
|
|
||||||
save_pic_with_caption(image, index, params, existing_caption=existing_caption)
|
|
||||||
|
|
||||||
if params.flip:
|
|
||||||
save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
|
|
||||||
|
|
||||||
|
|
||||||
def split_pic(image, inverse_xy, width, height, overlap_ratio):
|
|
||||||
if inverse_xy:
|
|
||||||
from_w, from_h = image.height, image.width
|
|
||||||
to_w, to_h = height, width
|
|
||||||
else:
|
|
||||||
from_w, from_h = image.width, image.height
|
|
||||||
to_w, to_h = width, height
|
|
||||||
h = from_h * to_w // from_w
|
|
||||||
if inverse_xy:
|
|
||||||
image = image.resize((h, to_w))
|
|
||||||
else:
|
|
||||||
image = image.resize((to_w, h))
|
|
||||||
|
|
||||||
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
|
||||||
y_step = (h - to_h) / (split_count - 1)
|
|
||||||
for i in range(split_count):
|
|
||||||
y = int(y_step * i)
|
|
||||||
if inverse_xy:
|
|
||||||
splitted = image.crop((y, 0, y + to_h, to_w))
|
|
||||||
else:
|
|
||||||
splitted = image.crop((0, y, to_w, y + to_h))
|
|
||||||
yield splitted
|
|
||||||
|
|
||||||
# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
|
|
||||||
def center_crop(image: Image, w: int, h: int):
|
|
||||||
iw, ih = image.size
|
|
||||||
if ih / h < iw / w:
|
|
||||||
sw = w * ih / h
|
|
||||||
box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
|
|
||||||
else:
|
|
||||||
sh = h * iw / w
|
|
||||||
box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
|
|
||||||
return image.resize((w, h), Image.Resampling.LANCZOS, box)
|
|
||||||
|
|
||||||
|
|
||||||
def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
|
|
||||||
iw, ih = image.size
|
|
||||||
err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
|
|
||||||
wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
|
|
||||||
if minarea <= w * h <= maxarea and err(w, h) <= threshold),
|
|
||||||
key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
|
|
||||||
default=None
|
|
||||||
)
|
|
||||||
return wh and center_crop(image, *wh)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
|
||||||
width = process_width
|
|
||||||
height = process_height
|
|
||||||
src = os.path.abspath(process_src)
|
|
||||||
dst = os.path.abspath(process_dst)
|
|
||||||
split_threshold = max(0.0, min(1.0, split_threshold))
|
|
||||||
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
|
|
||||||
|
|
||||||
assert src != dst, 'same directory specified as source and destination'
|
|
||||||
|
|
||||||
os.makedirs(dst, exist_ok=True)
|
|
||||||
|
|
||||||
files = listfiles(src)
|
|
||||||
|
|
||||||
shared.state.job = "preprocess"
|
|
||||||
shared.state.textinfo = "Preprocessing..."
|
|
||||||
shared.state.job_count = len(files)
|
|
||||||
|
|
||||||
params = PreprocessParams()
|
|
||||||
params.dstdir = dst
|
|
||||||
params.flip = process_flip
|
|
||||||
params.process_caption = process_caption
|
|
||||||
params.process_caption_deepbooru = process_caption_deepbooru
|
|
||||||
params.preprocess_txt_action = preprocess_txt_action
|
|
||||||
|
|
||||||
pbar = tqdm.tqdm(files)
|
|
||||||
for index, imagefile in enumerate(pbar):
|
|
||||||
params.subindex = 0
|
|
||||||
filename = os.path.join(src, imagefile)
|
|
||||||
try:
|
|
||||||
img = Image.open(filename)
|
|
||||||
img = ImageOps.exif_transpose(img)
|
|
||||||
img = img.convert("RGB")
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
description = f"Preprocessing [Image {index}/{len(files)}]"
|
|
||||||
pbar.set_description(description)
|
|
||||||
shared.state.textinfo = description
|
|
||||||
|
|
||||||
params.src = filename
|
|
||||||
|
|
||||||
existing_caption = None
|
|
||||||
existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
|
|
||||||
if os.path.exists(existing_caption_filename):
|
|
||||||
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
|
||||||
existing_caption = file.read()
|
|
||||||
|
|
||||||
if shared.state.interrupted:
|
|
||||||
break
|
|
||||||
|
|
||||||
if img.height > img.width:
|
|
||||||
ratio = (img.width * height) / (img.height * width)
|
|
||||||
inverse_xy = False
|
|
||||||
else:
|
|
||||||
ratio = (img.height * width) / (img.width * height)
|
|
||||||
inverse_xy = True
|
|
||||||
|
|
||||||
process_default_resize = True
|
|
||||||
|
|
||||||
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
|
||||||
for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
|
|
||||||
save_pic(splitted, index, params, existing_caption=existing_caption)
|
|
||||||
process_default_resize = False
|
|
||||||
|
|
||||||
if process_focal_crop and img.height != img.width:
|
|
||||||
|
|
||||||
dnn_model_path = None
|
|
||||||
try:
|
|
||||||
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv"))
|
|
||||||
except Exception as e:
|
|
||||||
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
|
|
||||||
|
|
||||||
autocrop_settings = autocrop.Settings(
|
|
||||||
crop_width = width,
|
|
||||||
crop_height = height,
|
|
||||||
face_points_weight = process_focal_crop_face_weight,
|
|
||||||
entropy_points_weight = process_focal_crop_entropy_weight,
|
|
||||||
corner_points_weight = process_focal_crop_edges_weight,
|
|
||||||
annotate_image = process_focal_crop_debug,
|
|
||||||
dnn_model_path = dnn_model_path,
|
|
||||||
)
|
|
||||||
for focal in autocrop.crop_image(img, autocrop_settings):
|
|
||||||
save_pic(focal, index, params, existing_caption=existing_caption)
|
|
||||||
process_default_resize = False
|
|
||||||
|
|
||||||
if process_multicrop:
|
|
||||||
cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
|
||||||
if cropped is not None:
|
|
||||||
save_pic(cropped, index, params, existing_caption=existing_caption)
|
|
||||||
else:
|
|
||||||
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
|
||||||
process_default_resize = False
|
|
||||||
|
|
||||||
if process_keep_original_size:
|
|
||||||
save_pic(img, index, params, existing_caption=existing_caption)
|
|
||||||
process_default_resize = False
|
|
||||||
|
|
||||||
if process_default_resize:
|
|
||||||
img = images.resize_image(1, img, width, height)
|
|
||||||
save_pic(img, index, params, existing_caption=existing_caption)
|
|
||||||
|
|
||||||
shared.state.nextjob()
|
|
@ -181,40 +181,7 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
|
||||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
|
||||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
|
||||||
vectors = data['clip_g'].shape[0]
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
shape = vec.shape[-1]
|
|
||||||
vectors = vec.shape[0]
|
|
||||||
else:
|
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
|
||||||
embedding.step = data.get('step', None)
|
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
|
||||||
embedding.vectors = vectors
|
|
||||||
embedding.shape = shape
|
|
||||||
embedding.filename = path
|
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
|
||||||
|
|
||||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||||
self.register_embedding(embedding, shared.sd_model)
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
||||||
|
if 'string_to_param' in data: # textual inversion embeddings
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
embedding.vectors = vectors
|
||||||
|
embedding.shape = shape
|
||||||
|
|
||||||
|
if filepath:
|
||||||
|
embedding.filename = filepath
|
||||||
|
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||||
if shared.opts.training_write_csv_every == 0:
|
if shared.opts.training_write_csv_every == 0:
|
||||||
return
|
return
|
||||||
@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||||||
assert log_directory, "Log directory is empty"
|
assert log_directory, "Log directory is empty"
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
from modules import processing
|
from modules import processing
|
||||||
|
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
p.prompt = preview_prompt
|
p.prompt = preview_prompt
|
||||||
p.negative_prompt = preview_negative_prompt
|
p.negative_prompt = preview_negative_prompt
|
||||||
p.steps = preview_steps
|
p.steps = preview_steps
|
||||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||||
p.cfg_scale = preview_cfg_scale
|
p.cfg_scale = preview_cfg_scale
|
||||||
p.seed = preview_seed
|
p.seed = preview_seed
|
||||||
p.width = preview_width
|
p.width = preview_width
|
||||||
|
@ -3,7 +3,6 @@ import html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
import modules.textual_inversion.preprocess
|
|
||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
|||||||
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def preprocess(*args):
|
|
||||||
modules.textual_inversion.preprocess.preprocess(*args)
|
|
||||||
|
|
||||||
return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
|
|
||||||
|
|
||||||
|
|
||||||
def train_embedding(*args):
|
def train_embedding(*args):
|
||||||
|
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||||
|
@ -3,7 +3,7 @@ from contextlib import closing
|
|||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import processing
|
from modules import processing
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
|
|
||||||
p.user = request.username
|
p.user = request.username
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if shared.opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
with closing(p):
|
with closing(p):
|
||||||
|
420
modules/ui.py
420
modules/ui.py
@ -4,6 +4,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
@ -12,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@ -25,7 +26,6 @@ import modules.hypernetworks.ui as hypernetworks_ui
|
|||||||
import modules.textual_inversion.ui as textual_inversion_ui
|
import modules.textual_inversion.ui as textual_inversion_ui
|
||||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.images
|
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
@ -151,11 +151,15 @@ def connect_clear_prompt(button):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_token_counter(text, steps):
|
def update_token_counter(text, steps, *, is_positive=True):
|
||||||
try:
|
try:
|
||||||
text, _ = extra_networks.parse_prompt(text)
|
text, _ = extra_networks.parse_prompt(text)
|
||||||
|
|
||||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
if is_positive:
|
||||||
|
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||||
|
else:
|
||||||
|
prompt_flat_list = [text]
|
||||||
|
|
||||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -169,76 +173,9 @@ def update_token_counter(text, steps):
|
|||||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
|
||||||
class Toprow:
|
def update_negative_prompt_token_counter(text, steps):
|
||||||
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
return update_token_counter(text, steps, is_positive=False)
|
||||||
|
|
||||||
def __init__(self, is_img2img):
|
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
|
||||||
self.id_part = id_part
|
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
|
||||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=80):
|
|
||||||
with gr.Row():
|
|
||||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
|
||||||
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=80):
|
|
||||||
with gr.Row():
|
|
||||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
|
||||||
|
|
||||||
self.button_interrogate = None
|
|
||||||
self.button_deepbooru = None
|
|
||||||
if is_img2img:
|
|
||||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
|
||||||
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
|
||||||
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
|
||||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
|
||||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
|
||||||
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
|
||||||
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
|
||||||
|
|
||||||
self.skip.click(
|
|
||||||
fn=lambda: shared.state.skip(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.interrupt.click(
|
|
||||||
fn=lambda: shared.state.interrupt(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
|
||||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
|
||||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
|
||||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
|
||||||
|
|
||||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
|
||||||
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
|
||||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
|
||||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
|
||||||
|
|
||||||
self.clear_prompt_button.click(
|
|
||||||
fn=lambda *x: x,
|
|
||||||
_js="confirm_clear_prompt",
|
|
||||||
inputs=[self.prompt, self.negative_prompt],
|
|
||||||
outputs=[self.prompt, self.negative_prompt],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
|
||||||
|
|
||||||
self.prompt_img.change(
|
|
||||||
fn=modules.images.image_data,
|
|
||||||
inputs=[self.prompt_img],
|
|
||||||
outputs=[self.prompt, self.prompt_img],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
@ -278,8 +215,8 @@ def apply_setting(key, value):
|
|||||||
return getattr(opts, key)
|
return getattr(opts, key)
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir, toprow=None):
|
||||||
return ui_common.create_output_panel(tabname, outdir)
|
return ui_common.create_output_panel(tabname, outdir, toprow)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_and_steps_selection(choices, tabname):
|
def create_sampler_and_steps_selection(choices, tabname):
|
||||||
@ -326,7 +263,7 @@ def create_ui():
|
|||||||
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
toprow = Toprow(is_img2img=False)
|
toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
|
||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
@ -334,10 +271,17 @@ def create_ui():
|
|||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.txt2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="txt2img_settings"))
|
||||||
|
|
||||||
scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
|
if category == "prompt":
|
||||||
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||||
|
|
||||||
@ -348,7 +292,7 @@ def create_ui():
|
|||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||||
|
|
||||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
|
||||||
|
|
||||||
if opts.dimensions_and_batch_together:
|
if opts.dimensions_and_batch_together:
|
||||||
with gr.Column(elem_id="txt2img_column_batch"):
|
with gr.Column(elem_id="txt2img_column_batch"):
|
||||||
@ -432,7 +376,7 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||||
@ -533,7 +477,7 @@ def create_ui():
|
|||||||
]
|
]
|
||||||
|
|
||||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||||
@ -544,13 +488,17 @@ def create_ui():
|
|||||||
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
toprow = Toprow(is_img2img=True)
|
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="img2img_settings"):
|
with ExitStack() as stack:
|
||||||
|
if shared.opts.img2img_settings_accordion:
|
||||||
|
stack.enter_context(gr.Accordion("Open for Settings", open=False))
|
||||||
|
stack.enter_context(gr.Column(variant='compact', elem_id="img2img_settings"))
|
||||||
|
|
||||||
copy_image_buttons = []
|
copy_image_buttons = []
|
||||||
copy_image_destinations = {}
|
copy_image_destinations = {}
|
||||||
|
|
||||||
@ -567,85 +515,89 @@ def create_ui():
|
|||||||
button = gr.Button(title)
|
button = gr.Button(title)
|
||||||
copy_image_buttons.append((button, name, elem))
|
copy_image_buttons.append((button, name, elem))
|
||||||
|
|
||||||
with gr.Tabs(elem_id="mode_img2img"):
|
|
||||||
img2img_selected_tab = gr.State(0)
|
|
||||||
|
|
||||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
|
||||||
add_copy_image_controls('img2img', init_img)
|
|
||||||
|
|
||||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
|
||||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
|
||||||
add_copy_image_controls('sketch', sketch)
|
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
|
||||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
|
||||||
|
|
||||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
|
||||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
|
||||||
inpaint_color_sketch_orig = gr.State(None)
|
|
||||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
|
||||||
|
|
||||||
def update_orig(image, state):
|
|
||||||
if image is not None:
|
|
||||||
same_size = state is not None and state.size == image.size
|
|
||||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
|
||||||
edited = same_size and has_exact_match
|
|
||||||
return image if not edited or state is None else state
|
|
||||||
|
|
||||||
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
|
||||||
|
|
||||||
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
|
||||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
|
||||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
|
||||||
|
|
||||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
|
||||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
|
||||||
gr.HTML(
|
|
||||||
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
|
||||||
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
|
||||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
|
||||||
f"{hidden}</p>"
|
|
||||||
)
|
|
||||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
|
||||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
|
||||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
|
||||||
with gr.Accordion("PNG info", open=False):
|
|
||||||
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
|
||||||
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
|
||||||
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
|
||||||
|
|
||||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
|
||||||
|
|
||||||
for i, tab in enumerate(img2img_tabs):
|
|
||||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
|
||||||
|
|
||||||
def copy_image(img):
|
|
||||||
if isinstance(img, dict) and 'image' in img:
|
|
||||||
return img['image']
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
for button, name, elem in copy_image_buttons:
|
|
||||||
button.click(
|
|
||||||
fn=copy_image,
|
|
||||||
inputs=[elem],
|
|
||||||
outputs=[copy_image_destinations[name]],
|
|
||||||
)
|
|
||||||
button.click(
|
|
||||||
fn=lambda: None,
|
|
||||||
_js=f"switch_to_{name.replace(' ', '_')}",
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
|
||||||
|
|
||||||
scripts.scripts_img2img.prepare_ui()
|
scripts.scripts_img2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
|
if category == "prompt":
|
||||||
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
|
if category == "image":
|
||||||
|
with gr.Tabs(elem_id="mode_img2img"):
|
||||||
|
img2img_selected_tab = gr.State(0)
|
||||||
|
|
||||||
|
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||||
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||||
|
add_copy_image_controls('img2img', init_img)
|
||||||
|
|
||||||
|
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||||
|
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||||
|
add_copy_image_controls('sketch', sketch)
|
||||||
|
|
||||||
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||||
|
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||||
|
|
||||||
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||||
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
|
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||||
|
|
||||||
|
def update_orig(image, state):
|
||||||
|
if image is not None:
|
||||||
|
same_size = state is not None and state.size == image.size
|
||||||
|
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||||
|
edited = same_size and has_exact_match
|
||||||
|
return image if not edited or state is None else state
|
||||||
|
|
||||||
|
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
||||||
|
|
||||||
|
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||||
|
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||||
|
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
||||||
|
|
||||||
|
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||||
|
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
|
gr.HTML(
|
||||||
|
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||||
|
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||||
|
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||||
|
f"{hidden}</p>"
|
||||||
|
)
|
||||||
|
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||||
|
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||||
|
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||||
|
with gr.Accordion("PNG info", open=False):
|
||||||
|
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||||
|
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||||
|
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||||
|
|
||||||
|
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||||
|
|
||||||
|
for i, tab in enumerate(img2img_tabs):
|
||||||
|
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||||
|
|
||||||
|
def copy_image(img):
|
||||||
|
if isinstance(img, dict) and 'image' in img:
|
||||||
|
return img['image']
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
for button, name, elem in copy_image_buttons:
|
||||||
|
button.click(
|
||||||
|
fn=copy_image,
|
||||||
|
inputs=[elem],
|
||||||
|
outputs=[copy_image_destinations[name]],
|
||||||
|
)
|
||||||
|
button.click(
|
||||||
|
fn=lambda: None,
|
||||||
|
_js=f"switch_to_{name.replace(' ', '_')}",
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||||
|
|
||||||
@ -661,8 +613,8 @@ def create_ui():
|
|||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
|
||||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
|
||||||
|
|
||||||
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||||
@ -683,12 +635,6 @@ def create_ui():
|
|||||||
scale_by.release(**on_change_args)
|
scale_by.release(**on_change_args)
|
||||||
button_update_resize_to.click(**on_change_args)
|
button_update_resize_to.click(**on_change_args)
|
||||||
|
|
||||||
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
|
||||||
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
|
||||||
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
|
||||||
for component in [init_img, sketch]:
|
|
||||||
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
|
||||||
|
|
||||||
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
||||||
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
||||||
|
|
||||||
@ -746,20 +692,26 @@ def create_ui():
|
|||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||||
|
|
||||||
def select_img2img_tab(tab):
|
|
||||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
|
||||||
|
|
||||||
for i, elem in enumerate(img2img_tabs):
|
|
||||||
elem.select(
|
|
||||||
fn=lambda tab=i: select_img2img_tab(tab),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[inpaint_controls, mask_alpha],
|
|
||||||
)
|
|
||||||
|
|
||||||
if category not in {"accordions"}:
|
if category not in {"accordions"}:
|
||||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||||
|
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||||
|
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||||
|
for component in [init_img, sketch]:
|
||||||
|
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||||
|
|
||||||
|
def select_img2img_tab(tab):
|
||||||
|
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||||
|
|
||||||
|
for i, elem in enumerate(img2img_tabs):
|
||||||
|
elem.select(
|
||||||
|
fn=lambda tab=i: select_img2img_tab(tab),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
|
)
|
||||||
|
|
||||||
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
@ -960,71 +912,6 @@ def create_ui():
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
||||||
|
|
||||||
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
|
||||||
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
|
||||||
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
|
|
||||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
|
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
|
|
||||||
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
|
|
||||||
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
|
|
||||||
process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
|
|
||||||
process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
|
|
||||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
|
|
||||||
|
|
||||||
with gr.Row(visible=False) as process_split_extra_row:
|
|
||||||
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
|
|
||||||
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
|
|
||||||
|
|
||||||
with gr.Row(visible=False) as process_focal_crop_row:
|
|
||||||
process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
|
|
||||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
|
||||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
|
||||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
|
||||||
|
|
||||||
with gr.Column(visible=False) as process_multicrop_col:
|
|
||||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
|
||||||
with gr.Row():
|
|
||||||
process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
|
|
||||||
process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
|
|
||||||
with gr.Row():
|
|
||||||
process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
|
|
||||||
process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
|
|
||||||
with gr.Row():
|
|
||||||
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
|
||||||
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column(scale=3):
|
|
||||||
gr.HTML(value="")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
with gr.Row():
|
|
||||||
interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
|
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
|
|
||||||
|
|
||||||
process_split.change(
|
|
||||||
fn=lambda show: gr_show(show),
|
|
||||||
inputs=[process_split],
|
|
||||||
outputs=[process_split_extra_row],
|
|
||||||
)
|
|
||||||
|
|
||||||
process_focal_crop.change(
|
|
||||||
fn=lambda show: gr_show(show),
|
|
||||||
inputs=[process_focal_crop],
|
|
||||||
outputs=[process_focal_crop_row],
|
|
||||||
)
|
|
||||||
|
|
||||||
process_multicrop.change(
|
|
||||||
fn=lambda show: gr_show(show),
|
|
||||||
inputs=[process_multicrop],
|
|
||||||
outputs=[process_multicrop_col],
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_textual_inversion_template_names():
|
def get_textual_inversion_template_names():
|
||||||
return sorted(textual_inversion.textual_inversion_templates)
|
return sorted(textual_inversion.textual_inversion_templates)
|
||||||
|
|
||||||
@ -1125,42 +1012,6 @@ def create_ui():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
run_preprocess.click(
|
|
||||||
fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
|
|
||||||
_js="start_training_textual_inversion",
|
|
||||||
inputs=[
|
|
||||||
dummy_component,
|
|
||||||
process_src,
|
|
||||||
process_dst,
|
|
||||||
process_width,
|
|
||||||
process_height,
|
|
||||||
preprocess_txt_action,
|
|
||||||
process_keep_original_size,
|
|
||||||
process_flip,
|
|
||||||
process_split,
|
|
||||||
process_caption,
|
|
||||||
process_caption_deepbooru,
|
|
||||||
process_split_threshold,
|
|
||||||
process_overlap_ratio,
|
|
||||||
process_focal_crop,
|
|
||||||
process_focal_crop_face_weight,
|
|
||||||
process_focal_crop_entropy_weight,
|
|
||||||
process_focal_crop_edges_weight,
|
|
||||||
process_focal_crop_debug,
|
|
||||||
process_multicrop,
|
|
||||||
process_multicrop_mindim,
|
|
||||||
process_multicrop_maxdim,
|
|
||||||
process_multicrop_minarea,
|
|
||||||
process_multicrop_maxarea,
|
|
||||||
process_multicrop_objective,
|
|
||||||
process_multicrop_threshold,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
ti_output,
|
|
||||||
ti_outcome,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
train_embedding.click(
|
train_embedding.click(
|
||||||
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
@ -1234,12 +1085,6 @@ def create_ui():
|
|||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
interrupt_preprocessing.click(
|
|
||||||
fn=lambda: shared.state.interrupt(),
|
|
||||||
inputs=[],
|
|
||||||
outputs=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||||
|
|
||||||
settings = ui_settings.UiSettings()
|
settings = ui_settings.UiSettings()
|
||||||
@ -1286,7 +1131,7 @@ def create_ui():
|
|||||||
|
|
||||||
loadsave.setup_ui()
|
loadsave.setup_ui()
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
|
||||||
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
footer = shared.html("footer.html")
|
footer = shared.html("footer.html")
|
||||||
@ -1338,7 +1183,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
|||||||
|
|
||||||
def setup_ui_api(app):
|
def setup_ui_api(app):
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List
|
|
||||||
|
|
||||||
class QuicksettingsHint(BaseModel):
|
class QuicksettingsHint(BaseModel):
|
||||||
name: str = Field(title="Name of the quicksettings field")
|
name: str = Field(title="Name of the quicksettings field")
|
||||||
@ -1347,7 +1191,7 @@ def setup_ui_api(app):
|
|||||||
def quicksettings_hint():
|
def quicksettings_hint():
|
||||||
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
||||||
|
|
||||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
|
||||||
|
|
||||||
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||||
|
|
||||||
@ -1357,7 +1201,7 @@ def setup_ui_api(app):
|
|||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
text = sysinfo.get()
|
text = sysinfo.get()
|
||||||
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt"
|
filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json"
|
||||||
|
|
||||||
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'})
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir):
|
def create_output_panel(tabname, outdir, toprow=None):
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
if not os.path.exists(f):
|
if not os.path.exists(f):
|
||||||
@ -130,12 +130,15 @@ Requested path was: {f}
|
|||||||
else:
|
else:
|
||||||
sp.Popen(["xdg-open", path])
|
sp.Popen(["xdg-open", path])
|
||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
with gr.Column(elem_id=f"{tabname}_results"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
if toprow:
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
toprow.create_inline_toprow_image()
|
||||||
|
|
||||||
generation_info = None
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||||
with gr.Column():
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||||
|
|
||||||
|
generation_info = None
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ def save_config_state(name):
|
|||||||
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||||
print(f"Saving backup of webui/extension state to {filename}.")
|
print(f"Saving backup of webui/extension state to {filename}.")
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(current_config_state, f, indent=4)
|
json.dump(current_config_state, f, indent=4, ensure_ascii=False)
|
||||||
config_states.list_config_states()
|
config_states.list_config_states()
|
||||||
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||||
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||||
@ -197,7 +197,7 @@ def update_config_states_table(state_name):
|
|||||||
config_state = config_states.all_config_states[state_name]
|
config_state = config_states.all_config_states[state_name]
|
||||||
|
|
||||||
config_name = config_state.get("name", "Config")
|
config_name = config_state.get("name", "Config")
|
||||||
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
filepath = config_state.get("filepath", "<unknown>")
|
filepath = config_state.get("filepath", "<unknown>")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -335,6 +335,11 @@ def normalize_git_url(url):
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def get_extension_dirname_from_url(url):
|
||||||
|
*parts, last_part = url.split('/')
|
||||||
|
return normalize_git_url(last_part)
|
||||||
|
|
||||||
|
|
||||||
def install_extension_from_url(dirname, url, branch_name=None):
|
def install_extension_from_url(dirname, url, branch_name=None):
|
||||||
check_access()
|
check_access()
|
||||||
|
|
||||||
@ -346,10 +351,7 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
|||||||
assert url, 'No URL specified'
|
assert url, 'No URL specified'
|
||||||
|
|
||||||
if dirname is None or dirname == "":
|
if dirname is None or dirname == "":
|
||||||
*parts, last_part = url.split('/')
|
dirname = get_extension_dirname_from_url(url)
|
||||||
last_part = normalize_git_url(last_part)
|
|
||||||
|
|
||||||
dirname = last_part
|
|
||||||
|
|
||||||
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
target_dir = os.path.join(extensions.extensions_dir, dirname)
|
||||||
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
|
||||||
@ -449,7 +451,8 @@ def get_date(info: dict, key):
|
|||||||
|
|
||||||
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
|
||||||
extlist = available_extensions["extensions"]
|
extlist = available_extensions["extensions"]
|
||||||
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
|
installed_extensions = {extension.name for extension in extensions.extensions}
|
||||||
|
installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None}
|
||||||
|
|
||||||
tags = available_extensions.get("tags", {})
|
tags = available_extensions.get("tags", {})
|
||||||
tags_to_hide = set(hide_tags)
|
tags_to_hide = set(hide_tags)
|
||||||
@ -482,7 +485,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
if url is None:
|
if url is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls
|
||||||
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
||||||
|
|
||||||
if any(x for x in extension_tags if x in tags_to_hide):
|
if any(x for x in extension_tags if x in tags_to_hide):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import functools
|
||||||
import os.path
|
import os.path
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -15,6 +16,17 @@ from modules.ui_components import ToolButton
|
|||||||
extra_pages = []
|
extra_pages = []
|
||||||
allowed_dirs = set()
|
allowed_dirs = set()
|
||||||
|
|
||||||
|
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def allowed_preview_extensions_with_extra(extra_extensions=None):
|
||||||
|
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_preview_extensions():
|
||||||
|
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
|
||||||
|
|
||||||
|
|
||||||
def register_page(page):
|
def register_page(page):
|
||||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||||
@ -33,9 +45,9 @@ def fetch_file(filename: str = ""):
|
|||||||
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()[1:]
|
||||||
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
if ext not in allowed_preview_extensions():
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
|
||||||
|
|
||||||
# would profit from returning 304
|
# would profit from returning 304
|
||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
@ -91,6 +103,7 @@ class ExtraNetworksPage:
|
|||||||
self.name = title.lower()
|
self.name = title.lower()
|
||||||
self.id_page = self.name.replace(" ", "_")
|
self.id_page = self.name.replace(" ", "_")
|
||||||
self.card_page = shared.html("extra-networks-card.html")
|
self.card_page = shared.html("extra-networks-card.html")
|
||||||
|
self.allow_prompt = True
|
||||||
self.allow_negative_prompt = False
|
self.allow_negative_prompt = False
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
self.items = {}
|
self.items = {}
|
||||||
@ -138,8 +151,13 @@ class ExtraNetworksPage:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||||
while subdir.startswith("/"):
|
|
||||||
subdir = subdir[1:]
|
if shared.opts.extra_networks_dir_button_function:
|
||||||
|
if not subdir.startswith("/"):
|
||||||
|
subdir = "/" + subdir
|
||||||
|
else:
|
||||||
|
while subdir.startswith("/"):
|
||||||
|
subdir = subdir[1:]
|
||||||
|
|
||||||
is_empty = len(os.listdir(x)) == 0
|
is_empty = len(os.listdir(x)) == 0
|
||||||
if not is_empty and not subdir.endswith("/"):
|
if not is_empty and not subdir.endswith("/"):
|
||||||
@ -213,9 +231,9 @@ class ExtraNetworksPage:
|
|||||||
metadata_button = ""
|
metadata_button = ""
|
||||||
metadata = item.get("metadata")
|
metadata = item.get("metadata")
|
||||||
if metadata:
|
if metadata:
|
||||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
|
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||||
|
|
||||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
|
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||||
|
|
||||||
local_path = ""
|
local_path = ""
|
||||||
filename = item.get("filename", "")
|
filename = item.get("filename", "")
|
||||||
@ -235,7 +253,7 @@ class ExtraNetworksPage:
|
|||||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
|
sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"background_image": background_image,
|
"background_image": background_image,
|
||||||
@ -266,6 +284,7 @@ class ExtraNetworksPage:
|
|||||||
"date_created": int(stat.st_ctime or 0),
|
"date_created": int(stat.st_ctime or 0),
|
||||||
"date_modified": int(stat.st_mtime or 0),
|
"date_modified": int(stat.st_mtime or 0),
|
||||||
"name": pth.name.lower(),
|
"name": pth.name.lower(),
|
||||||
|
"path": str(pth.parent).lower(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_preview(self, path):
|
def find_preview(self, path):
|
||||||
@ -273,11 +292,7 @@ class ExtraNetworksPage:
|
|||||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
|
||||||
if shared.opts.samples_format not in preview_extensions:
|
|
||||||
preview_extensions.append(shared.opts.samples_format)
|
|
||||||
|
|
||||||
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
|
|
||||||
|
|
||||||
for file in potential_files:
|
for file in potential_files:
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
@ -359,7 +374,10 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
related_tabs = []
|
related_tabs = []
|
||||||
|
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title, id=page.id_page) as tab:
|
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
||||||
|
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
|
||||||
|
pass
|
||||||
|
|
||||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
@ -373,19 +391,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
related_tabs.append(tab)
|
related_tabs.append(tab)
|
||||||
|
|
||||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||||
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||||
|
|
||||||
for tab in unrelated_tabs:
|
tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
|
||||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
|
||||||
|
|
||||||
for tab in related_tabs:
|
for tab in unrelated_tabs:
|
||||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||||
|
|
||||||
|
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
||||||
|
allow_prompt = "true" if page.allow_prompt else "false"
|
||||||
|
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
||||||
|
|
||||||
|
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||||
|
|
||||||
|
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||||
|
|
||||||
|
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
|
||||||
|
|
||||||
def pages_html():
|
def pages_html():
|
||||||
if not ui.pages_contents:
|
if not ui.pages_contents:
|
||||||
|
@ -10,11 +10,16 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('Checkpoints')
|
super().__init__('Checkpoints')
|
||||||
|
|
||||||
|
self.allow_prompt = False
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
return {
|
return {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
@ -30,9 +35,12 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
|
# instantiate a list to protect against concurrent modification
|
||||||
names = list(sd_models.checkpoints_list)
|
names = list(sd_models.checkpoints_list)
|
||||||
for index, name in enumerate(names):
|
for index, name in enumerate(names):
|
||||||
yield self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||||
|
@ -13,7 +13,10 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks.get(name)
|
||||||
|
if full_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||||
shorthash = sha256[0:10] if sha256 else None
|
shorthash = sha256[0:10] if sha256 else None
|
||||||
@ -31,8 +34,12 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(shared.hypernetworks):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(shared.hypernetworks)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [shared.cmd_opts.hypernetwork_dir]
|
return [shared.cmd_opts.hypernetwork_dir]
|
||||||
|
@ -14,6 +14,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||||
|
if embedding is None:
|
||||||
|
return
|
||||||
|
|
||||||
path, ext = os.path.splitext(embedding.filename)
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
return {
|
return {
|
||||||
@ -29,8 +31,12 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings):
|
# instantiate a list to protect against concurrent modification
|
||||||
yield self.create_item(name, index)
|
names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)
|
||||||
|
for index, name in enumerate(names):
|
||||||
|
item = self.create_item(name, index)
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)
|
||||||
|
@ -134,7 +134,7 @@ class UserMetadataEditor:
|
|||||||
basename, ext = os.path.splitext(filename)
|
basename, ext = os.path.splitext(filename)
|
||||||
|
|
||||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
with open(basename + '.json', "w", encoding="utf8") as file:
|
||||||
json.dump(metadata, file, indent=4)
|
json.dump(metadata, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def save_user_metadata(self, name, desc, notes):
|
def save_user_metadata(self, name, desc, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
@ -2,12 +2,12 @@ import os
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import localization, shared, scripts
|
from modules import localization, shared, scripts
|
||||||
from modules.paths import script_path, data_path
|
from modules.paths import script_path, data_path, cwd
|
||||||
|
|
||||||
|
|
||||||
def webpath(fn):
|
def webpath(fn):
|
||||||
if fn.startswith(script_path):
|
if fn.startswith(cwd):
|
||||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
web_path = os.path.relpath(fn, cwd)
|
||||||
else:
|
else:
|
||||||
web_path = os.path.abspath(fn)
|
web_path = os.path.abspath(fn)
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
from modules.ui_components import ToolButton
|
from modules.ui_components import ToolButton, InputAccordion
|
||||||
|
|
||||||
|
|
||||||
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
|
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
|
||||||
@ -32,8 +32,6 @@ class UiLoadsave:
|
|||||||
self.error_loading = True
|
self.error_loading = True
|
||||||
errors.display(e, "loading settings")
|
errors.display(e, "loading settings")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_component(self, path, x):
|
def add_component(self, path, x):
|
||||||
"""adds component to the registry of tracked components"""
|
"""adds component to the registry of tracked components"""
|
||||||
|
|
||||||
@ -43,20 +41,24 @@ class UiLoadsave:
|
|||||||
key = f"{path}/{field}"
|
key = f"{path}/{field}"
|
||||||
|
|
||||||
if getattr(obj, 'custom_script_source', None) is not None:
|
if getattr(obj, 'custom_script_source', None) is not None:
|
||||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||||
|
|
||||||
if getattr(obj, 'do_not_save_to_config', False):
|
if getattr(obj, 'do_not_save_to_config', False):
|
||||||
return
|
return
|
||||||
|
|
||||||
saved_value = self.ui_settings.get(key, None)
|
saved_value = self.ui_settings.get(key, None)
|
||||||
|
|
||||||
|
if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':
|
||||||
|
field = 'open'
|
||||||
|
|
||||||
if saved_value is None:
|
if saved_value is None:
|
||||||
self.ui_settings[key] = getattr(obj, field)
|
self.ui_settings[key] = getattr(obj, field)
|
||||||
elif condition and not condition(saved_value):
|
elif condition and not condition(saved_value):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
||||||
saved_value = str(saved_value)
|
saved_value = str(saved_value)
|
||||||
elif isinstance(x, gr.Number) and field == 'value':
|
elif isinstance(obj, gr.Number) and field == 'value':
|
||||||
try:
|
try:
|
||||||
saved_value = float(saved_value)
|
saved_value = float(saved_value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -67,7 +69,7 @@ class UiLoadsave:
|
|||||||
init_field(saved_value)
|
init_field(saved_value)
|
||||||
|
|
||||||
if field == 'value' and key not in self.component_mapping:
|
if field == 'value' and key not in self.component_mapping:
|
||||||
self.component_mapping[key] = x
|
self.component_mapping[key] = obj
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
||||||
apply_field(x, 'visible')
|
apply_field(x, 'visible')
|
||||||
@ -100,6 +102,12 @@ class UiLoadsave:
|
|||||||
|
|
||||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||||
|
|
||||||
|
if type(x) == InputAccordion:
|
||||||
|
if x.accordion.visible:
|
||||||
|
apply_field(x.accordion, 'visible')
|
||||||
|
apply_field(x, 'value')
|
||||||
|
apply_field(x.accordion, 'value')
|
||||||
|
|
||||||
def check_tab_id(tab_id):
|
def check_tab_id(tab_id):
|
||||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||||
if type(tab_id) == str:
|
if type(tab_id) == str:
|
||||||
@ -133,7 +141,7 @@ class UiLoadsave:
|
|||||||
|
|
||||||
def write_to_file(self, current_ui_settings):
|
def write_to_file(self, current_ui_settings):
|
||||||
with open(self.filename, "w", encoding="utf8") as file:
|
with open(self.filename, "w", encoding="utf8") as file:
|
||||||
json.dump(current_ui_settings, file, indent=4)
|
json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
def dump_defaults(self):
|
def dump_defaults(self):
|
||||||
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
|
|
||||||
|
|
||||||
def create_ui():
|
def create_ui():
|
||||||
|
dummy_component = gr.Label(visible=False)
|
||||||
tab_index = gr.State(value=0)
|
tab_index = gr.State(value=0)
|
||||||
|
|
||||||
with gr.Row(equal_height=False, variant='compact'):
|
with gr.Row(equal_height=False, variant='compact'):
|
||||||
@ -20,11 +21,13 @@ def create_ui():
|
|||||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||||
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
||||||
|
|
||||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
|
||||||
|
|
||||||
script_inputs = scripts.scripts_postproc.setup_ui()
|
script_inputs = scripts.scripts_postproc.setup_ui()
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras")
|
||||||
|
toprow.create_inline_toprow_image()
|
||||||
|
submit = toprow.submit
|
||||||
|
|
||||||
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
|
||||||
|
|
||||||
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
|
||||||
@ -32,8 +35,10 @@ def create_ui():
|
|||||||
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
|
fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),
|
||||||
|
_js="submit_extras",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
tab_index,
|
tab_index,
|
||||||
extras_image,
|
extras_image,
|
||||||
image_batch,
|
image_batch,
|
||||||
@ -45,8 +50,9 @@ def create_ui():
|
|||||||
outputs=[
|
outputs=[
|
||||||
result_images,
|
result_images,
|
||||||
html_info_x,
|
html_info_x,
|
||||||
html_info,
|
html_log,
|
||||||
]
|
],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
parameters_copypaste.add_paste_fields("extras", extras_image, None)
|
||||||
|
@ -4,6 +4,7 @@ from modules import shared, ui_common, ui_components, styles
|
|||||||
|
|
||||||
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
||||||
styles_materialize_symbol = '\U0001f4cb' # 📋
|
styles_materialize_symbol = '\U0001f4cb' # 📋
|
||||||
|
styles_copy_symbol = '\U0001f4dd' # 📝
|
||||||
|
|
||||||
|
|
||||||
def select_style(name):
|
def select_style(name):
|
||||||
@ -52,6 +53,8 @@ def refresh_styles():
|
|||||||
class UiPromptStyles:
|
class UiPromptStyles:
|
||||||
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
|
self.main_ui_prompt = main_ui_prompt
|
||||||
|
self.main_ui_negative_prompt = main_ui_negative_prompt
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
||||||
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
||||||
@ -61,13 +64,14 @@ class UiPromptStyles:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
||||||
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
||||||
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
||||||
|
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3, elem_classes=["prompt"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||||
@ -96,15 +100,21 @@ class UiPromptStyles:
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||||
|
|
||||||
self.materialize.click(
|
self.setup_apply_button(self.materialize)
|
||||||
fn=materialize_styles,
|
|
||||||
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
self.copy.click(
|
||||||
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
fn=lambda p, n: (p, n),
|
||||||
|
inputs=[main_ui_prompt, main_ui_negative_prompt],
|
||||||
|
outputs=[self.prompt, self.neg_prompt],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
|
)
|
||||||
|
|
||||||
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
||||||
|
|
||||||
|
def setup_apply_button(self, button):
|
||||||
|
button.click(
|
||||||
|
fn=materialize_styles,
|
||||||
|
inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
|
||||||
|
outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
|
||||||
|
show_progress=False,
|
||||||
|
).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
|
||||||
from modules.call_queue import wrap_gradio_call
|
from modules.call_queue import wrap_gradio_call
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.ui_components import FormRow
|
from modules.ui_components import FormRow
|
||||||
from modules.ui_gradio_extensions import reload_javascript
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
def get_value_for_setting(key):
|
def get_value_for_setting(key):
|
||||||
@ -63,6 +64,9 @@ class UiSettings:
|
|||||||
quicksettings_list = None
|
quicksettings_list = None
|
||||||
quicksettings_names = None
|
quicksettings_names = None
|
||||||
text_settings = None
|
text_settings = None
|
||||||
|
show_all_pages = None
|
||||||
|
show_one_page = None
|
||||||
|
search_input = None
|
||||||
|
|
||||||
def run_settings(self, *args):
|
def run_settings(self, *args):
|
||||||
changed = []
|
changed = []
|
||||||
@ -135,7 +139,7 @@ class UiSettings:
|
|||||||
gr.Group()
|
gr.Group()
|
||||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||||
current_tab.__enter__()
|
current_tab.__enter__()
|
||||||
current_row = gr.Column(variant='compact')
|
current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact')
|
||||||
current_row.__enter__()
|
current_row.__enter__()
|
||||||
|
|
||||||
previous_section = item.section
|
previous_section = item.section
|
||||||
@ -173,26 +177,43 @@ class UiSettings:
|
|||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
|
||||||
|
with gr.Row():
|
||||||
|
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
||||||
|
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
||||||
|
|
||||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False)
|
||||||
|
self.show_one_page.click(lambda: None)
|
||||||
|
|
||||||
|
self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False)
|
||||||
|
|
||||||
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||||
|
|
||||||
|
def call_func_and_return_text(func, text):
|
||||||
|
def handler():
|
||||||
|
t = timer.Timer()
|
||||||
|
func()
|
||||||
|
t.record(text)
|
||||||
|
|
||||||
|
return f'{text} in {t.total:.1f}s'
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
unload_sd_model.click(
|
unload_sd_model.click(
|
||||||
fn=sd_models.unload_model_weights,
|
fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
reload_sd_model.click(
|
reload_sd_model.click(
|
||||||
fn=sd_models.reload_model_weights,
|
fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[]
|
outputs=[self.result]
|
||||||
)
|
)
|
||||||
|
|
||||||
request_notifications.click(
|
request_notifications.click(
|
||||||
@ -241,6 +262,21 @@ class UiSettings:
|
|||||||
outputs=[sysinfo_check_output],
|
outputs=[sysinfo_check_output],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def calculate_all_checkpoint_hash_fn(max_thread):
|
||||||
|
checkpoints_list = sd_models.checkpoints_list.values()
|
||||||
|
with ThreadPoolExecutor(max_workers=max_thread) as executor:
|
||||||
|
futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
|
||||||
|
completed = 0
|
||||||
|
for _ in as_completed(futures):
|
||||||
|
completed += 1
|
||||||
|
print(f"{completed} / {len(checkpoints_list)} ")
|
||||||
|
print("Finish calculating hash for all checkpoints")
|
||||||
|
|
||||||
|
calculate_all_checkpoint_hash.click(
|
||||||
|
fn=calculate_all_checkpoint_hash_fn,
|
||||||
|
inputs=[calculate_all_checkpoint_hash_threads],
|
||||||
|
)
|
||||||
|
|
||||||
self.interface = settings_interface
|
self.interface = settings_interface
|
||||||
|
|
||||||
def add_quicksettings(self):
|
def add_quicksettings(self):
|
||||||
@ -294,3 +330,8 @@ class UiSettings:
|
|||||||
outputs=[self.component_dict[k] for k in component_keys],
|
outputs=[self.component_dict[k] for k in component_keys],
|
||||||
queue=False,
|
queue=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def search(self, text):
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
return [gr.update(visible=text in (comp.label or "")) for comp in self.components]
|
||||||
|
143
modules/ui_toprow.py
Normal file
143
modules/ui_toprow.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui_prompt_styles
|
||||||
|
import modules.images
|
||||||
|
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class Toprow:
|
||||||
|
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||||
|
|
||||||
|
prompt = None
|
||||||
|
prompt_img = None
|
||||||
|
negative_prompt = None
|
||||||
|
|
||||||
|
button_interrogate = None
|
||||||
|
button_deepbooru = None
|
||||||
|
|
||||||
|
interrupt = None
|
||||||
|
skip = None
|
||||||
|
submit = None
|
||||||
|
|
||||||
|
paste = None
|
||||||
|
clear_prompt_button = None
|
||||||
|
apply_styles = None
|
||||||
|
restore_progress_button = None
|
||||||
|
|
||||||
|
token_counter = None
|
||||||
|
token_button = None
|
||||||
|
negative_token_counter = None
|
||||||
|
negative_token_button = None
|
||||||
|
|
||||||
|
ui_styles = None
|
||||||
|
|
||||||
|
submit_box = None
|
||||||
|
|
||||||
|
def __init__(self, is_img2img, is_compact=False, id_part=None):
|
||||||
|
if id_part is None:
|
||||||
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
|
self.id_part = id_part
|
||||||
|
self.is_img2img = is_img2img
|
||||||
|
self.is_compact = is_compact
|
||||||
|
|
||||||
|
if not is_compact:
|
||||||
|
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||||
|
self.create_classic_toprow()
|
||||||
|
else:
|
||||||
|
self.create_submit_box()
|
||||||
|
|
||||||
|
def create_classic_toprow(self):
|
||||||
|
self.create_prompts()
|
||||||
|
|
||||||
|
with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
|
||||||
|
self.create_submit_box()
|
||||||
|
|
||||||
|
self.create_tools_row()
|
||||||
|
|
||||||
|
self.create_styles_ui()
|
||||||
|
|
||||||
|
def create_inline_toprow_prompts(self):
|
||||||
|
if not self.is_compact:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.create_prompts()
|
||||||
|
|
||||||
|
with gr.Row(elem_classes=["toprow-compact-stylerow"]):
|
||||||
|
with gr.Column(elem_classes=["toprow-compact-tools"]):
|
||||||
|
self.create_tools_row()
|
||||||
|
with gr.Column():
|
||||||
|
self.create_styles_ui()
|
||||||
|
|
||||||
|
def create_inline_toprow_image(self):
|
||||||
|
if not self.is_compact:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.submit_box.render()
|
||||||
|
|
||||||
|
def create_prompts(self):
|
||||||
|
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
||||||
|
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
||||||
|
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
|
||||||
|
self.prompt_img.change(
|
||||||
|
fn=modules.images.image_data,
|
||||||
|
inputs=[self.prompt_img],
|
||||||
|
outputs=[self.prompt, self.prompt_img],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_submit_box(self):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
|
||||||
|
self.submit_box = submit_box
|
||||||
|
|
||||||
|
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||||
|
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
|
||||||
|
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
|
||||||
|
|
||||||
|
self.skip.click(
|
||||||
|
fn=lambda: shared.state.skip(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.interrupt.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tools_row(self):
|
||||||
|
with gr.Row(elem_id=f"{self.id_part}_tools"):
|
||||||
|
from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
|
||||||
|
|
||||||
|
self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
|
||||||
|
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
|
||||||
|
self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
|
||||||
|
|
||||||
|
if self.is_img2img:
|
||||||
|
self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
|
||||||
|
self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
|
||||||
|
|
||||||
|
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
|
||||||
|
|
||||||
|
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
|
||||||
|
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
|
||||||
|
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||||
|
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
|
||||||
|
|
||||||
|
self.clear_prompt_button.click(
|
||||||
|
fn=lambda *x: x,
|
||||||
|
_js="confirm_clear_prompt",
|
||||||
|
inputs=[self.prompt, self.negative_prompt],
|
||||||
|
outputs=[self.prompt, self.negative_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_styles_ui(self):
|
||||||
|
self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
|
||||||
|
self.ui_styles.setup_apply_button(self.apply_styles)
|
@ -57,6 +57,9 @@ class Upscaler:
|
|||||||
dest_h = int((img.height * scale) // 8 * 8)
|
dest_h = int((img.height * scale) // 8 * 8)
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
|
if img.width >= dest_w and img.height >= dest_h:
|
||||||
|
break
|
||||||
|
|
||||||
shape = (img.width, img.height)
|
shape = (img.width, img.height)
|
||||||
|
|
||||||
img = self.do_upscale(img, selected_model)
|
img = self.do_upscale(img, selected_model)
|
||||||
@ -64,9 +67,6 @@ class Upscaler:
|
|||||||
if shape == (img.width, img.height):
|
if shape == (img.width, img.height):
|
||||||
break
|
break
|
||||||
|
|
||||||
if img.width >= dest_w and img.height >= dest_h:
|
|
||||||
break
|
|
||||||
|
|
||||||
if img.width != dest_w or img.height != dest_h:
|
if img.width != dest_w or img.height != dest_h:
|
||||||
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
|
||||||
|
|
||||||
|
164
modules/xlmr_m18.py
Normal file
164
modules/xlmr_m18.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
from transformers import BertPreTrainedModel,BertConfig
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||||
|
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class BertSeriesConfig(BertConfig):
|
||||||
|
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||||
|
|
||||||
|
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||||
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
self.project_dim = project_dim
|
||||||
|
self.pooler_fn = pooler_fn
|
||||||
|
self.learn_encoder = learn_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||||
|
|
||||||
|
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||||
|
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||||
|
config_class = BertSeriesConfig
|
||||||
|
|
||||||
|
def __init__(self, config=None, **kargs):
|
||||||
|
# modify initialization for autoloading
|
||||||
|
if config is None:
|
||||||
|
config = XLMRobertaConfig()
|
||||||
|
config.attention_probs_dropout_prob= 0.1
|
||||||
|
config.bos_token_id=0
|
||||||
|
config.eos_token_id=2
|
||||||
|
config.hidden_act='gelu'
|
||||||
|
config.hidden_dropout_prob=0.1
|
||||||
|
config.hidden_size=1024
|
||||||
|
config.initializer_range=0.02
|
||||||
|
config.intermediate_size=4096
|
||||||
|
config.layer_norm_eps=1e-05
|
||||||
|
config.max_position_embeddings=514
|
||||||
|
|
||||||
|
config.num_attention_heads=16
|
||||||
|
config.num_hidden_layers=24
|
||||||
|
config.output_past=True
|
||||||
|
config.pad_token_id=1
|
||||||
|
config.position_embedding_type= "absolute"
|
||||||
|
|
||||||
|
config.type_vocab_size= 1
|
||||||
|
config.use_cache=True
|
||||||
|
config.vocab_size= 250002
|
||||||
|
config.project_dim = 1024
|
||||||
|
config.learn_encoder = False
|
||||||
|
super().__init__(config)
|
||||||
|
self.roberta = XLMRobertaModel(config)
|
||||||
|
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||||
|
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
|
# self.pooler = lambda x: x[:,0]
|
||||||
|
# self.post_init()
|
||||||
|
|
||||||
|
self.has_pre_transformation = True
|
||||||
|
if self.has_pre_transformation:
|
||||||
|
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||||
|
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def encode(self,c):
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
text = self.tokenizer(c,
|
||||||
|
truncation=True,
|
||||||
|
max_length=77,
|
||||||
|
return_length=False,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt")
|
||||||
|
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||||
|
text["attention_mask"] = torch.tensor(
|
||||||
|
text['attention_mask']).to(device)
|
||||||
|
features = self(**text)
|
||||||
|
return features['projection_state']
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) :
|
||||||
|
r"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
|
||||||
|
outputs = self.roberta(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # last module outputs
|
||||||
|
# sequence_output = outputs[0]
|
||||||
|
|
||||||
|
|
||||||
|
# # project every module
|
||||||
|
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||||
|
|
||||||
|
# # pooler
|
||||||
|
# pooler_output = self.pooler(sequence_output_ln)
|
||||||
|
# pooler_output = self.transformation(pooler_output)
|
||||||
|
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
|
||||||
|
if self.has_pre_transformation:
|
||||||
|
sequence_output2 = outputs["hidden_states"][-2]
|
||||||
|
sequence_output2 = self.pre_LN(sequence_output2)
|
||||||
|
projection_state2 = self.transformation_pre(sequence_output2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"projection_state": projection_state2,
|
||||||
|
"last_hidden_state": outputs.last_hidden_state,
|
||||||
|
"hidden_states": outputs.hidden_states,
|
||||||
|
"attentions": outputs.attentions,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
projection_state = self.transformation(outputs.last_hidden_state)
|
||||||
|
return {
|
||||||
|
"projection_state": projection_state,
|
||||||
|
"last_hidden_state": outputs.last_hidden_state,
|
||||||
|
"hidden_states": outputs.hidden_states,
|
||||||
|
"attentions": outputs.attentions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# return {
|
||||||
|
# 'pooler_output':pooler_output,
|
||||||
|
# 'last_hidden_state':outputs.last_hidden_state,
|
||||||
|
# 'hidden_states':outputs.hidden_states,
|
||||||
|
# 'attentions':outputs.attentions,
|
||||||
|
# 'projection_state':projection_state,
|
||||||
|
# 'sequence_out': sequence_output
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
|
base_model_prefix = 'roberta'
|
||||||
|
config_class= RobertaSeriesConfig
|
59
modules/xpu_specific.py
Normal file
59
modules/xpu_specific.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from modules import shared
|
||||||
|
from modules.sd_hijack_utils import CondFunc
|
||||||
|
|
||||||
|
has_ipex = False
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||||
|
has_ipex = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_xpu():
|
||||||
|
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
def get_xpu_device_string():
|
||||||
|
if shared.cmd_opts.device_id is not None:
|
||||||
|
return f"xpu:{shared.cmd_opts.device_id}"
|
||||||
|
return "xpu"
|
||||||
|
|
||||||
|
|
||||||
|
def torch_xpu_gc():
|
||||||
|
with torch.xpu.device(get_xpu_device_string()):
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
has_xpu = check_for_xpu()
|
||||||
|
|
||||||
|
if has_xpu:
|
||||||
|
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
|
||||||
|
CondFunc('torch.Generator',
|
||||||
|
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||||
|
lambda orig_func, device=None: device is not None and device.type == "xpu")
|
||||||
|
|
||||||
|
# W/A for some OPs that could not handle different input dtypes
|
||||||
|
CondFunc('torch.nn.functional.layer_norm',
|
||||||
|
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||||
|
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||||
|
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||||
|
weight is not None and input.dtype != weight.data.dtype)
|
||||||
|
CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||||
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
|
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||||
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
|
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||||
|
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||||
|
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||||
|
CondFunc('torch.bmm',
|
||||||
|
lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),
|
||||||
|
lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)
|
||||||
|
CondFunc('torch.cat',
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),
|
||||||
|
lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))
|
||||||
|
CondFunc('torch.nn.functional.scaled_dot_product_attention',
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: orig_func(query, key.to(query.dtype), value.to(query.dtype), attn_mask, dropout_p, is_causal),
|
||||||
|
lambda orig_func, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False: query.dtype != key.dtype or query.dtype != value.dtype)
|
@ -16,6 +16,7 @@ exclude = [
|
|||||||
|
|
||||||
ignore = [
|
ignore = [
|
||||||
"E501", # Line too long
|
"E501", # Line too long
|
||||||
|
"E721", # Do not compare types, use `isinstance`
|
||||||
"E731", # Do not assign a `lambda` expression, use a `def`
|
"E731", # Do not assign a `lambda` expression, use a `def`
|
||||||
|
|
||||||
"I001", # Import block is un-sorted or un-formatted
|
"I001", # Import block is un-sorted or un-formatted
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user