mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-20 05:10:15 +08:00
Merge branch 'dev' into gradio4
This commit is contained in:
commit
25f636cb3a
@ -78,6 +78,8 @@ module.exports = {
|
|||||||
//extraNetworks.js
|
//extraNetworks.js
|
||||||
requestGet: "readonly",
|
requestGet: "readonly",
|
||||||
popup: "readonly",
|
popup: "readonly",
|
||||||
|
// profilerVisualization.js
|
||||||
|
createVisualizationTable: "readonly",
|
||||||
// from python
|
// from python
|
||||||
localization: "readonly",
|
localization: "readonly",
|
||||||
// progrssbar.js
|
// progrssbar.js
|
||||||
|
10
.github/workflows/on_pull_request.yaml
vendored
10
.github/workflows/on_pull_request.yaml
vendored
@ -11,8 +11,8 @@ jobs:
|
|||||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.11
|
python-version: 3.11
|
||||||
# NB: there's no cache: pip here since we're not installing anything
|
# NB: there's no cache: pip here since we're not installing anything
|
||||||
@ -20,7 +20,7 @@ jobs:
|
|||||||
# not to have GHA download an (at the time of writing) 4 GB cache
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
# of PyTorch and other dependencies.
|
# of PyTorch and other dependencies.
|
||||||
- name: Install Ruff
|
- name: Install Ruff
|
||||||
run: pip install ruff==0.1.6
|
run: pip install ruff==0.3.3
|
||||||
- name: Run Ruff
|
- name: Run Ruff
|
||||||
run: ruff .
|
run: ruff .
|
||||||
lint-js:
|
lint-js:
|
||||||
@ -29,9 +29,9 @@ jobs:
|
|||||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v3
|
uses: actions/setup-node@v4
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
node-version: 18
|
||||||
- run: npm i --ci
|
- run: npm i --ci
|
||||||
|
10
.github/workflows/run_tests.yaml
vendored
10
.github/workflows/run_tests.yaml
vendored
@ -11,9 +11,9 @@ jobs:
|
|||||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.10
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.10.6
|
||||||
cache: pip
|
cache: pip
|
||||||
@ -22,7 +22,7 @@ jobs:
|
|||||||
launch.py
|
launch.py
|
||||||
- name: Cache models
|
- name: Cache models
|
||||||
id: cache-models
|
id: cache-models
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: models
|
path: models
|
||||||
key: "2023-12-30"
|
key: "2023-12-30"
|
||||||
@ -68,13 +68,13 @@ jobs:
|
|||||||
python -m coverage report -i
|
python -m coverage report -i
|
||||||
python -m coverage html -i
|
python -m coverage html -i
|
||||||
- name: Upload main app output
|
- name: Upload main app output
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v4
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: output
|
name: output
|
||||||
path: output.txt
|
path: output.txt
|
||||||
- name: Upload coverage HTML
|
- name: Upload coverage HTML
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v4
|
||||||
if: always()
|
if: always()
|
||||||
with:
|
with:
|
||||||
name: htmlcov
|
name: htmlcov
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,3 +38,4 @@ notification.mp3
|
|||||||
/package-lock.json
|
/package-lock.json
|
||||||
/.coverage*
|
/.coverage*
|
||||||
/test/test_outputs
|
/test/test_outputs
|
||||||
|
/cache
|
||||||
|
18
CHANGELOG.md
18
CHANGELOG.md
@ -14,7 +14,7 @@
|
|||||||
* Add support for DAT upscaler models ([#14690](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14690), [#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039))
|
* Add support for DAT upscaler models ([#14690](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14690), [#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039))
|
||||||
* Extra Networks Tree View ([#14588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14588), [#14900](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14900))
|
* Extra Networks Tree View ([#14588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14588), [#14900](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14900))
|
||||||
* NPU Support ([#14801](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14801))
|
* NPU Support ([#14801](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14801))
|
||||||
* Propmpt comments support
|
* Prompt comments support
|
||||||
|
|
||||||
### Minor:
|
### Minor:
|
||||||
* Allow pasting in WIDTHxHEIGHT strings into the width/height fields ([#14296](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14296))
|
* Allow pasting in WIDTHxHEIGHT strings into the width/height fields ([#14296](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14296))
|
||||||
@ -59,7 +59,7 @@
|
|||||||
* modules/api/api.py: add api endpoint to refresh embeddings list ([#14715](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14715))
|
* modules/api/api.py: add api endpoint to refresh embeddings list ([#14715](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14715))
|
||||||
* set_named_arg ([#14773](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14773))
|
* set_named_arg ([#14773](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14773))
|
||||||
* add before_token_counter callback and use it for prompt comments
|
* add before_token_counter callback and use it for prompt comments
|
||||||
* ResizeHandleRow - allow overriden column scale parameter ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004))
|
* ResizeHandleRow - allow overridden column scale parameter ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004))
|
||||||
|
|
||||||
### Performance
|
### Performance
|
||||||
* Massive performance improvement for extra networks directories with a huge number of files in them in an attempt to tackle #14507 ([#14528](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14528))
|
* Massive performance improvement for extra networks directories with a huge number of files in them in an attempt to tackle #14507 ([#14528](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14528))
|
||||||
@ -101,7 +101,7 @@
|
|||||||
* Gracefully handle mtime read exception from cache ([#14933](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14933))
|
* Gracefully handle mtime read exception from cache ([#14933](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14933))
|
||||||
* Only trigger interrupt on `Esc` when interrupt button visible ([#14932](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14932))
|
* Only trigger interrupt on `Esc` when interrupt button visible ([#14932](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14932))
|
||||||
* Disable prompt token counters option actually disables token counting rather than just hiding results.
|
* Disable prompt token counters option actually disables token counting rather than just hiding results.
|
||||||
* avoid doble upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966))
|
* avoid double upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966))
|
||||||
* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995))
|
* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995))
|
||||||
* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006))
|
* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006))
|
||||||
* Fix resize-handle for mobile ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010), [#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065))
|
* Fix resize-handle for mobile ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010), [#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065))
|
||||||
@ -171,7 +171,7 @@
|
|||||||
* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
|
* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page
|
||||||
* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
|
* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))
|
||||||
* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
|
* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))
|
||||||
* allow use of mutiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
|
* allow use of multiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))
|
||||||
* make extra network card description plaintext by default, with an option (Treat card description as HTML) to re-enable HTML as it was (originally by [#13241](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13241))
|
* make extra network card description plaintext by default, with an option (Treat card description as HTML) to re-enable HTML as it was (originally by [#13241](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13241))
|
||||||
|
|
||||||
### Extensions and API:
|
### Extensions and API:
|
||||||
@ -308,7 +308,7 @@
|
|||||||
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
* new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))
|
||||||
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
* rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:
|
||||||
* makes all of them work with img2img
|
* makes all of them work with img2img
|
||||||
* makes prompt composition posssible (AND)
|
* makes prompt composition possible (AND)
|
||||||
* makes them available for SDXL
|
* makes them available for SDXL
|
||||||
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
* always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))
|
||||||
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
* use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))
|
||||||
@ -484,7 +484,7 @@
|
|||||||
* user metadata system for custom networks
|
* user metadata system for custom networks
|
||||||
* extended Lora metadata editor: set activation text, default weight, view tags, training info
|
* extended Lora metadata editor: set activation text, default weight, view tags, training info
|
||||||
* Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
|
* Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
|
||||||
* show github stars for extenstions
|
* show github stars for extensions
|
||||||
* img2img batch mode can read extra stuff from png info
|
* img2img batch mode can read extra stuff from png info
|
||||||
* img2img batch works with subdirectories
|
* img2img batch works with subdirectories
|
||||||
* hotkeys to move prompt elements: alt+left/right
|
* hotkeys to move prompt elements: alt+left/right
|
||||||
@ -703,7 +703,7 @@
|
|||||||
* do not wait for Stable Diffusion model to load at startup
|
* do not wait for Stable Diffusion model to load at startup
|
||||||
* add filename patterns: `[denoising]`
|
* add filename patterns: `[denoising]`
|
||||||
* directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for
|
* directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for
|
||||||
* LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metdata of the file, if present, instead of filename (both can be used to activate LoRA)
|
* LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metadata of the file, if present, instead of filename (both can be used to activate LoRA)
|
||||||
* LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
* LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
||||||
* LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer)
|
* LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer)
|
||||||
* LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss)
|
* LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss)
|
||||||
@ -733,7 +733,7 @@
|
|||||||
* fix gamepad navigation
|
* fix gamepad navigation
|
||||||
* make the lightbox fullscreen image function properly
|
* make the lightbox fullscreen image function properly
|
||||||
* fix squished thumbnails in extras tab
|
* fix squished thumbnails in extras tab
|
||||||
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
|
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everything after you refreshed)
|
||||||
* fix webui showing the same image if you configure the generation to always save results into same file
|
* fix webui showing the same image if you configure the generation to always save results into same file
|
||||||
* fix bug with upscalers not working properly
|
* fix bug with upscalers not working properly
|
||||||
* fix MPS on PyTorch 2.0.1, Intel Macs
|
* fix MPS on PyTorch 2.0.1, Intel Macs
|
||||||
@ -751,7 +751,7 @@
|
|||||||
* switch to PyTorch 2.0.0 (except for AMD GPUs)
|
* switch to PyTorch 2.0.0 (except for AMD GPUs)
|
||||||
* visual improvements to custom code scripts
|
* visual improvements to custom code scripts
|
||||||
* add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]`
|
* add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]`
|
||||||
* add support for saving init images in img2img, and record their hashes in infotext for reproducability
|
* add support for saving init images in img2img, and record their hashes in infotext for reproducibility
|
||||||
* automatically select current word when adjusting weight with ctrl+up/down
|
* automatically select current word when adjusting weight with ctrl+up/down
|
||||||
* add dropdowns for X/Y/Z plot
|
* add dropdowns for X/Y/Z plot
|
||||||
* add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
|
* add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
|
||||||
|
@ -98,6 +98,7 @@ Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-di
|
|||||||
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
|
- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)
|
||||||
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
|
- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)
|
||||||
|
- [Ascend NPUs](https://github.com/wangshuai09/stable-diffusion-webui/wiki/Install-and-run-on-Ascend-NPUs) (external wiki page)
|
||||||
|
|
||||||
Alternatively, use online services (like Google Colab):
|
Alternatively, use online services (like Google Colab):
|
||||||
|
|
||||||
|
5
_typos.toml
Normal file
5
_typos.toml
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
[default.extend-words]
|
||||||
|
# Part of "RGBa" (Pillow's pre-multiplied alpha RGB mode)
|
||||||
|
Ba = "Ba"
|
||||||
|
# HSA is something AMD uses for their GPUs
|
||||||
|
HSA = "HSA"
|
@ -301,7 +301,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
target = x_start
|
target = x_start
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
||||||
|
|
||||||
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
||||||
|
|
||||||
@ -880,7 +880,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
# hybrid case, cond is exptected to be a dict
|
# hybrid case, cond is expected to be a dict
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if not isinstance(cond, list):
|
if not isinstance(cond, list):
|
||||||
@ -916,7 +916,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
||||||
|
|
||||||
elif self.cond_stage_key == 'coordinates_bbox':
|
elif self.cond_stage_key == 'coordinates_bbox':
|
||||||
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
|
||||||
|
|
||||||
# assuming padding of unfold is always 0 and its dilation is always 1
|
# assuming padding of unfold is always 0 and its dilation is always 1
|
||||||
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
||||||
@ -926,7 +926,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||||
rescale_latent = 2 ** (num_downs)
|
rescale_latent = 2 ** (num_downs)
|
||||||
|
|
||||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||||
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
||||||
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
||||||
|
@ -30,7 +30,7 @@ def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
|
|||||||
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
In LoRA with Kroneckor Product, first value is a value for weight scale.
|
||||||
secon value is a value for weight.
|
secon value is a value for weight.
|
||||||
|
|
||||||
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
|
||||||
|
|
||||||
examples)
|
examples)
|
||||||
factor
|
factor
|
||||||
|
@ -29,7 +29,6 @@ class NetworkOnDisk:
|
|||||||
|
|
||||||
def read_metadata():
|
def read_metadata():
|
||||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||||
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@ -117,6 +116,12 @@ class NetworkModule:
|
|||||||
|
|
||||||
if hasattr(self.sd_module, 'weight'):
|
if hasattr(self.sd_module, 'weight'):
|
||||||
self.shape = self.sd_module.weight.shape
|
self.shape = self.sd_module.weight.shape
|
||||||
|
elif isinstance(self.sd_module, nn.MultiheadAttention):
|
||||||
|
# For now, only self-attn use Pytorch's MHA
|
||||||
|
# So assume all qkvo proj have same shape
|
||||||
|
self.shape = self.sd_module.out_proj.weight.shape
|
||||||
|
else:
|
||||||
|
self.shape = None
|
||||||
|
|
||||||
self.ops = None
|
self.ops = None
|
||||||
self.extra_kwargs = {}
|
self.extra_kwargs = {}
|
||||||
@ -146,6 +151,9 @@ class NetworkModule:
|
|||||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||||
|
|
||||||
|
self.dora_scale = weights.w.get("dora_scale", None)
|
||||||
|
self.dora_norm_dims = len(self.shape) - 1
|
||||||
|
|
||||||
def multiplier(self):
|
def multiplier(self):
|
||||||
if 'transformer' in self.sd_key[:20]:
|
if 'transformer' in self.sd_key[:20]:
|
||||||
return self.network.te_multiplier
|
return self.network.te_multiplier
|
||||||
@ -160,6 +168,27 @@ class NetworkModule:
|
|||||||
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
|
def apply_weight_decompose(self, updown, orig_weight):
|
||||||
|
# Match the device/dtype
|
||||||
|
orig_weight = orig_weight.to(updown.dtype)
|
||||||
|
dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)
|
||||||
|
updown = updown.to(orig_weight.device)
|
||||||
|
|
||||||
|
merged_scale1 = updown + orig_weight
|
||||||
|
merged_scale1_norm = (
|
||||||
|
merged_scale1.transpose(0, 1)
|
||||||
|
.reshape(merged_scale1.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
dora_merged = (
|
||||||
|
merged_scale1 * (dora_scale / merged_scale1_norm)
|
||||||
|
)
|
||||||
|
final_updown = dora_merged - orig_weight
|
||||||
|
return final_updown
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
@ -175,6 +204,9 @@ class NetworkModule:
|
|||||||
if ex_bias is not None:
|
if ex_bias is not None:
|
||||||
ex_bias = ex_bias * self.multiplier()
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
if self.dora_scale is not None:
|
||||||
|
updown = self.apply_weight_decompose(updown, orig_weight)
|
||||||
|
|
||||||
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||||
|
|
||||||
def calc_updown(self, target):
|
def calc_updown(self, target):
|
||||||
|
@ -36,13 +36,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
# self.alpha is unused
|
# self.alpha is unused
|
||||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||||
|
|
||||||
# LyCORIS BOFT
|
|
||||||
if self.oft_blocks.dim() == 4:
|
|
||||||
self.is_boft = True
|
|
||||||
self.rescale = weights.w.get('rescale', None)
|
|
||||||
if self.rescale is not None:
|
|
||||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
|
||||||
|
|
||||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||||
@ -54,6 +47,13 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
elif is_other_linear:
|
elif is_other_linear:
|
||||||
self.out_dim = self.sd_module.embed_dim
|
self.out_dim = self.sd_module.embed_dim
|
||||||
|
|
||||||
|
# LyCORIS BOFT
|
||||||
|
if self.oft_blocks.dim() == 4:
|
||||||
|
self.is_boft = True
|
||||||
|
self.rescale = weights.w.get('rescale', None)
|
||||||
|
if self.rescale is not None and not is_other_linear:
|
||||||
|
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||||
|
|
||||||
self.num_blocks = self.dim
|
self.num_blocks = self.dim
|
||||||
self.block_size = self.out_dim // self.dim
|
self.block_size = self.out_dim // self.dim
|
||||||
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim
|
||||||
|
@ -355,7 +355,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
If not, restores orginal weights from backup and alters weights according to networks.
|
If not, restores original weights from backup and alters weights according to networks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
network_layer_name = getattr(self, 'network_layer_name', None)
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||||
@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
# Send "real" orig_weight into MHA's lora module
|
||||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
qw, kw, vw = self.in_proj_weight.chunk(3, 0)
|
||||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
updown_q, _ = module_q.calc_updown(qw)
|
||||||
|
updown_k, _ = module_k.calc_updown(kw)
|
||||||
|
updown_v, _ = module_v.calc_updown(vw)
|
||||||
|
del qw, kw, vw
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
|
@ -149,6 +149,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
|
|
||||||
v = random.random() * max_count
|
v = random.random() * max_count
|
||||||
if count > v:
|
if count > v:
|
||||||
|
for x in "({[]})":
|
||||||
|
tag = tag.replace(x, '\\' + x)
|
||||||
res.append(tag)
|
res.append(tag)
|
||||||
|
|
||||||
return ", ".join(sorted(res))
|
return ", ".join(sorted(res))
|
||||||
|
@ -31,7 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
"shorthash": lora_on_disk.shorthash,
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_terms": search_terms,
|
"search_terms": search_terms,
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
|
@ -43,6 +43,7 @@ onUiLoaded(async() => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
function getActiveTab(elements, all = false) {
|
function getActiveTab(elements, all = false) {
|
||||||
|
if (!elements.img2imgTabs) return null;
|
||||||
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
const tabs = elements.img2imgTabs.querySelectorAll("button");
|
||||||
|
|
||||||
if (all) return tabs;
|
if (all) return tabs;
|
||||||
@ -57,6 +58,7 @@ onUiLoaded(async() => {
|
|||||||
// Get tab ID
|
// Get tab ID
|
||||||
function getTabId(elements) {
|
function getTabId(elements) {
|
||||||
const activeTab = getActiveTab(elements);
|
const activeTab = getActiveTab(elements);
|
||||||
|
if (!activeTab) return null;
|
||||||
return tabNameToElementId[activeTab.innerText];
|
return tabNameToElementId[activeTab.innerText];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,6 +249,7 @@ onUiLoaded(async() => {
|
|||||||
let isMoving = false;
|
let isMoving = false;
|
||||||
let mouseX, mouseY;
|
let mouseX, mouseY;
|
||||||
let activeElement;
|
let activeElement;
|
||||||
|
let interactedWithAltKey = false;
|
||||||
|
|
||||||
const elements = Object.fromEntries(
|
const elements = Object.fromEntries(
|
||||||
Object.keys(elementIDs).map(id => [
|
Object.keys(elementIDs).map(id => [
|
||||||
@ -260,7 +263,7 @@ onUiLoaded(async() => {
|
|||||||
const targetElement = gradioApp().querySelector(elemId);
|
const targetElement = gradioApp().querySelector(elemId);
|
||||||
|
|
||||||
if (!targetElement) {
|
if (!targetElement) {
|
||||||
console.log("Element not found");
|
console.log("Element not found", elemId);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,7 +383,8 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
// Create tooltip
|
// Create tooltip
|
||||||
function createTooltip() {
|
function createTooltip() {
|
||||||
const toolTipElemnt = targetElement.querySelector(".image-container");
|
const toolTipElement =
|
||||||
|
targetElement.querySelector(".image-container");
|
||||||
const tooltip = document.createElement("div");
|
const tooltip = document.createElement("div");
|
||||||
tooltip.className = "canvas-tooltip";
|
tooltip.className = "canvas-tooltip";
|
||||||
|
|
||||||
@ -442,16 +446,26 @@ onUiLoaded(async() => {
|
|||||||
tooltip.appendChild(tooltipContent);
|
tooltip.appendChild(tooltipContent);
|
||||||
|
|
||||||
// Add a hint element to the target element
|
// Add a hint element to the target element
|
||||||
toolTipElemnt.appendChild(tooltip);
|
toolTipElement.appendChild(tooltip);
|
||||||
|
|
||||||
return tooltip;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//Show tool tip if setting enable
|
//Show tool tip if setting enable
|
||||||
const canvasTooltip = createTooltip();
|
if (hotkeysConfig.canvas_show_tooltip) {
|
||||||
|
createTooltip();
|
||||||
|
}
|
||||||
|
|
||||||
if (!hotkeysConfig.canvas_show_tooltip) {
|
// In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.
|
||||||
canvasTooltip.style.display = "none";
|
function fixCanvas() {
|
||||||
|
const activeTab = getActiveTab(elements)?.textContent.trim();
|
||||||
|
|
||||||
|
if (activeTab && activeTab !== "img2img") {
|
||||||
|
const img = targetElement.querySelector(`${elemId} img`);
|
||||||
|
|
||||||
|
if (img && img.style.display !== "none") {
|
||||||
|
img.style.display = "none";
|
||||||
|
img.style.visibility = "hidden";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset the zoom level and pan position of the target element to their initial values
|
// Reset the zoom level and pan position of the target element to their initial values
|
||||||
@ -570,6 +584,10 @@ onUiLoaded(async() => {
|
|||||||
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
|
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
|
if (hotkeysConfig.canvas_hotkey_zoom === "Alt") {
|
||||||
|
interactedWithAltKey = true;
|
||||||
|
}
|
||||||
|
|
||||||
let zoomPosX, zoomPosY;
|
let zoomPosX, zoomPosY;
|
||||||
let delta = 0.2;
|
let delta = 0.2;
|
||||||
if (elemData[elemId].zoomLevel > 7) {
|
if (elemData[elemId].zoomLevel > 7) {
|
||||||
@ -767,17 +785,29 @@ onUiLoaded(async() => {
|
|||||||
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
targetElement.addEventListener("mouseleave", handleMouseLeave);
|
||||||
|
|
||||||
// Reset zoom when click on another tab
|
// Reset zoom when click on another tab
|
||||||
elements.img2imgTabs.addEventListener("click", resetZoom);
|
if (elements.img2imgTabs) {
|
||||||
|
elements.img2imgTabs.addEventListener("click", resetZoom);
|
||||||
|
elements.img2imgTabs.addEventListener("click", () => {
|
||||||
|
// targetElement.style.width = "";
|
||||||
|
if (parseInt(targetElement.style.width) > 865) {
|
||||||
|
setTimeout(fitToElement, 0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
targetElement.addEventListener("wheel", e => {
|
targetElement.addEventListener("wheel", e => {
|
||||||
// change zoom level
|
// change zoom level
|
||||||
const operation = e.deltaY > 0 ? "-" : "+";
|
const operation = (e.deltaY || -e.wheelDelta) > 0 ? "-" : "+";
|
||||||
changeZoomLevel(operation, e);
|
changeZoomLevel(operation, e);
|
||||||
|
|
||||||
// Handle brush size adjustment with ctrl key pressed
|
// Handle brush size adjustment with ctrl key pressed
|
||||||
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
|
if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
|
if (hotkeysConfig.canvas_hotkey_adjust === "Alt") {
|
||||||
|
interactedWithAltKey = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Increase or decrease brush size based on scroll direction
|
// Increase or decrease brush size based on scroll direction
|
||||||
adjustBrushSize(elemId, e.deltaY);
|
adjustBrushSize(elemId, e.deltaY);
|
||||||
}
|
}
|
||||||
@ -817,6 +847,20 @@ onUiLoaded(async() => {
|
|||||||
document.addEventListener("keydown", handleMoveKeyDown);
|
document.addEventListener("keydown", handleMoveKeyDown);
|
||||||
document.addEventListener("keyup", handleMoveKeyUp);
|
document.addEventListener("keyup", handleMoveKeyUp);
|
||||||
|
|
||||||
|
|
||||||
|
// Prevent firefox from opening main menu when alt is used as a hotkey for zoom or brush size
|
||||||
|
function handleAltKeyUp(e) {
|
||||||
|
if (e.key !== "Alt" || !interactedWithAltKey) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
e.preventDefault();
|
||||||
|
interactedWithAltKey = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("keyup", handleAltKeyUp);
|
||||||
|
|
||||||
|
|
||||||
// Detect zoom level and update the pan speed.
|
// Detect zoom level and update the pan speed.
|
||||||
function updatePanPosition(movementX, movementY) {
|
function updatePanPosition(movementX, movementY) {
|
||||||
let panSpeed = 2;
|
let panSpeed = 2;
|
||||||
|
@ -8,8 +8,8 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
||||||
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
||||||
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas position"),
|
||||||
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
|
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, needed for testing"),
|
||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_components, ui_settings, infotext_utils
|
from modules import scripts, shared, ui_components, ui_settings, infotext_utils, errors
|
||||||
from modules.ui_components import FormColumn
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +42,11 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
setting_name = extra_options[index]
|
setting_name = extra_options[index]
|
||||||
|
|
||||||
with FormColumn():
|
with FormColumn():
|
||||||
comp = ui_settings.create_setting_component(setting_name)
|
try:
|
||||||
|
comp = ui_settings.create_setting_component(setting_name)
|
||||||
|
except KeyError:
|
||||||
|
errors.report(f"Can't add extra options for {setting_name} in ui")
|
||||||
|
continue
|
||||||
|
|
||||||
self.comps.append(comp)
|
self.comps.append(comp)
|
||||||
self.setting_names.append(setting_name)
|
self.setting_names.append(setting_name)
|
||||||
|
@ -57,10 +57,14 @@ def latent_blend(settings, a, b, t):
|
|||||||
|
|
||||||
# NOTE: We use inplace operations wherever possible.
|
# NOTE: We use inplace operations wherever possible.
|
||||||
|
|
||||||
# [4][w][h] to [1][4][w][h]
|
if len(t.shape) == 3:
|
||||||
t2 = t.unsqueeze(0)
|
# [4][w][h] to [1][4][w][h]
|
||||||
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
t2 = t.unsqueeze(0)
|
||||||
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
||||||
|
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
t2 = t
|
||||||
|
t3 = t[:, 0][:, None]
|
||||||
|
|
||||||
one_minus_t2 = 1 - t2
|
one_minus_t2 = 1 - t2
|
||||||
one_minus_t3 = 1 - t3
|
one_minus_t3 = 1 - t3
|
||||||
@ -104,7 +108,7 @@ def latent_blend(settings, a, b, t):
|
|||||||
|
|
||||||
def get_modified_nmask(settings, nmask, sigma):
|
def get_modified_nmask(settings, nmask, sigma):
|
||||||
"""
|
"""
|
||||||
Converts a negative mask representing the transparency of the original latent vectors being overlayed
|
Converts a negative mask representing the transparency of the original latent vectors being overlaid
|
||||||
to a mask that is scaled according to the denoising strength for this step.
|
to a mask that is scaled according to the denoising strength for this step.
|
||||||
|
|
||||||
Where:
|
Where:
|
||||||
@ -135,7 +139,10 @@ def apply_adaptive_masks(
|
|||||||
from PIL import Image, ImageOps, ImageFilter
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
|
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
|
||||||
latent_mask = nmask[0].float()
|
if len(nmask.shape) == 3:
|
||||||
|
latent_mask = nmask[0].float()
|
||||||
|
else:
|
||||||
|
latent_mask = nmask[:, 0].float()
|
||||||
# convert the original mask into a form we use to scale distances for thresholding
|
# convert the original mask into a form we use to scale distances for thresholding
|
||||||
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
|
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
|
||||||
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
|
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
|
||||||
@ -157,7 +164,14 @@ def apply_adaptive_masks(
|
|||||||
percentile_min=0.25, percentile_max=0.75, min_width=1)
|
percentile_min=0.25, percentile_max=0.75, min_width=1)
|
||||||
|
|
||||||
# The distance at which opacity of original decreases to 50%
|
# The distance at which opacity of original decreases to 50%
|
||||||
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
|
if len(mask_scalar.shape) == 3:
|
||||||
|
if mask_scalar.shape[0] > i:
|
||||||
|
half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i]
|
||||||
|
else:
|
||||||
|
half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0]
|
||||||
|
else:
|
||||||
|
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
|
||||||
|
|
||||||
converted_mask = converted_mask / half_weighted_distance
|
converted_mask = converted_mask / half_weighted_distance
|
||||||
|
|
||||||
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
|
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
<div class="copy-path-button card-button"
|
<div class="copy-path-button card-button"
|
||||||
title="Copy path to clipboard"
|
title="Copy path to clipboard"
|
||||||
onclick="extraNetworksCopyCardPath(event, '{filename}')"
|
onclick="extraNetworksCopyCardPath(event)"
|
||||||
data-clipboard-text="{filename}">
|
data-clipboard-text="{filename}">
|
||||||
</div>
|
</div>
|
@ -1,4 +1,4 @@
|
|||||||
<div class="edit-button card-button"
|
<div class="edit-button card-button"
|
||||||
title="Edit metadata"
|
title="Edit metadata"
|
||||||
onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
|
onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}')">
|
||||||
</div>
|
</div>
|
@ -1,4 +1,4 @@
|
|||||||
<div class="metadata-button card-button"
|
<div class="metadata-button card-button"
|
||||||
title="Show internal metadata"
|
title="Show internal metadata"
|
||||||
onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}', '{name}')">
|
onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}')">
|
||||||
</div>
|
</div>
|
8
html/extra-networks-pane-dirs.html
Normal file
8
html/extra-networks-pane-dirs.html
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<div class="extra-network-pane-content-dirs">
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_dirs' class='extra-network-dirs'>
|
||||||
|
{dirs_html}
|
||||||
|
</div>
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards'>
|
||||||
|
{items_html}
|
||||||
|
</div>
|
||||||
|
</div>
|
8
html/extra-networks-pane-tree.html
Normal file
8
html/extra-networks-pane-tree.html
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<div class="extra-network-pane-content-tree resize-handle-row">
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree' style='flex-basis: {extra_networks_tree_view_default_width}px'>
|
||||||
|
{tree_html}
|
||||||
|
</div>
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>
|
||||||
|
{items_html}
|
||||||
|
</div>
|
||||||
|
</div>
|
@ -1,23 +1,53 @@
|
|||||||
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane'>
|
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane {tree_view_div_default_display_class}'>
|
||||||
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
|
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
|
||||||
<div class="extra-network-control--search">
|
<div class="extra-network-control--search">
|
||||||
<input
|
<input
|
||||||
id="{tabname}_{extra_networks_tabname}_extra_search"
|
id="{tabname}_{extra_networks_tabname}_extra_search"
|
||||||
class="extra-network-control--search-text"
|
class="extra-network-control--search-text"
|
||||||
type="search"
|
type="search"
|
||||||
placeholder="Filter files"
|
placeholder="Search"
|
||||||
>
|
>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<small>Sort: </small>
|
||||||
<div
|
<div
|
||||||
id="{tabname}_{extra_networks_tabname}_extra_sort"
|
id="{tabname}_{extra_networks_tabname}_extra_sort_path"
|
||||||
class="extra-network-control--sort"
|
class="extra-network-control--sort{sort_path_active}"
|
||||||
data-sortmode="{data_sortmode}"
|
data-sortkey="default"
|
||||||
data-sortkey="{data_sortkey}"
|
|
||||||
title="Sort by path"
|
title="Sort by path"
|
||||||
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
>
|
>
|
||||||
<i class="extra-network-control--sort-icon"></i>
|
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_sort_name"
|
||||||
|
class="extra-network-control--sort{sort_name_active}"
|
||||||
|
data-sortkey="name"
|
||||||
|
title="Sort by name"
|
||||||
|
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_sort_date_created"
|
||||||
|
class="extra-network-control--sort{sort_date_created_active}"
|
||||||
|
data-sortkey="date_created"
|
||||||
|
title="Sort by date created"
|
||||||
|
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_sort_date_modified"
|
||||||
|
class="extra-network-control--sort{sort_date_modified_active}"
|
||||||
|
data-sortkey="date_modified"
|
||||||
|
title="Sort by date modified"
|
||||||
|
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--icon extra-network-control--sort-icon"></i>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<small> </small>
|
||||||
<div
|
<div
|
||||||
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
|
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
|
||||||
class="extra-network-control--sort-dir"
|
class="extra-network-control--sort-dir"
|
||||||
@ -25,15 +55,18 @@
|
|||||||
title="Sort ascending"
|
title="Sort ascending"
|
||||||
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
>
|
>
|
||||||
<i class="extra-network-control--sort-dir-icon"></i>
|
<i class="extra-network-control--icon extra-network-control--sort-dir-icon"></i>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
<small> </small>
|
||||||
<div
|
<div
|
||||||
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
|
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
|
||||||
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
|
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
|
||||||
title="Enable Tree View"
|
title="Enable Tree View"
|
||||||
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
>
|
>
|
||||||
<i class="extra-network-control--tree-view-icon"></i>
|
<i class="extra-network-control--icon extra-network-control--tree-view-icon"></i>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
id="{tabname}_{extra_networks_tabname}_extra_refresh"
|
id="{tabname}_{extra_networks_tabname}_extra_refresh"
|
||||||
@ -41,15 +74,8 @@
|
|||||||
title="Refresh page"
|
title="Refresh page"
|
||||||
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
>
|
>
|
||||||
<i class="extra-network-control--refresh-icon"></i>
|
<i class="extra-network-control--icon extra-network-control--refresh-icon"></i>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="extra-network-pane-content resize-handle-row" style="display: {extra_network_pane_content_default_display};">
|
{pane_content}
|
||||||
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree {tree_view_div_extra_class}' style='flex-basis: {extra_networks_tree_view_default_width}px; display: {tree_view_div_default_display};'>
|
</div>
|
||||||
{tree_html}
|
|
||||||
</div>
|
|
||||||
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>
|
|
||||||
{items_html}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
@ -45,17 +45,17 @@ function dimensionChange(e, is_width, is_height) {
|
|||||||
var scaledx = targetElement.width * viewportscale;
|
var scaledx = targetElement.width * viewportscale;
|
||||||
var scaledy = targetElement.height * viewportscale;
|
var scaledy = targetElement.height * viewportscale;
|
||||||
|
|
||||||
var cleintRectTop = (viewportOffset.top + window.scrollY);
|
var clientRectTop = (viewportOffset.top + window.scrollY);
|
||||||
var cleintRectLeft = (viewportOffset.left + window.scrollX);
|
var clientRectLeft = (viewportOffset.left + window.scrollX);
|
||||||
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
|
var clientRectCentreY = clientRectTop + (targetElement.clientHeight / 2);
|
||||||
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
|
var clientRectCentreX = clientRectLeft + (targetElement.clientWidth / 2);
|
||||||
|
|
||||||
var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
|
var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
|
||||||
var arscaledx = currentWidth * arscale;
|
var arscaledx = currentWidth * arscale;
|
||||||
var arscaledy = currentHeight * arscale;
|
var arscaledy = currentHeight * arscale;
|
||||||
|
|
||||||
var arRectTop = cleintRectCentreY - (arscaledy / 2);
|
var arRectTop = clientRectCentreY - (arscaledy / 2);
|
||||||
var arRectLeft = cleintRectCentreX - (arscaledx / 2);
|
var arRectLeft = clientRectCentreX - (arscaledx / 2);
|
||||||
var arRectWidth = arscaledx;
|
var arRectWidth = arscaledx;
|
||||||
var arRectHeight = arscaledy;
|
var arRectHeight = arscaledy;
|
||||||
|
|
||||||
|
27
javascript/dragdrop.js
vendored
27
javascript/dragdrop.js
vendored
@ -74,22 +74,39 @@ window.document.addEventListener('dragover', e => {
|
|||||||
e.dataTransfer.dropEffect = 'copy';
|
e.dataTransfer.dropEffect = 'copy';
|
||||||
});
|
});
|
||||||
|
|
||||||
window.document.addEventListener('drop', e => {
|
window.document.addEventListener('drop', async e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
if (!eventHasFiles(e)) return;
|
const url = e.dataTransfer.getData('text/uri-list') || e.dataTransfer.getData('text/plain');
|
||||||
|
if (!eventHasFiles(e) && !url) return;
|
||||||
|
|
||||||
if (dragDropTargetIsPrompt(target)) {
|
if (dragDropTargetIsPrompt(target)) {
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
const isImg2img = get_tab_index('tabs') == 1;
|
||||||
|
let prompt_image_target = isImg2img ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
|
||||||
const imgParent = gradioApp().getElementById(prompt_target);
|
const imgParent = gradioApp().getElementById(prompt_image_target);
|
||||||
const files = e.dataTransfer.files;
|
const files = e.dataTransfer.files;
|
||||||
const fileInput = imgParent.querySelector('input[type="file"]');
|
const fileInput = imgParent.querySelector('input[type="file"]');
|
||||||
if (fileInput) {
|
if (eventHasFiles(e) && fileInput) {
|
||||||
fileInput.files = files;
|
fileInput.files = files;
|
||||||
fileInput.dispatchEvent(new Event('change'));
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
|
} else if (url) {
|
||||||
|
try {
|
||||||
|
const request = await fetch(url);
|
||||||
|
if (!request.ok) {
|
||||||
|
console.error('Error fetching URL:', url, request.status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const data = new DataTransfer();
|
||||||
|
data.items.add(new File([await request.blob()], 'image.png'));
|
||||||
|
fileInput.files = data.files;
|
||||||
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching URL:', url, error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +64,14 @@ function keyupEditAttention(event) {
|
|||||||
selectionEnd++;
|
selectionEnd++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deselect surrounding whitespace
|
||||||
|
while (text[selectionStart] == " " && selectionStart < selectionEnd) {
|
||||||
|
selectionStart++;
|
||||||
|
}
|
||||||
|
while (text[selectionEnd - 1] == " " && selectionEnd > selectionStart) {
|
||||||
|
selectionEnd--;
|
||||||
|
}
|
||||||
|
|
||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -39,12 +39,12 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
// tabname_full = {tabname}_{extra_networks_tabname}
|
// tabname_full = {tabname}_{extra_networks_tabname}
|
||||||
var tabname_full = elem.id;
|
var tabname_full = elem.id;
|
||||||
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
|
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
|
||||||
var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort");
|
|
||||||
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
|
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
|
||||||
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
|
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
|
||||||
|
var currentSort = '';
|
||||||
|
|
||||||
// If any of the buttons above don't exist, we want to skip this iteration of the loop.
|
// If any of the buttons above don't exist, we want to skip this iteration of the loop.
|
||||||
if (!search || !sort_mode || !sort_dir || !refresh) {
|
if (!search || !sort_dir || !refresh) {
|
||||||
return; // `return` is equivalent of `continue` but for forEach loops.
|
return; // `return` is equivalent of `continue` but for forEach loops.
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
var searchTerm = search.value.toLowerCase();
|
var searchTerm = search.value.toLowerCase();
|
||||||
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
||||||
var searchOnly = elem.querySelector('.search_only');
|
var searchOnly = elem.querySelector('.search_only');
|
||||||
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) {
|
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {
|
||||||
return t.textContent.toLowerCase();
|
return t.textContent.toLowerCase();
|
||||||
}).join(" ");
|
}).join(" ");
|
||||||
|
|
||||||
@ -71,42 +71,46 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
var applySort = function(force) {
|
var applySort = function(force) {
|
||||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');
|
||||||
|
var parent = gradioApp().querySelector('#' + tabname_full + "_cards");
|
||||||
var reverse = sort_dir.dataset.sortdir == "Descending";
|
var reverse = sort_dir.dataset.sortdir == "Descending";
|
||||||
var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
var activeSearchElem = gradioApp().querySelector('#' + tabname_full + "_controls .extra-network-control--sort.extra-network-control--enabled");
|
||||||
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : "default";
|
||||||
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
var sortKeyDataField = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||||
|
var sortKeyStore = sortKey + "-" + sort_dir.dataset.sortdir + "-" + cards.length;
|
||||||
|
|
||||||
if (sortKeyStore == sort_mode.dataset.sortkey && !force) {
|
if (sortKeyStore == currentSort && !force) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
sort_mode.dataset.sortkey = sortKeyStore;
|
currentSort = sortKeyStore;
|
||||||
|
|
||||||
cards.forEach(function(card) {
|
|
||||||
card.originalParentElement = card.parentElement;
|
|
||||||
});
|
|
||||||
var sortedCards = Array.from(cards);
|
var sortedCards = Array.from(cards);
|
||||||
sortedCards.sort(function(cardA, cardB) {
|
sortedCards.sort(function(cardA, cardB) {
|
||||||
var a = cardA.dataset[sortKey];
|
var a = cardA.dataset[sortKeyDataField];
|
||||||
var b = cardB.dataset[sortKey];
|
var b = cardB.dataset[sortKeyDataField];
|
||||||
if (!isNaN(a) && !isNaN(b)) {
|
if (!isNaN(a) && !isNaN(b)) {
|
||||||
return parseInt(a) - parseInt(b);
|
return parseInt(a) - parseInt(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
return (a < b ? -1 : (a > b ? 1 : 0));
|
return (a < b ? -1 : (a > b ? 1 : 0));
|
||||||
});
|
});
|
||||||
|
|
||||||
if (reverse) {
|
if (reverse) {
|
||||||
sortedCards.reverse();
|
sortedCards.reverse();
|
||||||
}
|
}
|
||||||
cards.forEach(function(card) {
|
|
||||||
card.remove();
|
parent.innerHTML = '';
|
||||||
});
|
|
||||||
|
var frag = document.createDocumentFragment();
|
||||||
sortedCards.forEach(function(card) {
|
sortedCards.forEach(function(card) {
|
||||||
card.originalParentElement.appendChild(card);
|
frag.appendChild(card);
|
||||||
});
|
});
|
||||||
|
parent.appendChild(frag);
|
||||||
};
|
};
|
||||||
|
|
||||||
search.addEventListener("input", applyFilter);
|
search.addEventListener("input", function() {
|
||||||
|
applyFilter();
|
||||||
|
});
|
||||||
applySort();
|
applySort();
|
||||||
applyFilter();
|
applyFilter();
|
||||||
extraNetworksApplySort[tabname_full] = applySort;
|
extraNetworksApplySort[tabname_full] = applySort;
|
||||||
@ -272,6 +276,15 @@ function saveCardPreview(event, tabname, filename) {
|
|||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function extraNetworksSearchButton(tabname, extra_networks_tabname, event) {
|
||||||
|
var searchTextarea = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search");
|
||||||
|
var button = event.target;
|
||||||
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||||
|
|
||||||
|
searchTextarea.value = text;
|
||||||
|
updateInput(searchTextarea);
|
||||||
|
}
|
||||||
|
|
||||||
function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {
|
function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {
|
||||||
/**
|
/**
|
||||||
* Processes `onclick` events when user clicks on files in tree.
|
* Processes `onclick` events when user clicks on files in tree.
|
||||||
@ -290,7 +303,7 @@ function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_netwo
|
|||||||
* Processes `onclick` events when user clicks on directories in tree.
|
* Processes `onclick` events when user clicks on directories in tree.
|
||||||
*
|
*
|
||||||
* Here is how the tree reacts to clicks for various states:
|
* Here is how the tree reacts to clicks for various states:
|
||||||
* unselected unopened directory: Diretory is selected and expanded.
|
* unselected unopened directory: Directory is selected and expanded.
|
||||||
* unselected opened directory: Directory is selected.
|
* unselected opened directory: Directory is selected.
|
||||||
* selected opened directory: Directory is collapsed and deselected.
|
* selected opened directory: Directory is collapsed and deselected.
|
||||||
* chevron is clicked: Directory is expanded or collapsed. Selected state unchanged.
|
* chevron is clicked: Directory is expanded or collapsed. Selected state unchanged.
|
||||||
@ -383,36 +396,17 @@ function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
|
function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
|
||||||
/**
|
/** Handles `onclick` events for Sort Mode buttons. */
|
||||||
* Handles `onclick` events for the Sort Mode button.
|
|
||||||
*
|
var self = event.currentTarget;
|
||||||
* Modifies the data attributes of the Sort Mode button to cycle between
|
var parent = event.currentTarget.parentElement;
|
||||||
* various sorting modes.
|
|
||||||
*
|
parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) {
|
||||||
* @param event The generated event.
|
x.classList.remove('extra-network-control--enabled');
|
||||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
});
|
||||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
|
||||||
*/
|
self.classList.add('extra-network-control--enabled');
|
||||||
var curr_mode = event.currentTarget.dataset.sortmode;
|
|
||||||
var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir");
|
|
||||||
var sort_dir = el_sort_dir.dataset.sortdir;
|
|
||||||
if (curr_mode == "path") {
|
|
||||||
event.currentTarget.dataset.sortmode = "name";
|
|
||||||
event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640";
|
|
||||||
event.currentTarget.setAttribute("title", "Sort by filename");
|
|
||||||
} else if (curr_mode == "name") {
|
|
||||||
event.currentTarget.dataset.sortmode = "date_created";
|
|
||||||
event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640";
|
|
||||||
event.currentTarget.setAttribute("title", "Sort by date created");
|
|
||||||
} else if (curr_mode == "date_created") {
|
|
||||||
event.currentTarget.dataset.sortmode = "date_modified";
|
|
||||||
event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640";
|
|
||||||
event.currentTarget.setAttribute("title", "Sort by date modified");
|
|
||||||
} else {
|
|
||||||
event.currentTarget.dataset.sortmode = "path";
|
|
||||||
event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640";
|
|
||||||
event.currentTarget.setAttribute("title", "Sort by path");
|
|
||||||
}
|
|
||||||
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -447,27 +441,12 @@ function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabn
|
|||||||
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
*/
|
*/
|
||||||
const tree = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree");
|
var button = event.currentTarget;
|
||||||
const parent = tree.parentElement;
|
button.classList.toggle("extra-network-control--enabled");
|
||||||
let resizeHandle = parent.querySelector('.resize-handle');
|
var show = !button.classList.contains("extra-network-control--enabled");
|
||||||
tree.classList.toggle("hidden");
|
|
||||||
|
|
||||||
if (tree.classList.contains("hidden")) {
|
var pane = gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_pane");
|
||||||
tree.style.display = 'none';
|
pane.classList.toggle("extra-network-dirs-hidden", show);
|
||||||
parent.style.display = 'flex';
|
|
||||||
if (resizeHandle) {
|
|
||||||
resizeHandle.style.display = 'none';
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tree.style.display = 'block';
|
|
||||||
parent.style.display = 'grid';
|
|
||||||
if (!resizeHandle) {
|
|
||||||
setupResizeHandle(parent);
|
|
||||||
resizeHandle = parent.querySelector('.resize-handle');
|
|
||||||
}
|
|
||||||
resizeHandle.style.display = 'block';
|
|
||||||
}
|
|
||||||
event.currentTarget.classList.toggle("extra-network-control--enabled");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
|
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
|
||||||
@ -528,12 +507,76 @@ function popupId(id) {
|
|||||||
popup(storedPopupIds[id]);
|
popup(storedPopupIds[id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function extraNetworksFlattenMetadata(obj) {
|
||||||
|
const result = {};
|
||||||
|
|
||||||
|
// Convert any stringified JSON objects to actual objects
|
||||||
|
for (const key of Object.keys(obj)) {
|
||||||
|
if (typeof obj[key] === 'string') {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(obj[key]);
|
||||||
|
if (parsed && typeof parsed === 'object') {
|
||||||
|
obj[key] = parsed;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten the object
|
||||||
|
for (const key of Object.keys(obj)) {
|
||||||
|
if (typeof obj[key] === 'object' && obj[key] !== null) {
|
||||||
|
const nested = extraNetworksFlattenMetadata(obj[key]);
|
||||||
|
for (const nestedKey of Object.keys(nested)) {
|
||||||
|
result[`${key}/${nestedKey}`] = nested[nestedKey];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result[key] = obj[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special case for handling modelspec keys
|
||||||
|
for (const key of Object.keys(result)) {
|
||||||
|
if (key.startsWith("modelspec.")) {
|
||||||
|
result[key.replaceAll(".", "/")] = result[key];
|
||||||
|
delete result[key];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add empty keys to designate hierarchy
|
||||||
|
for (const key of Object.keys(result)) {
|
||||||
|
const parts = key.split("/");
|
||||||
|
for (let i = 1; i < parts.length; i++) {
|
||||||
|
const parent = parts.slice(0, i).join("/");
|
||||||
|
if (!result[parent]) {
|
||||||
|
result[parent] = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
function extraNetworksShowMetadata(text) {
|
function extraNetworksShowMetadata(text) {
|
||||||
|
try {
|
||||||
|
let parsed = JSON.parse(text);
|
||||||
|
if (parsed && typeof parsed === 'object') {
|
||||||
|
parsed = extraNetworksFlattenMetadata(parsed);
|
||||||
|
const table = createVisualizationTable(parsed, 0);
|
||||||
|
popup(table);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.eror(error);
|
||||||
|
}
|
||||||
|
|
||||||
var elem = document.createElement('pre');
|
var elem = document.createElement('pre');
|
||||||
elem.classList.add('popup-metadata');
|
elem.classList.add('popup-metadata');
|
||||||
elem.textContent = text;
|
elem.textContent = text;
|
||||||
|
|
||||||
popup(elem);
|
popup(elem);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
function requestGet(url, data, handler, errorHandler) {
|
function requestGet(url, data, handler, errorHandler) {
|
||||||
@ -562,16 +605,18 @@ function requestGet(url, data, handler, errorHandler) {
|
|||||||
xhr.send(js);
|
xhr.send(js);
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksCopyCardPath(event, path) {
|
function extraNetworksCopyCardPath(event) {
|
||||||
navigator.clipboard.writeText(path);
|
navigator.clipboard.writeText(event.target.getAttribute("data-clipboard-text"));
|
||||||
event.stopPropagation();
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
function extraNetworksRequestMetadata(event, extraPage) {
|
||||||
var showError = function() {
|
var showError = function() {
|
||||||
extraNetworksShowMetadata("there was an error getting metadata");
|
extraNetworksShowMetadata("there was an error getting metadata");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
var cardName = event.target.parentElement.parentElement.getAttribute("data-name");
|
||||||
|
|
||||||
requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) {
|
requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) {
|
||||||
if (data && data.metadata) {
|
if (data && data.metadata) {
|
||||||
extraNetworksShowMetadata(data.metadata);
|
extraNetworksShowMetadata(data.metadata);
|
||||||
@ -585,7 +630,7 @@ function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
|||||||
|
|
||||||
var extraPageUserMetadataEditors = {};
|
var extraPageUserMetadataEditors = {};
|
||||||
|
|
||||||
function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
|
function extraNetworksEditUserMetadata(event, tabname, extraPage) {
|
||||||
var id = tabname + '_' + extraPage + '_edit_user_metadata';
|
var id = tabname + '_' + extraPage + '_edit_user_metadata';
|
||||||
|
|
||||||
var editor = extraPageUserMetadataEditors[id];
|
var editor = extraPageUserMetadataEditors[id];
|
||||||
@ -597,6 +642,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
|
|||||||
extraPageUserMetadataEditors[id] = editor;
|
extraPageUserMetadataEditors[id] = editor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var cardName = event.target.parentElement.parentElement.getAttribute("data-name");
|
||||||
editor.nameTextarea.value = cardName;
|
editor.nameTextarea.value = cardName;
|
||||||
updateInput(editor.nameTextarea);
|
updateInput(editor.nameTextarea);
|
||||||
|
|
||||||
|
@ -131,19 +131,15 @@ function setupImageForLightbox(e) {
|
|||||||
e.style.cursor = 'pointer';
|
e.style.cursor = 'pointer';
|
||||||
e.style.userSelect = 'none';
|
e.style.userSelect = 'none';
|
||||||
|
|
||||||
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1;
|
e.addEventListener('mousedown', function(evt) {
|
||||||
|
|
||||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
|
||||||
// If you know how to fix this without switching to mousedown event, please.
|
|
||||||
// For other browsers the event is click to make it possiblr to drag picture.
|
|
||||||
var event = isFirefox ? 'mousedown' : 'click';
|
|
||||||
|
|
||||||
e.addEventListener(event, function(evt) {
|
|
||||||
if (evt.button == 1) {
|
if (evt.button == 1) {
|
||||||
open(evt.target.src);
|
open(evt.target.src);
|
||||||
evt.preventDefault();
|
evt.preventDefault();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
}, true);
|
||||||
|
|
||||||
|
e.addEventListener('click', function(evt) {
|
||||||
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
|
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||||
|
@ -33,120 +33,141 @@ function createRow(table, cellName, items) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function showProfile(path, cutoff = 0.05) {
|
function createVisualizationTable(data, cutoff = 0, sort = "") {
|
||||||
requestGet(path, {}, function(data) {
|
var table = document.createElement('table');
|
||||||
var table = document.createElement('table');
|
table.className = 'popup-table';
|
||||||
table.className = 'popup-table';
|
|
||||||
|
|
||||||
data.records['total'] = data.total;
|
var keys = Object.keys(data);
|
||||||
var keys = Object.keys(data.records).sort(function(a, b) {
|
if (sort === "number") {
|
||||||
return data.records[b] - data.records[a];
|
keys = keys.sort(function(a, b) {
|
||||||
|
return data[b] - data[a];
|
||||||
});
|
});
|
||||||
var items = keys.map(function(x) {
|
} else {
|
||||||
return {key: x, parts: x.split('/'), time: data.records[x]};
|
keys = keys.sort();
|
||||||
|
}
|
||||||
|
var items = keys.map(function(x) {
|
||||||
|
return {key: x, parts: x.split('/'), value: data[x]};
|
||||||
|
});
|
||||||
|
var maxLength = items.reduce(function(a, b) {
|
||||||
|
return Math.max(a, b.parts.length);
|
||||||
|
}, 0);
|
||||||
|
|
||||||
|
var cols = createRow(
|
||||||
|
table,
|
||||||
|
'th',
|
||||||
|
[
|
||||||
|
cutoff === 0 ? 'key' : 'record',
|
||||||
|
cutoff === 0 ? 'value' : 'seconds'
|
||||||
|
]
|
||||||
|
);
|
||||||
|
cols[0].colSpan = maxLength;
|
||||||
|
|
||||||
|
function arraysEqual(a, b) {
|
||||||
|
return !(a < b || b < a);
|
||||||
|
}
|
||||||
|
|
||||||
|
var addLevel = function(level, parent, hide) {
|
||||||
|
var matching = items.filter(function(x) {
|
||||||
|
return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
|
||||||
});
|
});
|
||||||
var maxLength = items.reduce(function(a, b) {
|
if (sort === "number") {
|
||||||
return Math.max(a, b.parts.length);
|
matching = matching.sort(function(a, b) {
|
||||||
}, 0);
|
return b.value - a.value;
|
||||||
|
});
|
||||||
var cols = createRow(table, 'th', ['record', 'seconds']);
|
} else {
|
||||||
cols[0].colSpan = maxLength;
|
matching = matching.sort();
|
||||||
|
|
||||||
function arraysEqual(a, b) {
|
|
||||||
return !(a < b || b < a);
|
|
||||||
}
|
}
|
||||||
|
var othersTime = 0;
|
||||||
|
var othersList = [];
|
||||||
|
var othersRows = [];
|
||||||
|
var childrenRows = [];
|
||||||
|
matching.forEach(function(x) {
|
||||||
|
var visible = (cutoff === 0 && !hide) || (x.value >= cutoff && !hide);
|
||||||
|
|
||||||
var addLevel = function(level, parent, hide) {
|
var cells = [];
|
||||||
var matching = items.filter(function(x) {
|
for (var i = 0; i < maxLength; i++) {
|
||||||
return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
|
cells.push(x.parts[i]);
|
||||||
});
|
}
|
||||||
var sorted = matching.sort(function(a, b) {
|
cells.push(cutoff === 0 ? x.value : x.value.toFixed(3));
|
||||||
return b.time - a.time;
|
var cols = createRow(table, 'td', cells);
|
||||||
});
|
for (i = 0; i < level; i++) {
|
||||||
var othersTime = 0;
|
cols[i].className = 'muted';
|
||||||
var othersList = [];
|
}
|
||||||
var othersRows = [];
|
|
||||||
var childrenRows = [];
|
|
||||||
sorted.forEach(function(x) {
|
|
||||||
var visible = x.time >= cutoff && !hide;
|
|
||||||
|
|
||||||
var cells = [];
|
var tr = cols[0].parentNode;
|
||||||
for (var i = 0; i < maxLength; i++) {
|
if (!visible) {
|
||||||
cells.push(x.parts[i]);
|
tr.classList.add("hidden");
|
||||||
}
|
}
|
||||||
cells.push(x.time.toFixed(3));
|
|
||||||
var cols = createRow(table, 'td', cells);
|
|
||||||
for (i = 0; i < level; i++) {
|
|
||||||
cols[i].className = 'muted';
|
|
||||||
}
|
|
||||||
|
|
||||||
var tr = cols[0].parentNode;
|
if (cutoff === 0 || x.value >= cutoff) {
|
||||||
if (!visible) {
|
childrenRows.push(tr);
|
||||||
tr.classList.add("hidden");
|
} else {
|
||||||
}
|
othersTime += x.value;
|
||||||
|
othersList.push(x.parts[level]);
|
||||||
if (x.time >= cutoff) {
|
othersRows.push(tr);
|
||||||
childrenRows.push(tr);
|
}
|
||||||
} else {
|
|
||||||
othersTime += x.time;
|
|
||||||
othersList.push(x.parts[level]);
|
|
||||||
othersRows.push(tr);
|
|
||||||
}
|
|
||||||
|
|
||||||
var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);
|
|
||||||
if (children.length > 0) {
|
|
||||||
var cell = cols[level];
|
|
||||||
var onclick = function() {
|
|
||||||
cell.classList.remove("link");
|
|
||||||
cell.removeEventListener("click", onclick);
|
|
||||||
children.forEach(function(x) {
|
|
||||||
x.classList.remove("hidden");
|
|
||||||
});
|
|
||||||
};
|
|
||||||
cell.classList.add("link");
|
|
||||||
cell.addEventListener("click", onclick);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
if (othersTime > 0) {
|
|
||||||
var cells = [];
|
|
||||||
for (var i = 0; i < maxLength; i++) {
|
|
||||||
cells.push(parent[i]);
|
|
||||||
}
|
|
||||||
cells.push(othersTime.toFixed(3));
|
|
||||||
cells[level] = 'others';
|
|
||||||
var cols = createRow(table, 'td', cells);
|
|
||||||
for (i = 0; i < level; i++) {
|
|
||||||
cols[i].className = 'muted';
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);
|
||||||
|
if (children.length > 0) {
|
||||||
var cell = cols[level];
|
var cell = cols[level];
|
||||||
var tr = cell.parentNode;
|
|
||||||
var onclick = function() {
|
var onclick = function() {
|
||||||
tr.classList.add("hidden");
|
|
||||||
cell.classList.remove("link");
|
cell.classList.remove("link");
|
||||||
cell.removeEventListener("click", onclick);
|
cell.removeEventListener("click", onclick);
|
||||||
othersRows.forEach(function(x) {
|
children.forEach(function(x) {
|
||||||
x.classList.remove("hidden");
|
x.classList.remove("hidden");
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
cell.title = othersList.join(", ");
|
|
||||||
cell.classList.add("link");
|
cell.classList.add("link");
|
||||||
cell.addEventListener("click", onclick);
|
cell.addEventListener("click", onclick);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
if (hide) {
|
if (othersTime > 0) {
|
||||||
tr.classList.add("hidden");
|
var cells = [];
|
||||||
}
|
for (var i = 0; i < maxLength; i++) {
|
||||||
|
cells.push(parent[i]);
|
||||||
childrenRows.push(tr);
|
}
|
||||||
|
cells.push(othersTime.toFixed(3));
|
||||||
|
cells[level] = 'others';
|
||||||
|
var cols = createRow(table, 'td', cells);
|
||||||
|
for (i = 0; i < level; i++) {
|
||||||
|
cols[i].className = 'muted';
|
||||||
}
|
}
|
||||||
|
|
||||||
return childrenRows;
|
var cell = cols[level];
|
||||||
};
|
var tr = cell.parentNode;
|
||||||
|
var onclick = function() {
|
||||||
|
tr.classList.add("hidden");
|
||||||
|
cell.classList.remove("link");
|
||||||
|
cell.removeEventListener("click", onclick);
|
||||||
|
othersRows.forEach(function(x) {
|
||||||
|
x.classList.remove("hidden");
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
addLevel(0, []);
|
cell.title = othersList.join(", ");
|
||||||
|
cell.classList.add("link");
|
||||||
|
cell.addEventListener("click", onclick);
|
||||||
|
|
||||||
|
if (hide) {
|
||||||
|
tr.classList.add("hidden");
|
||||||
|
}
|
||||||
|
|
||||||
|
childrenRows.push(tr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return childrenRows;
|
||||||
|
};
|
||||||
|
|
||||||
|
addLevel(0, []);
|
||||||
|
|
||||||
|
return table;
|
||||||
|
}
|
||||||
|
|
||||||
|
function showProfile(path, cutoff = 0.05) {
|
||||||
|
requestGet(path, {}, function(data) {
|
||||||
|
data.records['total'] = data.total;
|
||||||
|
const table = createVisualizationTable(data.records, cutoff, "number");
|
||||||
popup(table);
|
popup(table);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -79,6 +79,11 @@
|
|||||||
parent.minRightColWidth = 0;
|
parent.minRightColWidth = 0;
|
||||||
parent.needHideOnMoblie = false;
|
parent.needHideOnMoblie = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!leftColTemplate) {
|
||||||
|
leftColTemplate = '1fr';
|
||||||
|
}
|
||||||
|
|
||||||
const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`;
|
const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`;
|
||||||
parent.style.gridTemplateColumns = gridTemplateColumns;
|
parent.style.gridTemplateColumns = gridTemplateColumns;
|
||||||
parent.style.originalGridTemplateColumns = gridTemplateColumns;
|
parent.style.originalGridTemplateColumns = gridTemplateColumns;
|
||||||
|
@ -125,8 +125,7 @@ function showSubmitInterruptingPlaceholder(tabname) {
|
|||||||
function showRestoreProgressButton(tabname, show) {
|
function showRestoreProgressButton(tabname, show) {
|
||||||
var button = gradioApp().getElementById(tabname + "_restore_progress");
|
var button = gradioApp().getElementById(tabname + "_restore_progress");
|
||||||
if (!button) return;
|
if (!button) return;
|
||||||
|
button.style.setProperty('display', show ? 'flex' : 'none', 'important');
|
||||||
button.style.display = show ? "flex" : "none";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function submit() {
|
function submit() {
|
||||||
@ -196,6 +195,7 @@ function restoreProgressTxt2img() {
|
|||||||
var id = localGet("txt2img_task_id");
|
var id = localGet("txt2img_task_id");
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
|
showSubmitInterruptingPlaceholder('txt2img');
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
}, null, 0);
|
}, null, 0);
|
||||||
@ -210,6 +210,7 @@ function restoreProgressImg2img() {
|
|||||||
var id = localGet("img2img_task_id");
|
var id = localGet("img2img_task_id");
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
|
showSubmitInterruptingPlaceholder('img2img');
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
showSubmitButtons('img2img', true);
|
showSubmitButtons('img2img', true);
|
||||||
}, null, 0);
|
}, null, 0);
|
||||||
@ -398,7 +399,7 @@ function switchWidthHeight(tabname) {
|
|||||||
|
|
||||||
var onEditTimers = {};
|
var onEditTimers = {};
|
||||||
|
|
||||||
// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
|
// calls func after afterMs milliseconds has passed since the input elem has been edited by user
|
||||||
function onEdit(editId, elem, afterMs, func) {
|
function onEdit(editId, elem, afterMs, func) {
|
||||||
var edited = function() {
|
var edited = function() {
|
||||||
var existingTimer = onEditTimers[editId];
|
var existingTimer = onEditTimers[editId];
|
||||||
|
@ -23,7 +23,7 @@ from modules.shared import opts
|
|||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||||
from PIL import PngImagePlugin, Image
|
from PIL import PngImagePlugin
|
||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
@ -85,7 +85,7 @@ def decode_base64_to_image(encoding):
|
|||||||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
|
||||||
response = requests.get(encoding, timeout=30, headers=headers)
|
response = requests.get(encoding, timeout=30, headers=headers)
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(response.content))
|
image = images.read(BytesIO(response.content))
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
raise HTTPException(status_code=500, detail="Invalid image url") from e
|
||||||
@ -93,7 +93,7 @@ def decode_base64_to_image(encoding):
|
|||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = images.read(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
@ -360,7 +360,7 @@ class Api:
|
|||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
||||||
"""Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
|
"""Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
|
||||||
|
|
||||||
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
||||||
|
|
||||||
@ -409,8 +409,8 @@ class Api:
|
|||||||
if request.override_settings is None:
|
if request.override_settings is None:
|
||||||
request.override_settings = {}
|
request.override_settings = {}
|
||||||
|
|
||||||
overriden_settings = infotext_utils.get_override_settings(params)
|
overridden_settings = infotext_utils.get_override_settings(params)
|
||||||
for _, setting_name, value in overriden_settings:
|
for _, setting_name, value in overridden_settings:
|
||||||
if setting_name not in request.override_settings:
|
if setting_name not in request.override_settings:
|
||||||
request.override_settings[setting_name] = value
|
request.override_settings[setting_name] = value
|
||||||
|
|
||||||
|
@ -2,48 +2,55 @@ import json
|
|||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
|
import diskcache
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from modules.paths import data_path, script_path
|
from modules.paths import data_path, script_path
|
||||||
|
|
||||||
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json"))
|
||||||
cache_data = None
|
cache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, "cache"))
|
||||||
|
caches = {}
|
||||||
cache_lock = threading.Lock()
|
cache_lock = threading.Lock()
|
||||||
|
|
||||||
dump_cache_after = None
|
|
||||||
dump_cache_thread = None
|
|
||||||
|
|
||||||
|
|
||||||
def dump_cache():
|
def dump_cache():
|
||||||
"""
|
"""old function for dumping cache to disk; does nothing since diskcache."""
|
||||||
Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
|
|
||||||
"""
|
|
||||||
|
|
||||||
global dump_cache_after
|
pass
|
||||||
global dump_cache_thread
|
|
||||||
|
|
||||||
def thread_func():
|
|
||||||
global dump_cache_after
|
|
||||||
global dump_cache_thread
|
|
||||||
|
|
||||||
while dump_cache_after is not None and time.time() < dump_cache_after:
|
def make_cache(subsection: str) -> diskcache.Cache:
|
||||||
time.sleep(1)
|
return diskcache.Cache(
|
||||||
|
os.path.join(cache_dir, subsection),
|
||||||
|
size_limit=2**32, # 4 GB, culling oldest first
|
||||||
|
disk_min_file_size=2**18, # keep up to 256KB in Sqlite
|
||||||
|
)
|
||||||
|
|
||||||
with cache_lock:
|
|
||||||
cache_filename_tmp = cache_filename + "-"
|
|
||||||
with open(cache_filename_tmp, "w", encoding="utf8") as file:
|
|
||||||
json.dump(cache_data, file, indent=4, ensure_ascii=False)
|
|
||||||
|
|
||||||
os.replace(cache_filename_tmp, cache_filename)
|
def convert_old_cached_data():
|
||||||
|
try:
|
||||||
|
with open(cache_filename, "r", encoding="utf8") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||||
|
print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')
|
||||||
|
return
|
||||||
|
|
||||||
dump_cache_after = None
|
total_count = sum(len(keyvalues) for keyvalues in data.values())
|
||||||
dump_cache_thread = None
|
|
||||||
|
|
||||||
with cache_lock:
|
with tqdm.tqdm(total=total_count, desc="converting cache") as progress:
|
||||||
dump_cache_after = time.time() + 5
|
for subsection, keyvalues in data.items():
|
||||||
if dump_cache_thread is None:
|
cache_obj = caches.get(subsection)
|
||||||
dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
|
if cache_obj is None:
|
||||||
dump_cache_thread.start()
|
cache_obj = make_cache(subsection)
|
||||||
|
caches[subsection] = cache_obj
|
||||||
|
|
||||||
|
for key, value in keyvalues.items():
|
||||||
|
cache_obj[key] = value
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
|
||||||
def cache(subsection):
|
def cache(subsection):
|
||||||
@ -54,28 +61,21 @@ def cache(subsection):
|
|||||||
subsection (str): The subsection identifier for the cache.
|
subsection (str): The subsection identifier for the cache.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The cache data for the specified subsection.
|
diskcache.Cache: The cache data for the specified subsection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
global cache_data
|
cache_obj = caches.get(subsection)
|
||||||
|
if not cache_obj:
|
||||||
if cache_data is None:
|
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
if cache_data is None:
|
if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):
|
||||||
try:
|
convert_old_cached_data()
|
||||||
with open(cache_filename, "r", encoding="utf8") as file:
|
|
||||||
cache_data = json.load(file)
|
|
||||||
except FileNotFoundError:
|
|
||||||
cache_data = {}
|
|
||||||
except Exception:
|
|
||||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
|
||||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
|
||||||
cache_data = {}
|
|
||||||
|
|
||||||
s = cache_data.get(subsection, {})
|
cache_obj = caches.get(subsection)
|
||||||
cache_data[subsection] = s
|
if not cache_obj:
|
||||||
|
cache_obj = make_cache(subsection)
|
||||||
|
caches[subsection] = cache_obj
|
||||||
|
|
||||||
return s
|
return cache_obj
|
||||||
|
|
||||||
|
|
||||||
def cached_data_for_file(subsection, title, filename, func):
|
def cached_data_for_file(subsection, title, filename, func):
|
||||||
|
@ -100,8 +100,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
sys_pct = sys_peak/max(sys_total, 1) * 100
|
sys_pct = sys_peak/max(sys_total, 1) * 100
|
||||||
|
|
||||||
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
toltip_a = "Active: peak amount of video memory used during generation (excluding cached data)"
|
||||||
toltip_r = "Reserved: total amout of video memory allocated by the Torch library "
|
toltip_r = "Reserved: total amount of video memory allocated by the Torch library "
|
||||||
toltip_sys = "System: peak amout of video memory allocated by all running programs, out of total capacity"
|
toltip_sys = "System: peak amount of video memory allocated by all running programs, out of total capacity"
|
||||||
|
|
||||||
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
text_a = f"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>"
|
||||||
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
text_r = f"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>"
|
||||||
|
@ -124,3 +124,4 @@ parser.add_argument("--disable-extra-extensions", action='store_true', help="pre
|
|||||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui")
|
||||||
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
parser.add_argument("--unix-filenames-sanitization", action='store_true', help="allow any symbols except '/' in filenames. May conflict with your browser and file system")
|
||||||
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
parser.add_argument("--filenames-max-length", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')
|
||||||
|
parser.add_argument("--no-prompt-history", action='store_true', help="disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file")
|
||||||
|
@ -259,7 +259,7 @@ def test_for_nans(x, where):
|
|||||||
def first_time_calculation():
|
def first_time_calculation():
|
||||||
"""
|
"""
|
||||||
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
|
||||||
spends about 2.7 seconds doing that, at least wih NVidia.
|
spends about 2.7 seconds doing that, at least with NVidia.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = torch.zeros((1, 1)).to(device, dtype)
|
x = torch.zeros((1, 1)).to(device, dtype)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import re
|
import re
|
||||||
@ -9,6 +10,10 @@ from modules import shared, errors, cache, scripts
|
|||||||
from modules.gitpython_hack import Repo
|
from modules.gitpython_hack import Repo
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
|
extensions: list[Extension] = []
|
||||||
|
extension_paths: dict[str, Extension] = {}
|
||||||
|
loaded_extensions: dict[str, Exception] = {}
|
||||||
|
|
||||||
|
|
||||||
os.makedirs(extensions_dir, exist_ok=True)
|
os.makedirs(extensions_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -22,6 +27,13 @@ def active():
|
|||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CallbackOrderInfo:
|
||||||
|
name: str
|
||||||
|
before: list
|
||||||
|
after: list
|
||||||
|
|
||||||
|
|
||||||
class ExtensionMetadata:
|
class ExtensionMetadata:
|
||||||
filename = "metadata.ini"
|
filename = "metadata.ini"
|
||||||
config: configparser.ConfigParser
|
config: configparser.ConfigParser
|
||||||
@ -42,7 +54,7 @@ class ExtensionMetadata:
|
|||||||
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
|
||||||
self.canonical_name = canonical_name.lower().strip()
|
self.canonical_name = canonical_name.lower().strip()
|
||||||
|
|
||||||
self.requires = self.get_script_requirements("Requires", "Extension")
|
self.requires = None
|
||||||
|
|
||||||
def get_script_requirements(self, field, section, extra_section=None):
|
def get_script_requirements(self, field, section, extra_section=None):
|
||||||
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
"""reads a list of requirements from the config; field is the name of the field in the ini file,
|
||||||
@ -54,7 +66,15 @@ class ExtensionMetadata:
|
|||||||
if extra_section:
|
if extra_section:
|
||||||
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
x = x + ', ' + self.config.get(extra_section, field, fallback='')
|
||||||
|
|
||||||
return self.parse_list(x.lower())
|
listed_requirements = self.parse_list(x.lower())
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for requirement in listed_requirements:
|
||||||
|
loaded_requirements = (x for x in requirement.split("|") if x in loaded_extensions)
|
||||||
|
relevant_requirement = next(loaded_requirements, requirement)
|
||||||
|
res.append(relevant_requirement)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def parse_list(self, text):
|
def parse_list(self, text):
|
||||||
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
"""converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
|
||||||
@ -65,6 +85,22 @@ class ExtensionMetadata:
|
|||||||
# both "," and " " are accepted as separator
|
# both "," and " " are accepted as separator
|
||||||
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
|
||||||
|
|
||||||
|
def list_callback_order_instructions(self):
|
||||||
|
for section in self.config.sections():
|
||||||
|
if not section.startswith("callbacks/"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
callback_name = section[10:]
|
||||||
|
|
||||||
|
if not callback_name.startswith(self.canonical_name):
|
||||||
|
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
|
||||||
|
after = self.parse_list(self.config.get(section, 'After', fallback=''))
|
||||||
|
|
||||||
|
yield CallbackOrderInfo(callback_name, before, after)
|
||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
@ -156,6 +192,8 @@ class Extension:
|
|||||||
def check_updates(self):
|
def check_updates(self):
|
||||||
repo = Repo(self.path)
|
repo = Repo(self.path)
|
||||||
for fetch in repo.remote().fetch(dry_run=True):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
|
if self.branch and fetch.name != f'{repo.remote().name}/{self.branch}':
|
||||||
|
continue
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
self.status = "new commits"
|
self.status = "new commits"
|
||||||
@ -186,6 +224,8 @@ class Extension:
|
|||||||
|
|
||||||
def list_extensions():
|
def list_extensions():
|
||||||
extensions.clear()
|
extensions.clear()
|
||||||
|
extension_paths.clear()
|
||||||
|
loaded_extensions.clear()
|
||||||
|
|
||||||
if shared.cmd_opts.disable_all_extensions:
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
@ -196,7 +236,6 @@ def list_extensions():
|
|||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
loaded_extensions = {}
|
|
||||||
|
|
||||||
# scan through extensions directory and load metadata
|
# scan through extensions directory and load metadata
|
||||||
for dirname in [extensions_builtin_dir, extensions_dir]:
|
for dirname in [extensions_builtin_dir, extensions_dir]:
|
||||||
@ -220,8 +259,12 @@ def list_extensions():
|
|||||||
is_builtin = dirname == extensions_builtin_dir
|
is_builtin = dirname == extensions_builtin_dir
|
||||||
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
|
||||||
extensions.append(extension)
|
extensions.append(extension)
|
||||||
|
extension_paths[extension.path] = extension
|
||||||
loaded_extensions[canonical_name] = extension
|
loaded_extensions[canonical_name] = extension
|
||||||
|
|
||||||
|
for extension in extensions:
|
||||||
|
extension.metadata.requires = extension.metadata.get_script_requirements("Requires", "Extension")
|
||||||
|
|
||||||
# check for requirements
|
# check for requirements
|
||||||
for extension in extensions:
|
for extension in extensions:
|
||||||
if not extension.enabled:
|
if not extension.enabled:
|
||||||
@ -238,4 +281,16 @@ def list_extensions():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
extensions: list[Extension] = []
|
def find_extension(filename):
|
||||||
|
parentdir = os.path.dirname(os.path.realpath(filename))
|
||||||
|
|
||||||
|
while parentdir != filename:
|
||||||
|
extension = extension_paths.get(parentdir)
|
||||||
|
if extension is not None:
|
||||||
|
return extension
|
||||||
|
|
||||||
|
filename = parentdir
|
||||||
|
parentdir = os.path.dirname(filename)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class ExtraNetwork:
|
|||||||
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
|
||||||
separated by colon.
|
separated by colon.
|
||||||
|
|
||||||
Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
|
Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -
|
||||||
in this case, all effects of this extra networks should be disabled.
|
in this case, all effects of this extra networks should be disabled.
|
||||||
|
|
||||||
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
Can be called multiple times before deactivate() - each new call should override the previous call completely.
|
||||||
|
@ -95,6 +95,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
zeros_(b)
|
zeros_(b)
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||||
|
devices.torch_npu_set_device()
|
||||||
self.to(devices.device)
|
self.to(devices.device)
|
||||||
|
|
||||||
def fix_old_state_dict(self, state_dict):
|
def fix_old_state_dict(self, state_dict):
|
||||||
|
@ -12,7 +12,7 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -773,7 +773,7 @@ def image_data(data):
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(io.BytesIO(data))
|
image = read(io.BytesIO(data))
|
||||||
textinfo, _ = read_info_from_image(image)
|
textinfo, _ = read_info_from_image(image)
|
||||||
return textinfo, None
|
return textinfo, None
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -800,3 +800,30 @@ def flatten(img, bgcolor):
|
|||||||
|
|
||||||
return img.convert('RGB')
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
|
def read(fp, **kwargs):
|
||||||
|
image = Image.open(fp, **kwargs)
|
||||||
|
image = fix_image(image)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def fix_image(image: Image.Image):
|
||||||
|
if image is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = ImageOps.exif_transpose(image)
|
||||||
|
image = fix_png_transparency(image)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def fix_png_transparency(image: Image.Image):
|
||||||
|
if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
|
||||||
|
return image
|
||||||
|
|
||||||
|
image = image.convert("RGBA")
|
||||||
|
return image
|
||||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
|||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images as imgutil
|
from modules import images
|
||||||
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
@ -20,7 +20,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
output_dir = output_dir.strip()
|
output_dir = output_dir.strip()
|
||||||
processing.fix_seed(p)
|
processing.fix_seed(p)
|
||||||
|
|
||||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
batch_images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||||
|
|
||||||
is_inpaint_batch = False
|
is_inpaint_batch = False
|
||||||
if inpaint_mask_dir:
|
if inpaint_mask_dir:
|
||||||
@ -30,9 +30,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
if is_inpaint_batch:
|
if is_inpaint_batch:
|
||||||
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
|
||||||
|
|
||||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(batch_images) * p.n_iter
|
||||||
|
|
||||||
# extract "default" params to use in case getting png info fails
|
# extract "default" params to use in case getting png info fails
|
||||||
prompt = p.prompt
|
prompt = p.prompt
|
||||||
@ -45,8 +45,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||||
batch_results = None
|
batch_results = None
|
||||||
discard_further_results = False
|
discard_further_results = False
|
||||||
for i, image in enumerate(images):
|
for i, image in enumerate(batch_images):
|
||||||
state.job = f"{i+1} out of {len(images)}"
|
state.job = f"{i+1} out of {len(batch_images)}"
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
state.skipped = False
|
state.skipped = False
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(image)
|
img = images.read(image)
|
||||||
except UnidentifiedImageError as e:
|
except UnidentifiedImageError as e:
|
||||||
print(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
@ -85,7 +85,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
# otherwise user has many masks with the same name but different extensions
|
# otherwise user has many masks with the same name but different extensions
|
||||||
mask_image_path = masks_found[0]
|
mask_image_path = masks_found[0]
|
||||||
|
|
||||||
mask_image = Image.open(mask_image_path)
|
mask_image = images.read(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
|
|
||||||
if use_png_info:
|
if use_png_info:
|
||||||
@ -93,8 +93,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
info_img = img
|
info_img = img
|
||||||
if png_info_dir:
|
if png_info_dir:
|
||||||
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
|
||||||
info_img = Image.open(info_img_path)
|
info_img = images.read(info_img_path)
|
||||||
geninfo, _ = imgutil.read_info_from_image(info_img)
|
geninfo, _ = images.read_info_from_image(info_img)
|
||||||
parsed_parameters = parse_generation_parameters(geninfo)
|
parsed_parameters = parse_generation_parameters(geninfo)
|
||||||
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -145,7 +145,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
return batch_results
|
return batch_results
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -173,9 +173,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
image = None
|
image = None
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
image = images.fix_image(image)
|
||||||
if image is not None:
|
mask = images.fix_image(mask)
|
||||||
image = ImageOps.exif_transpose(image)
|
|
||||||
|
|
||||||
if selected_scale_tab == 1 and not is_batch:
|
if selected_scale_tab == 1 and not is_batch:
|
||||||
assert image, "Can't scale by because no image is selected"
|
assert image, "Can't scale by because no image is selected"
|
||||||
@ -192,10 +191,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
sampler_name=sampler_name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
@ -8,7 +8,7 @@ import sys
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
|
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
||||||
@ -95,15 +95,15 @@ def image_from_url_text(filedata):
|
|||||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||||
|
|
||||||
filename = filename.rsplit('?', 1)[0]
|
filename = filename.rsplit('?', 1)[0]
|
||||||
return Image.open(filename)
|
return images.read(filename)
|
||||||
|
|
||||||
if isinstance(filedata, str):
|
if isinstance(filedata, str):
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
|
||||||
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
filedata = base64.decodebytes(filedata.encode('utf-8'))
|
||||||
image = Image.open(io.BytesIO(filedata))
|
image = images.read(io.BytesIO(filedata))
|
||||||
return image
|
return image
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -274,17 +274,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
else:
|
else:
|
||||||
prompt += ("" if prompt == "" else "\n") + line
|
prompt += ("" if prompt == "" else "\n") + line
|
||||||
|
|
||||||
if shared.opts.infotext_styles != "Ignore":
|
|
||||||
found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
|
||||||
|
|
||||||
if shared.opts.infotext_styles == "Apply":
|
|
||||||
res["Styles array"] = found_styles
|
|
||||||
elif shared.opts.infotext_styles == "Apply if any" and found_styles:
|
|
||||||
res["Styles array"] = found_styles
|
|
||||||
|
|
||||||
res["Prompt"] = prompt
|
|
||||||
res["Negative prompt"] = negative_prompt
|
|
||||||
|
|
||||||
for k, v in re_param.findall(lastline):
|
for k, v in re_param.findall(lastline):
|
||||||
try:
|
try:
|
||||||
if v[0] == '"' and v[-1] == '"':
|
if v[0] == '"' and v[-1] == '"':
|
||||||
@ -299,6 +288,26 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error parsing \"{k}: {v}\"")
|
print(f"Error parsing \"{k}: {v}\"")
|
||||||
|
|
||||||
|
# Extract styles from prompt
|
||||||
|
if shared.opts.infotext_styles != "Ignore":
|
||||||
|
found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
|
||||||
|
|
||||||
|
same_hr_styles = True
|
||||||
|
if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
|
||||||
|
hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
|
||||||
|
hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
|
||||||
|
if same_hr_styles := found_styles == hr_found_styles:
|
||||||
|
res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
|
||||||
|
res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
|
||||||
|
|
||||||
|
if same_hr_styles:
|
||||||
|
prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
|
||||||
|
if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
|
||||||
|
res['Styles array'] = found_styles
|
||||||
|
|
||||||
|
res["Prompt"] = prompt
|
||||||
|
res["Negative prompt"] = negative_prompt
|
||||||
|
|
||||||
# Missing CLIP skip means it was set to 1 (the default)
|
# Missing CLIP skip means it was set to 1 (the default)
|
||||||
if "Clip skip" not in res:
|
if "Clip skip" not in res:
|
||||||
res["Clip skip"] = "1"
|
res["Clip skip"] = "1"
|
||||||
@ -314,6 +323,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Hires sampler" not in res:
|
if "Hires sampler" not in res:
|
||||||
res["Hires sampler"] = "Use same sampler"
|
res["Hires sampler"] = "Use same sampler"
|
||||||
|
|
||||||
|
if "Hires schedule type" not in res:
|
||||||
|
res["Hires schedule type"] = "Use same scheduler"
|
||||||
|
|
||||||
if "Hires checkpoint" not in res:
|
if "Hires checkpoint" not in res:
|
||||||
res["Hires checkpoint"] = "Use same checkpoint"
|
res["Hires checkpoint"] = "Use same checkpoint"
|
||||||
|
|
||||||
@ -365,7 +377,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||||
res["Cache FP16 weight for LoRA"] = False
|
res["Cache FP16 weight for LoRA"] = False
|
||||||
|
|
||||||
if "Emphasis" not in res:
|
prompt_attention = prompt_parser.parse_prompt_attention(prompt)
|
||||||
|
prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
|
||||||
|
prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
|
||||||
|
if "Emphasis" not in res and prompt_uses_emphasis:
|
||||||
res["Emphasis"] = "Original"
|
res["Emphasis"] = "Original"
|
||||||
|
|
||||||
if "Refiner switch by sampling steps" not in res:
|
if "Refiner switch by sampling steps" not in res:
|
||||||
@ -468,7 +483,7 @@ def get_override_settings(params, *, skip_fields=None):
|
|||||||
|
|
||||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||||
def paste_func(prompt):
|
def paste_func(prompt):
|
||||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
|
||||||
filename = os.path.join(data_path, "params.txt")
|
filename = os.path.join(data_path, "params.txt")
|
||||||
try:
|
try:
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
|
@ -6,6 +6,7 @@ import re
|
|||||||
v160 = version.parse("1.6.0")
|
v160 = version.parse("1.6.0")
|
||||||
v170_tsnr = version.parse("v1.7.0-225")
|
v170_tsnr = version.parse("v1.7.0-225")
|
||||||
v180 = version.parse("1.8.0")
|
v180 = version.parse("1.8.0")
|
||||||
|
v180_hr_styles = version.parse("1.8.0-139")
|
||||||
|
|
||||||
|
|
||||||
def parse_version(text):
|
def parse_version(text):
|
||||||
|
@ -51,6 +51,7 @@ def check_versions():
|
|||||||
def initialize():
|
def initialize():
|
||||||
from modules import initialize_util
|
from modules import initialize_util
|
||||||
initialize_util.fix_torch_version()
|
initialize_util.fix_torch_version()
|
||||||
|
initialize_util.fix_pytorch_lightning()
|
||||||
initialize_util.fix_asyncio_event_loop_policy()
|
initialize_util.fix_asyncio_event_loop_policy()
|
||||||
initialize_util.validate_tls_options()
|
initialize_util.validate_tls_options()
|
||||||
initialize_util.configure_sigint_handler()
|
initialize_util.configure_sigint_handler()
|
||||||
@ -109,7 +110,7 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
with startup_timer.subcategory("load scripts"):
|
with startup_timer.subcategory("load scripts"):
|
||||||
scripts.load_scripts()
|
scripts.load_scripts()
|
||||||
|
|
||||||
if reload_script_modules:
|
if reload_script_modules and shared.opts.enable_reloading_ui_scripts:
|
||||||
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
|
||||||
importlib.reload(module)
|
importlib.reload(module)
|
||||||
startup_timer.record("reload script modules")
|
startup_timer.record("reload script modules")
|
||||||
@ -139,7 +140,7 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
"""
|
"""
|
||||||
Accesses shared.sd_model property to load model.
|
Accesses shared.sd_model property to load model.
|
||||||
After it's available, if it has been loaded before this access by some extension,
|
After it's available, if it has been loaded before this access by some extension,
|
||||||
its optimization may be None because the list of optimizaers has neet been filled
|
its optimization may be None because the list of optimizers has not been filled
|
||||||
by that time, so we apply optimization again.
|
by that time, so we apply optimization again.
|
||||||
"""
|
"""
|
||||||
from modules import devices
|
from modules import devices
|
||||||
|
@ -26,6 +26,13 @@ def fix_torch_version():
|
|||||||
torch.__long_version__ = torch.__version__
|
torch.__long_version__ = torch.__version__
|
||||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
|
def fix_pytorch_lightning():
|
||||||
|
# Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
|
||||||
|
if 'pytorch_lightning.utilities.distributed' not in sys.modules:
|
||||||
|
import pytorch_lightning
|
||||||
|
# Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
|
||||||
|
print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
|
||||||
|
sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
|
||||||
|
|
||||||
def fix_asyncio_event_loop_policy():
|
def fix_asyncio_event_loop_policy():
|
||||||
"""
|
"""
|
||||||
|
@ -12,7 +12,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
|
||||||
# use check `getattr` and try it for compatibility.
|
# use check `getattr` and try it for compatibility.
|
||||||
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
|
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,
|
||||||
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
|
||||||
def check_for_mps() -> bool:
|
def check_for_mps() -> bool:
|
||||||
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
if version.parse(torch.__version__) <= version.parse("2.0.1"):
|
||||||
|
@ -110,7 +110,7 @@ def load_upscalers():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
data = []
|
||||||
commandline_options = vars(shared.cmd_opts)
|
commandline_options = vars(shared.cmd_opts)
|
||||||
|
|
||||||
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||||
@ -129,10 +129,10 @@ def load_upscalers():
|
|||||||
scaler = cls(commandline_model_path)
|
scaler = cls(commandline_model_path)
|
||||||
scaler.user_path = commandline_model_path
|
scaler.user_path = commandline_model_path
|
||||||
scaler.model_download_path = commandline_model_path or scaler.model_path
|
scaler.model_download_path = commandline_model_path or scaler.model_path
|
||||||
datas += scaler.scalers
|
data += scaler.scalers
|
||||||
|
|
||||||
shared.sd_upscalers = sorted(
|
shared.sd_upscalers = sorted(
|
||||||
datas,
|
data,
|
||||||
# Special case for UpscalerNone keeps it at the beginning of the list.
|
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||||
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||||
)
|
)
|
||||||
|
@ -341,7 +341,7 @@ class DDPM(pl.LightningModule):
|
|||||||
elif self.parameterization == "x0":
|
elif self.parameterization == "x0":
|
||||||
target = x_start
|
target = x_start
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
|
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
||||||
|
|
||||||
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
|
||||||
|
|
||||||
@ -901,7 +901,7 @@ class LatentDiffusion(DDPM):
|
|||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
# hybrid case, cond is exptected to be a dict
|
# hybrid case, cond is expected to be a dict
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if not isinstance(cond, list):
|
if not isinstance(cond, list):
|
||||||
@ -937,7 +937,7 @@ class LatentDiffusion(DDPM):
|
|||||||
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
|
||||||
|
|
||||||
elif self.cond_stage_key == 'coordinates_bbox':
|
elif self.cond_stage_key == 'coordinates_bbox':
|
||||||
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
|
assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'
|
||||||
|
|
||||||
# assuming padding of unfold is always 0 and its dilation is always 1
|
# assuming padding of unfold is always 0 and its dilation is always 1
|
||||||
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
|
||||||
@ -947,7 +947,7 @@ class LatentDiffusion(DDPM):
|
|||||||
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
num_downs = self.first_stage_model.encoder.num_resolutions - 1
|
||||||
rescale_latent = 2 ** (num_downs)
|
rescale_latent = 2 ** (num_downs)
|
||||||
|
|
||||||
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
|
# get top left positions of patches as conforming for the bbbox tokenizer, therefore we
|
||||||
# need to rescale the tl patch coordinates to be in between (0,1)
|
# need to rescale the tl patch coordinates to be in between (0,1)
|
||||||
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
|
||||||
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
|
||||||
|
@ -240,6 +240,9 @@ class Options:
|
|||||||
|
|
||||||
item_categories = {}
|
item_categories = {}
|
||||||
for item in self.data_labels.values():
|
for item in self.data_labels.values():
|
||||||
|
if item.section[0] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
category = categories.mapping.get(item.category_id)
|
category = categories.mapping.get(item.category_id)
|
||||||
category = "Uncategorized" if category is None else category.label
|
category = "Uncategorized" if category is None else category.label
|
||||||
if category not in item_categories:
|
if category not in item_categories:
|
||||||
|
@ -32,6 +32,6 @@ models_path = os.path.join(data_path, "models")
|
|||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
config_states_dir = os.path.join(script_path, "config_states")
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
default_output_dir = os.path.join(data_path, "output")
|
default_output_dir = os.path.join(data_path, "outputs")
|
||||||
|
|
||||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -20,10 +20,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = images.fix_image(img)
|
||||||
fn = ''
|
fn = ''
|
||||||
else:
|
else:
|
||||||
image = Image.open(os.path.abspath(img.name))
|
image = images.read(os.path.abspath(img.name))
|
||||||
fn = os.path.splitext(img.orig_name)[0]
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
yield image, fn
|
yield image, fn
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
@ -59,7 +59,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
if isinstance(image_placeholder, str):
|
if isinstance(image_placeholder, str):
|
||||||
try:
|
try:
|
||||||
image_data = Image.open(image_placeholder)
|
image_data = images.read(image_placeholder)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@ -69,7 +69,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
initial_pp = scripts_postprocessing.PostprocessedImage(image_data if image_data.mode in ("RGBA", "RGB") else image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(initial_pp, args)
|
scripts.scripts_postproc.run(initial_pp, args)
|
||||||
|
|
||||||
@ -125,8 +125,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(pp.image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
image_data.close()
|
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
@ -152,6 +152,7 @@ class StableDiffusionProcessing:
|
|||||||
seed_resize_from_w: int = -1
|
seed_resize_from_w: int = -1
|
||||||
seed_enable_extras: bool = True
|
seed_enable_extras: bool = True
|
||||||
sampler_name: str = None
|
sampler_name: str = None
|
||||||
|
scheduler: str = None
|
||||||
batch_size: int = 1
|
batch_size: int = 1
|
||||||
n_iter: int = 1
|
n_iter: int = 1
|
||||||
steps: int = 50
|
steps: int = 50
|
||||||
@ -702,7 +703,7 @@ def program_version():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None, all_hr_prompts=None, all_hr_negative_prompts=None):
|
||||||
if index is None:
|
if index is None:
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
@ -721,6 +722,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": p.sampler_name,
|
"Sampler": p.sampler_name,
|
||||||
|
"Schedule type": p.scheduler,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
@ -745,11 +747,18 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
"Tiling": "True" if p.tiling else None,
|
"Tiling": "True" if p.tiling else None,
|
||||||
|
"Hires prompt": None, # This is set later, insert here to keep order
|
||||||
|
"Hires negative prompt": None, # This is set later, insert here to keep order
|
||||||
**p.extra_generation_params,
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
"User": p.user if opts.add_user_name_to_info else None,
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if all_hr_prompts := all_hr_prompts or getattr(p, 'all_hr_prompts', None):
|
||||||
|
generation_params['Hires prompt'] = all_hr_prompts[index] if all_hr_prompts[index] != all_prompts[index] else None
|
||||||
|
if all_hr_negative_prompts := all_hr_negative_prompts or getattr(p, 'all_hr_negative_prompts', None):
|
||||||
|
generation_params['Hires negative prompt'] = all_hr_negative_prompts[index] if all_hr_negative_prompts[index] != all_negative_prompts[index] else None
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
||||||
@ -896,22 +905,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||||
|
|
||||||
|
p.setup_conds()
|
||||||
|
|
||||||
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
# params.txt should be saved after scripts.process_batch, since the
|
# params.txt should be saved after scripts.process_batch, since the
|
||||||
# infotext could be modified by that callback
|
# infotext could be modified by that callback
|
||||||
# Example: a wildcard processed by process_batch sets an extra model
|
# Example: a wildcard processed by process_batch sets an extra model
|
||||||
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
||||||
if n == 0:
|
if n == 0 and not cmd_opts.no_prompt_history:
|
||||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||||
processed = Processed(p, [])
|
processed = Processed(p, [])
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
p.setup_conds()
|
|
||||||
|
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
p.comment(comment)
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
|
||||||
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
@ -1106,6 +1115,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
hr_resize_y: int = 0
|
hr_resize_y: int = 0
|
||||||
hr_checkpoint_name: str = None
|
hr_checkpoint_name: str = None
|
||||||
hr_sampler_name: str = None
|
hr_sampler_name: str = None
|
||||||
|
hr_scheduler: str = None
|
||||||
hr_prompt: str = ''
|
hr_prompt: str = ''
|
||||||
hr_negative_prompt: str = ''
|
hr_negative_prompt: str = ''
|
||||||
force_task_id: str = None
|
force_task_id: str = None
|
||||||
@ -1194,11 +1204,10 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||||
|
|
||||||
if tuple(self.hr_prompt) != tuple(self.prompt):
|
self.extra_generation_params["Hires schedule type"] = None # to be set in sd_samplers_kdiffusion.py
|
||||||
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
|
||||||
|
|
||||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
if self.hr_scheduler is None:
|
||||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
self.hr_scheduler = self.scheduler
|
||||||
|
|
||||||
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
if self.enable_hr and self.latent_scale_mode is None:
|
if self.enable_hr and self.latent_scale_mode is None:
|
||||||
|
@ -26,6 +26,13 @@ class ScriptStripComments(scripts.Script):
|
|||||||
p.main_prompt = strip_comments(p.main_prompt)
|
p.main_prompt = strip_comments(p.main_prompt)
|
||||||
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
|
p.main_negative_prompt = strip_comments(p.main_negative_prompt)
|
||||||
|
|
||||||
|
if getattr(p, 'enable_hr', False):
|
||||||
|
p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts]
|
||||||
|
p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts]
|
||||||
|
|
||||||
|
p.hr_prompt = strip_comments(p.hr_prompt)
|
||||||
|
p.hr_negative_prompt = strip_comments(p.hr_negative_prompt)
|
||||||
|
|
||||||
|
|
||||||
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
|
def before_token_counter(params: script_callbacks.BeforeTokenCounterParams):
|
||||||
if not shared.opts.enable_prompt_comments:
|
if not shared.opts.enable_prompt_comments:
|
||||||
|
45
modules/processing_scripts/sampler.py
Normal file
45
modules/processing_scripts/sampler.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, sd_samplers, sd_schedulers, shared
|
||||||
|
from modules.infotext_utils import PasteField
|
||||||
|
from modules.ui_components import FormRow, FormGroup
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptSampler(scripts.ScriptBuiltinUI):
|
||||||
|
section = "sampler"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.steps = None
|
||||||
|
self.sampler_name = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Sampler"
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
sampler_names = [x.name for x in sd_samplers.visible_samplers()]
|
||||||
|
scheduler_names = [x.label for x in sd_schedulers.schedulers]
|
||||||
|
|
||||||
|
if shared.opts.samplers_in_dropdown:
|
||||||
|
with FormRow(elem_id=f"sampler_selection_{self.tabname}"):
|
||||||
|
self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||||
|
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||||
|
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||||
|
else:
|
||||||
|
with FormGroup(elem_id=f"sampler_selection_{self.tabname}"):
|
||||||
|
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
|
||||||
|
self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
|
||||||
|
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
PasteField(self.steps, "Steps", api="steps"),
|
||||||
|
PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"),
|
||||||
|
PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return self.steps, self.sampler_name, self.scheduler
|
||||||
|
|
||||||
|
def setup(self, p, steps, sampler_name, scheduler):
|
||||||
|
p.steps = steps
|
||||||
|
p.sampler_name = sampler_name
|
||||||
|
p.scheduler = scheduler
|
@ -34,7 +34,7 @@ def randn_local(seed, shape):
|
|||||||
|
|
||||||
|
|
||||||
def randn_like(x):
|
def randn_like(x):
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ def randn_like(x):
|
|||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape, generator=None):
|
def randn_without_seed(shape, generator=None):
|
||||||
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
|
||||||
|
|
||||||
Use either randn() or manual_seed() to initialize the generator."""
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from gradio import Blocks
|
from gradio import Blocks
|
||||||
|
|
||||||
from modules import errors, timer
|
from modules import errors, timer, extensions, shared, util
|
||||||
|
|
||||||
|
|
||||||
def report_exception(c, job):
|
def report_exception(c, job):
|
||||||
@ -116,7 +117,105 @@ class BeforeTokenCounterParams:
|
|||||||
is_positive: bool = True
|
is_positive: bool = True
|
||||||
|
|
||||||
|
|
||||||
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
|
@dataclasses.dataclass
|
||||||
|
class ScriptCallback:
|
||||||
|
script: str
|
||||||
|
callback: any
|
||||||
|
name: str = "unnamed"
|
||||||
|
|
||||||
|
|
||||||
|
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
|
||||||
|
if filename is None:
|
||||||
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
|
filename = stack[0].filename if stack else 'unknown file'
|
||||||
|
|
||||||
|
extension = extensions.find_extension(filename)
|
||||||
|
extension_name = extension.canonical_name if extension else 'base'
|
||||||
|
|
||||||
|
callback_name = f"{extension_name}/{os.path.basename(filename)}/{category}"
|
||||||
|
if name is not None:
|
||||||
|
callback_name += f'/{name}'
|
||||||
|
|
||||||
|
unique_callback_name = callback_name
|
||||||
|
for index in range(1000):
|
||||||
|
existing = any(x.name == unique_callback_name for x in callbacks)
|
||||||
|
if not existing:
|
||||||
|
break
|
||||||
|
|
||||||
|
unique_callback_name = f'{callback_name}-{index+1}'
|
||||||
|
|
||||||
|
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
|
||||||
|
|
||||||
|
|
||||||
|
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
|
||||||
|
callbacks = unordered_callbacks.copy()
|
||||||
|
callback_lookup = {x.name: x for x in callbacks}
|
||||||
|
dependencies = {}
|
||||||
|
|
||||||
|
order_instructions = {}
|
||||||
|
for extension in extensions.extensions:
|
||||||
|
for order_instruction in extension.metadata.list_callback_order_instructions():
|
||||||
|
if order_instruction.name in callback_lookup:
|
||||||
|
if order_instruction.name not in order_instructions:
|
||||||
|
order_instructions[order_instruction.name] = []
|
||||||
|
|
||||||
|
order_instructions[order_instruction.name].append(order_instruction)
|
||||||
|
|
||||||
|
if order_instructions:
|
||||||
|
for callback in callbacks:
|
||||||
|
dependencies[callback.name] = []
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
for order_instruction in order_instructions.get(callback.name, []):
|
||||||
|
for after in order_instruction.after:
|
||||||
|
if after not in callback_lookup:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependencies[callback.name].append(after)
|
||||||
|
|
||||||
|
for before in order_instruction.before:
|
||||||
|
if before not in callback_lookup:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dependencies[before].append(callback.name)
|
||||||
|
|
||||||
|
sorted_names = util.topological_sort(dependencies)
|
||||||
|
callbacks = [callback_lookup[x] for x in sorted_names]
|
||||||
|
|
||||||
|
if enable_user_sort:
|
||||||
|
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
|
||||||
|
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
|
||||||
|
if index is not None:
|
||||||
|
callbacks.insert(0, callbacks.pop(index))
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
|
||||||
|
if unordered_callbacks is None:
|
||||||
|
unordered_callbacks = callback_map.get('callbacks_' + category, [])
|
||||||
|
|
||||||
|
if not enable_user_sort:
|
||||||
|
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
|
||||||
|
|
||||||
|
callbacks = ordered_callbacks_map.get(category)
|
||||||
|
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
callbacks = sort_callbacks(category, unordered_callbacks)
|
||||||
|
|
||||||
|
ordered_callbacks_map[category] = callbacks
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def enumerate_callbacks():
|
||||||
|
for category, callbacks in callback_map.items():
|
||||||
|
if category.startswith('callbacks_'):
|
||||||
|
category = category[10:]
|
||||||
|
|
||||||
|
yield category, callbacks
|
||||||
|
|
||||||
|
|
||||||
callback_map = dict(
|
callback_map = dict(
|
||||||
callbacks_app_started=[],
|
callbacks_app_started=[],
|
||||||
callbacks_model_loaded=[],
|
callbacks_model_loaded=[],
|
||||||
@ -141,14 +240,18 @@ callback_map = dict(
|
|||||||
callbacks_before_token_counter=[],
|
callbacks_before_token_counter=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ordered_callbacks_map = {}
|
||||||
|
|
||||||
|
|
||||||
def clear_callbacks():
|
def clear_callbacks():
|
||||||
for callback_list in callback_map.values():
|
for callback_list in callback_map.values():
|
||||||
callback_list.clear()
|
callback_list.clear()
|
||||||
|
|
||||||
|
ordered_callbacks_map.clear()
|
||||||
|
|
||||||
|
|
||||||
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
||||||
for c in callback_map['callbacks_app_started']:
|
for c in ordered_callbacks('app_started'):
|
||||||
try:
|
try:
|
||||||
c.callback(demo, app)
|
c.callback(demo, app)
|
||||||
timer.startup_timer.record(os.path.basename(c.script))
|
timer.startup_timer.record(os.path.basename(c.script))
|
||||||
@ -157,7 +260,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
|||||||
|
|
||||||
|
|
||||||
def app_reload_callback():
|
def app_reload_callback():
|
||||||
for c in callback_map['callbacks_on_reload']:
|
for c in ordered_callbacks('on_reload'):
|
||||||
try:
|
try:
|
||||||
c.callback()
|
c.callback()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -165,7 +268,7 @@ def app_reload_callback():
|
|||||||
|
|
||||||
|
|
||||||
def model_loaded_callback(sd_model):
|
def model_loaded_callback(sd_model):
|
||||||
for c in callback_map['callbacks_model_loaded']:
|
for c in ordered_callbacks('model_loaded'):
|
||||||
try:
|
try:
|
||||||
c.callback(sd_model)
|
c.callback(sd_model)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -175,7 +278,7 @@ def model_loaded_callback(sd_model):
|
|||||||
def ui_tabs_callback():
|
def ui_tabs_callback():
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for c in callback_map['callbacks_ui_tabs']:
|
for c in ordered_callbacks('ui_tabs'):
|
||||||
try:
|
try:
|
||||||
res += c.callback() or []
|
res += c.callback() or []
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -185,7 +288,7 @@ def ui_tabs_callback():
|
|||||||
|
|
||||||
|
|
||||||
def ui_train_tabs_callback(params: UiTrainTabParams):
|
def ui_train_tabs_callback(params: UiTrainTabParams):
|
||||||
for c in callback_map['callbacks_ui_train_tabs']:
|
for c in ordered_callbacks('ui_train_tabs'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -193,7 +296,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
|
|||||||
|
|
||||||
|
|
||||||
def ui_settings_callback():
|
def ui_settings_callback():
|
||||||
for c in callback_map['callbacks_ui_settings']:
|
for c in ordered_callbacks('ui_settings'):
|
||||||
try:
|
try:
|
||||||
c.callback()
|
c.callback()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -201,7 +304,7 @@ def ui_settings_callback():
|
|||||||
|
|
||||||
|
|
||||||
def before_image_saved_callback(params: ImageSaveParams):
|
def before_image_saved_callback(params: ImageSaveParams):
|
||||||
for c in callback_map['callbacks_before_image_saved']:
|
for c in ordered_callbacks('before_image_saved'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -209,7 +312,7 @@ def before_image_saved_callback(params: ImageSaveParams):
|
|||||||
|
|
||||||
|
|
||||||
def image_saved_callback(params: ImageSaveParams):
|
def image_saved_callback(params: ImageSaveParams):
|
||||||
for c in callback_map['callbacks_image_saved']:
|
for c in ordered_callbacks('image_saved'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -217,7 +320,7 @@ def image_saved_callback(params: ImageSaveParams):
|
|||||||
|
|
||||||
|
|
||||||
def extra_noise_callback(params: ExtraNoiseParams):
|
def extra_noise_callback(params: ExtraNoiseParams):
|
||||||
for c in callback_map['callbacks_extra_noise']:
|
for c in ordered_callbacks('extra_noise'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -225,7 +328,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
|
|||||||
|
|
||||||
|
|
||||||
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
def cfg_denoiser_callback(params: CFGDenoiserParams):
|
||||||
for c in callback_map['callbacks_cfg_denoiser']:
|
for c in ordered_callbacks('cfg_denoiser'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -233,7 +336,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
|
|||||||
|
|
||||||
|
|
||||||
def cfg_denoised_callback(params: CFGDenoisedParams):
|
def cfg_denoised_callback(params: CFGDenoisedParams):
|
||||||
for c in callback_map['callbacks_cfg_denoised']:
|
for c in ordered_callbacks('cfg_denoised'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -241,7 +344,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
|||||||
|
|
||||||
|
|
||||||
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||||
for c in callback_map['callbacks_cfg_after_cfg']:
|
for c in ordered_callbacks('cfg_after_cfg'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -249,7 +352,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
|||||||
|
|
||||||
|
|
||||||
def before_component_callback(component, **kwargs):
|
def before_component_callback(component, **kwargs):
|
||||||
for c in callback_map['callbacks_before_component']:
|
for c in ordered_callbacks('before_component'):
|
||||||
try:
|
try:
|
||||||
c.callback(component, **kwargs)
|
c.callback(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -257,7 +360,7 @@ def before_component_callback(component, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def after_component_callback(component, **kwargs):
|
def after_component_callback(component, **kwargs):
|
||||||
for c in callback_map['callbacks_after_component']:
|
for c in ordered_callbacks('after_component'):
|
||||||
try:
|
try:
|
||||||
c.callback(component, **kwargs)
|
c.callback(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -265,7 +368,7 @@ def after_component_callback(component, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def image_grid_callback(params: ImageGridLoopParams):
|
def image_grid_callback(params: ImageGridLoopParams):
|
||||||
for c in callback_map['callbacks_image_grid']:
|
for c in ordered_callbacks('image_grid'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -273,7 +376,7 @@ def image_grid_callback(params: ImageGridLoopParams):
|
|||||||
|
|
||||||
|
|
||||||
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
||||||
for c in callback_map['callbacks_infotext_pasted']:
|
for c in ordered_callbacks('infotext_pasted'):
|
||||||
try:
|
try:
|
||||||
c.callback(infotext, params)
|
c.callback(infotext, params)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -281,7 +384,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
|||||||
|
|
||||||
|
|
||||||
def script_unloaded_callback():
|
def script_unloaded_callback():
|
||||||
for c in reversed(callback_map['callbacks_script_unloaded']):
|
for c in reversed(ordered_callbacks('script_unloaded')):
|
||||||
try:
|
try:
|
||||||
c.callback()
|
c.callback()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -289,7 +392,7 @@ def script_unloaded_callback():
|
|||||||
|
|
||||||
|
|
||||||
def before_ui_callback():
|
def before_ui_callback():
|
||||||
for c in reversed(callback_map['callbacks_before_ui']):
|
for c in reversed(ordered_callbacks('before_ui')):
|
||||||
try:
|
try:
|
||||||
c.callback()
|
c.callback()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -299,7 +402,7 @@ def before_ui_callback():
|
|||||||
def list_optimizers_callback():
|
def list_optimizers_callback():
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for c in callback_map['callbacks_list_optimizers']:
|
for c in ordered_callbacks('list_optimizers'):
|
||||||
try:
|
try:
|
||||||
c.callback(res)
|
c.callback(res)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -311,7 +414,7 @@ def list_optimizers_callback():
|
|||||||
def list_unets_callback():
|
def list_unets_callback():
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for c in callback_map['callbacks_list_unets']:
|
for c in ordered_callbacks('list_unets'):
|
||||||
try:
|
try:
|
||||||
c.callback(res)
|
c.callback(res)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -321,20 +424,13 @@ def list_unets_callback():
|
|||||||
|
|
||||||
|
|
||||||
def before_token_counter_callback(params: BeforeTokenCounterParams):
|
def before_token_counter_callback(params: BeforeTokenCounterParams):
|
||||||
for c in callback_map['callbacks_before_token_counter']:
|
for c in ordered_callbacks('before_token_counter'):
|
||||||
try:
|
try:
|
||||||
c.callback(params)
|
c.callback(params)
|
||||||
except Exception:
|
except Exception:
|
||||||
report_exception(c, 'before_token_counter')
|
report_exception(c, 'before_token_counter')
|
||||||
|
|
||||||
|
|
||||||
def add_callback(callbacks, fun):
|
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
|
||||||
filename = stack[0].filename if stack else 'unknown file'
|
|
||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if stack else 'unknown file'
|
filename = stack[0].filename if stack else 'unknown file'
|
||||||
@ -351,24 +447,24 @@ def remove_callbacks_for_function(callback_func):
|
|||||||
callback_list.remove(callback_to_remove)
|
callback_list.remove(callback_to_remove)
|
||||||
|
|
||||||
|
|
||||||
def on_app_started(callback):
|
def on_app_started(callback, *, name=None):
|
||||||
"""register a function to be called when the webui started, the gradio `Block` component and
|
"""register a function to be called when the webui started, the gradio `Block` component and
|
||||||
fastapi `FastAPI` object are passed as the arguments"""
|
fastapi `FastAPI` object are passed as the arguments"""
|
||||||
add_callback(callback_map['callbacks_app_started'], callback)
|
add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')
|
||||||
|
|
||||||
|
|
||||||
def on_before_reload(callback):
|
def on_before_reload(callback, *, name=None):
|
||||||
"""register a function to be called just before the server reloads."""
|
"""register a function to be called just before the server reloads."""
|
||||||
add_callback(callback_map['callbacks_on_reload'], callback)
|
add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')
|
||||||
|
|
||||||
|
|
||||||
def on_model_loaded(callback):
|
def on_model_loaded(callback, *, name=None):
|
||||||
"""register a function to be called when the stable diffusion model is created; the model is
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
passed as an argument; this function is also called when the script is reloaded. """
|
passed as an argument; this function is also called when the script is reloaded. """
|
||||||
add_callback(callback_map['callbacks_model_loaded'], callback)
|
add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')
|
||||||
|
|
||||||
|
|
||||||
def on_ui_tabs(callback):
|
def on_ui_tabs(callback, *, name=None):
|
||||||
"""register a function to be called when the UI is creating new tabs.
|
"""register a function to be called when the UI is creating new tabs.
|
||||||
The function must either return a None, which means no new tabs to be added, or a list, where
|
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||||
each element is a tuple:
|
each element is a tuple:
|
||||||
@ -378,71 +474,71 @@ def on_ui_tabs(callback):
|
|||||||
title is tab text displayed to user in the UI
|
title is tab text displayed to user in the UI
|
||||||
elem_id is HTML id for the tab
|
elem_id is HTML id for the tab
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_ui_tabs'], callback)
|
add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')
|
||||||
|
|
||||||
|
|
||||||
def on_ui_train_tabs(callback):
|
def on_ui_train_tabs(callback, *, name=None):
|
||||||
"""register a function to be called when the UI is creating new tabs for the train tab.
|
"""register a function to be called when the UI is creating new tabs for the train tab.
|
||||||
Create your new tabs with gr.Tab.
|
Create your new tabs with gr.Tab.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_ui_train_tabs'], callback)
|
add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings(callback):
|
def on_ui_settings(callback, *, name=None):
|
||||||
"""register a function to be called before UI settings are populated; add your settings
|
"""register a function to be called before UI settings are populated; add your settings
|
||||||
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
by using shared.opts.add_option(shared.OptionInfo(...)) """
|
||||||
add_callback(callback_map['callbacks_ui_settings'], callback)
|
add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')
|
||||||
|
|
||||||
|
|
||||||
def on_before_image_saved(callback):
|
def on_before_image_saved(callback, *, name=None):
|
||||||
"""register a function to be called before an image is saved to a file.
|
"""register a function to be called before an image is saved to a file.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_before_image_saved'], callback)
|
add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')
|
||||||
|
|
||||||
|
|
||||||
def on_image_saved(callback):
|
def on_image_saved(callback, *, name=None):
|
||||||
"""register a function to be called after an image is saved to a file.
|
"""register a function to be called after an image is saved to a file.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_image_saved'], callback)
|
add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')
|
||||||
|
|
||||||
|
|
||||||
def on_extra_noise(callback):
|
def on_extra_noise(callback, *, name=None):
|
||||||
"""register a function to be called before adding extra noise in img2img or hires fix;
|
"""register a function to be called before adding extra noise in img2img or hires fix;
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
- params: ExtraNoiseParams - contains noise determined by seed and latent representation of image
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_extra_noise'], callback)
|
add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')
|
||||||
|
|
||||||
|
|
||||||
def on_cfg_denoiser(callback):
|
def on_cfg_denoiser(callback, *, name=None):
|
||||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
|
add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')
|
||||||
|
|
||||||
|
|
||||||
def on_cfg_denoised(callback):
|
def on_cfg_denoised(callback, *, name=None):
|
||||||
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
- params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')
|
||||||
|
|
||||||
|
|
||||||
def on_cfg_after_cfg(callback):
|
def on_cfg_after_cfg(callback, *, name=None):
|
||||||
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')
|
||||||
|
|
||||||
|
|
||||||
def on_before_component(callback):
|
def on_before_component(callback, *, name=None):
|
||||||
"""register a function to be called before a component is created.
|
"""register a function to be called before a component is created.
|
||||||
The callback is called with arguments:
|
The callback is called with arguments:
|
||||||
- component - gradio component that is about to be created.
|
- component - gradio component that is about to be created.
|
||||||
@ -451,61 +547,61 @@ def on_before_component(callback):
|
|||||||
Use elem_id/label fields of kwargs to figure out which component it is.
|
Use elem_id/label fields of kwargs to figure out which component it is.
|
||||||
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
This can be useful to inject your own components somewhere in the middle of vanilla UI.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_before_component'], callback)
|
add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')
|
||||||
|
|
||||||
|
|
||||||
def on_after_component(callback):
|
def on_after_component(callback, *, name=None):
|
||||||
"""register a function to be called after a component is created. See on_before_component for more."""
|
"""register a function to be called after a component is created. See on_before_component for more."""
|
||||||
add_callback(callback_map['callbacks_after_component'], callback)
|
add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')
|
||||||
|
|
||||||
|
|
||||||
def on_image_grid(callback):
|
def on_image_grid(callback, *, name=None):
|
||||||
"""register a function to be called before making an image grid.
|
"""register a function to be called before making an image grid.
|
||||||
The callback is called with one argument:
|
The callback is called with one argument:
|
||||||
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
- params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_image_grid'], callback)
|
add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')
|
||||||
|
|
||||||
|
|
||||||
def on_infotext_pasted(callback):
|
def on_infotext_pasted(callback, *, name=None):
|
||||||
"""register a function to be called before applying an infotext.
|
"""register a function to be called before applying an infotext.
|
||||||
The callback is called with two arguments:
|
The callback is called with two arguments:
|
||||||
- infotext: str - raw infotext.
|
- infotext: str - raw infotext.
|
||||||
- result: dict[str, any] - parsed infotext parameters.
|
- result: dict[str, any] - parsed infotext parameters.
|
||||||
"""
|
"""
|
||||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')
|
||||||
|
|
||||||
|
|
||||||
def on_script_unloaded(callback):
|
def on_script_unloaded(callback, *, name=None):
|
||||||
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
"""register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
|
||||||
the script did should be reverted here"""
|
the script did should be reverted here"""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_script_unloaded'], callback)
|
add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')
|
||||||
|
|
||||||
|
|
||||||
def on_before_ui(callback):
|
def on_before_ui(callback, *, name=None):
|
||||||
"""register a function to be called before the UI is created."""
|
"""register a function to be called before the UI is created."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_before_ui'], callback)
|
add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')
|
||||||
|
|
||||||
|
|
||||||
def on_list_optimizers(callback):
|
def on_list_optimizers(callback, *, name=None):
|
||||||
"""register a function to be called when UI is making a list of cross attention optimization options.
|
"""register a function to be called when UI is making a list of cross attention optimization options.
|
||||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
|
||||||
to it."""
|
to it."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')
|
||||||
|
|
||||||
|
|
||||||
def on_list_unets(callback):
|
def on_list_unets(callback, *, name=None):
|
||||||
"""register a function to be called when UI is making a list of alternative options for unet.
|
"""register a function to be called when UI is making a list of alternative options for unet.
|
||||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_list_unets'], callback)
|
add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')
|
||||||
|
|
||||||
|
|
||||||
def on_before_token_counter(callback):
|
def on_before_token_counter(callback, *, name=None):
|
||||||
"""register a function to be called when UI is counting tokens for a prompt.
|
"""register a function to be called when UI is counting tokens for a prompt.
|
||||||
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_before_token_counter'], callback)
|
add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')
|
||||||
|
@ -7,7 +7,9 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
|
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util
|
||||||
|
|
||||||
|
topological_sort = util.topological_sort
|
||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
@ -92,7 +94,7 @@ class Script:
|
|||||||
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
"""A list of controls retured by the ui()."""
|
"""A list of controls returned by the ui()."""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
@ -109,7 +111,7 @@ class Script:
|
|||||||
|
|
||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
"""
|
"""
|
||||||
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
is_img2img is True if this function is called for the img2img interface, and False otherwise
|
||||||
|
|
||||||
This function should return:
|
This function should return:
|
||||||
- False if the script should not be shown in UI at all
|
- False if the script should not be shown in UI at all
|
||||||
@ -138,7 +140,6 @@ class Script:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def before_process(self, p, *args):
|
def before_process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||||
@ -351,6 +352,9 @@ class ScriptBuiltinUI(Script):
|
|||||||
|
|
||||||
return f'{tabname}{item_id}'
|
return f'{tabname}{item_id}'
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return AlwaysVisible
|
||||||
|
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
@ -369,29 +373,6 @@ scripts_data = []
|
|||||||
postprocessing_scripts_data = []
|
postprocessing_scripts_data = []
|
||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
def topological_sort(dependencies):
|
|
||||||
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
|
||||||
Ignores errors relating to missing dependeencies or circular dependencies
|
|
||||||
"""
|
|
||||||
|
|
||||||
visited = {}
|
|
||||||
result = []
|
|
||||||
|
|
||||||
def inner(name):
|
|
||||||
visited[name] = True
|
|
||||||
|
|
||||||
for dep in dependencies.get(name, []):
|
|
||||||
if dep in dependencies and dep not in visited:
|
|
||||||
inner(dep)
|
|
||||||
|
|
||||||
result.append(name)
|
|
||||||
|
|
||||||
for depname in dependencies:
|
|
||||||
if depname not in visited:
|
|
||||||
inner(depname)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScriptWithDependencies:
|
class ScriptWithDependencies:
|
||||||
@ -562,6 +543,25 @@ class ScriptRunner:
|
|||||||
self.paste_field_names = []
|
self.paste_field_names = []
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
|
self.callback_map = {}
|
||||||
|
self.callback_names = [
|
||||||
|
'before_process',
|
||||||
|
'process',
|
||||||
|
'before_process_batch',
|
||||||
|
'after_extra_networks_activate',
|
||||||
|
'process_batch',
|
||||||
|
'postprocess',
|
||||||
|
'postprocess_batch',
|
||||||
|
'postprocess_batch_list',
|
||||||
|
'post_sample',
|
||||||
|
'on_mask_blend',
|
||||||
|
'postprocess_image',
|
||||||
|
'postprocess_maskoverlay',
|
||||||
|
'postprocess_image_after_composite',
|
||||||
|
'before_component',
|
||||||
|
'after_component',
|
||||||
|
]
|
||||||
|
|
||||||
self.on_before_component_elem_id = {}
|
self.on_before_component_elem_id = {}
|
||||||
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
@ -600,6 +600,8 @@ class ScriptRunner:
|
|||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
|
self.callback_map.clear()
|
||||||
|
|
||||||
self.apply_on_before_component_callbacks()
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
def apply_on_before_component_callbacks(self):
|
def apply_on_before_component_callbacks(self):
|
||||||
@ -769,8 +771,42 @@ class ScriptRunner:
|
|||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
def list_scripts_for_method(self, method_name):
|
||||||
|
if method_name in ('before_component', 'after_component'):
|
||||||
|
return self.scripts
|
||||||
|
else:
|
||||||
|
return self.alwayson_scripts
|
||||||
|
|
||||||
|
def create_ordered_callbacks_list(self, method_name, *, enable_user_sort=True):
|
||||||
|
script_list = self.list_scripts_for_method(method_name)
|
||||||
|
category = f'script_{method_name}'
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
|
for script in script_list:
|
||||||
|
if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):
|
||||||
|
continue
|
||||||
|
|
||||||
|
script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)
|
||||||
|
|
||||||
|
return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)
|
||||||
|
|
||||||
|
def ordered_callbacks(self, method_name, *, enable_user_sort=True):
|
||||||
|
script_list = self.list_scripts_for_method(method_name)
|
||||||
|
category = f'script_{method_name}'
|
||||||
|
|
||||||
|
scrpts_len, callbacks = self.callback_map.get(category, (-1, None))
|
||||||
|
|
||||||
|
if callbacks is None or scrpts_len != len(script_list):
|
||||||
|
callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)
|
||||||
|
self.callback_map[category] = len(script_list), callbacks
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
def ordered_scripts(self, method_name):
|
||||||
|
return [x.callback for x in self.ordered_callbacks(method_name)]
|
||||||
|
|
||||||
def before_process(self, p):
|
def before_process(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('before_process'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.before_process(p, *script_args)
|
script.before_process(p, *script_args)
|
||||||
@ -778,7 +814,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_process: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def process(self, p):
|
def process(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('process'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process(p, *script_args)
|
script.process(p, *script_args)
|
||||||
@ -786,7 +822,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running process: {script.filename}", exc_info=True)
|
errors.report(f"Error running process: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_process_batch(self, p, **kwargs):
|
def before_process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('before_process_batch'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.before_process_batch(p, *script_args, **kwargs)
|
script.before_process_batch(p, *script_args, **kwargs)
|
||||||
@ -794,7 +830,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def after_extra_networks_activate(self, p, **kwargs):
|
def after_extra_networks_activate(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('after_extra_networks_activate'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
script.after_extra_networks_activate(p, *script_args, **kwargs)
|
||||||
@ -802,7 +838,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def process_batch(self, p, **kwargs):
|
def process_batch(self, p, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('process_batch'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.process_batch(p, *script_args, **kwargs)
|
script.process_batch(p, *script_args, **kwargs)
|
||||||
@ -810,7 +846,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess(self, p, processed):
|
def postprocess(self, p, processed):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess(p, processed, *script_args)
|
script.postprocess(p, processed, *script_args)
|
||||||
@ -818,7 +854,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_batch(self, p, images, **kwargs):
|
def postprocess_batch(self, p, images, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess_batch'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||||
@ -826,7 +862,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess_batch_list'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
|
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
|
||||||
@ -834,7 +870,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def post_sample(self, p, ps: PostSampleArgs):
|
def post_sample(self, p, ps: PostSampleArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('post_sample'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.post_sample(p, ps, *script_args)
|
script.post_sample(p, ps, *script_args)
|
||||||
@ -842,7 +878,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('on_mask_blend'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.on_mask_blend(p, mba, *script_args)
|
script.on_mask_blend(p, mba, *script_args)
|
||||||
@ -850,7 +886,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess_image'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_image(p, pp, *script_args)
|
script.postprocess_image(p, pp, *script_args)
|
||||||
@ -858,7 +894,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess_maskoverlay'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
||||||
@ -866,7 +902,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
|
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('postprocess_image_after_composite'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.postprocess_image_after_composite(p, pp, *script_args)
|
script.postprocess_image_after_composite(p, pp, *script_args)
|
||||||
@ -880,7 +916,7 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.ordered_scripts('before_component'):
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -893,7 +929,7 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.ordered_scripts('after_component'):
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -921,7 +957,7 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
def before_hr(self, p):
|
def before_hr(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('before_hr'):
|
||||||
try:
|
try:
|
||||||
script_args = p.script_args[script.args_from:script.args_to]
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
script.before_hr(p, *script_args)
|
script.before_hr(p, *script_args)
|
||||||
@ -929,7 +965,7 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def setup_scrips(self, p, *, is_ui=True):
|
def setup_scrips(self, p, *, is_ui=True):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.ordered_scripts('setup'):
|
||||||
if not is_ui and script.setup_for_ui_only:
|
if not is_ui and script.setup_for_ui_only:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ class EmphasisIgnore(Emphasis):
|
|||||||
|
|
||||||
class EmphasisOriginal(Emphasis):
|
class EmphasisOriginal(Emphasis):
|
||||||
name = "Original"
|
name = "Original"
|
||||||
description = "the orginal emphasis implementation"
|
description = "the original emphasis implementation"
|
||||||
|
|
||||||
def after_transformers(self):
|
def after_transformers(self):
|
||||||
original_mean = self.z.mean()
|
original_mean = self.z.mean()
|
||||||
@ -48,7 +48,7 @@ class EmphasisOriginal(Emphasis):
|
|||||||
|
|
||||||
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
class EmphasisOriginalNoNorm(EmphasisOriginal):
|
||||||
name = "No norm"
|
name = "No norm"
|
||||||
description = "same as orginal, but without normalization (seems to work better for SDXL)"
|
description = "same as original, but without normalization (seems to work better for SDXL)"
|
||||||
|
|
||||||
def after_transformers(self):
|
def after_transformers(self):
|
||||||
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)
|
||||||
|
@ -23,7 +23,7 @@ class PromptChunk:
|
|||||||
|
|
||||||
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
|
||||||
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
|
||||||
chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
|
||||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
|
|
||||||
def encode_with_transformers(self, tokens):
|
def encode_with_transformers(self, tokens):
|
||||||
"""
|
"""
|
||||||
converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
|
converts a batch of token ids (in python lists) into a single tensor with numeric representation of those tokens;
|
||||||
All python lists with tokens are assumed to have same length, usually 77.
|
All python lists with tokens are assumed to have same length, usually 77.
|
||||||
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
|
||||||
model - can be 768 and 1024.
|
model - can be 768 and 1024.
|
||||||
@ -136,7 +136,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
if token == self.comma_token:
|
if token == self.comma_token:
|
||||||
last_comma = len(chunk.tokens)
|
last_comma = len(chunk.tokens)
|
||||||
|
|
||||||
# this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
# this is when we are at the end of allotted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
|
||||||
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
# is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
|
||||||
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
break_location = last_comma + 1
|
break_location = last_comma + 1
|
||||||
@ -206,7 +206,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||||
An example shape returned by this function can be: (2, 77, 768).
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one element
|
||||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for _position, embedding in fixes:
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
devices.torch_npu_set_device()
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
zs.append(z)
|
zs.append(z)
|
||||||
|
|
||||||
|
@ -784,7 +784,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||||||
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||||
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||||
If no such model exists, returns None.
|
If no such model exists, returns None.
|
||||||
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
already_loaded = None
|
already_loaded = None
|
||||||
|
@ -13,8 +13,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
for embedder in self.conditioner.embedders:
|
for embedder in self.conditioner.embedders:
|
||||||
embedder.ucg_rate = 0.0
|
embedder.ucg_rate = 0.0
|
||||||
|
|
||||||
width = getattr(batch, 'width', 1024)
|
width = getattr(batch, 'width', 1024) or 1024
|
||||||
height = getattr(batch, 'height', 1024)
|
height = getattr(batch, 'height', 1024) or 1024
|
||||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
|
||||||
|
sample_to_image = sd_samplers_common.sample_to_image
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
@ -10,8 +15,8 @@ all_samplers = [
|
|||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
samplers = []
|
samplers: list[sd_samplers_common.SamplerData] = []
|
||||||
samplers_for_img2img = []
|
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
samplers_hidden = {}
|
samplers_hidden = {}
|
||||||
|
|
||||||
@ -57,4 +62,64 @@ def visible_sampler_names():
|
|||||||
return [x.name for x in samplers if x.name not in samplers_hidden]
|
return [x.name for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def visible_samplers():
|
||||||
|
return [x for x in samplers if x.name not in samplers_hidden]
|
||||||
|
|
||||||
|
|
||||||
|
def get_sampler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler_from_infotext(d: dict):
|
||||||
|
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_sampler_and_scheduler(d: dict):
|
||||||
|
hr_sampler = d.get("Hires sampler", "Use same sampler")
|
||||||
|
sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
|
||||||
|
|
||||||
|
hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
|
||||||
|
scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
|
||||||
|
|
||||||
|
sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
|
||||||
|
|
||||||
|
sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
|
||||||
|
scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
|
||||||
|
|
||||||
|
return sampler, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_sampler_from_infotext(d: dict):
|
||||||
|
return get_hr_sampler_and_scheduler(d)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hr_scheduler_from_infotext(d: dict):
|
||||||
|
return get_hr_sampler_and_scheduler(d)[1]
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_sampler_and_scheduler(sampler_name, scheduler_name):
|
||||||
|
default_sampler = samplers[0]
|
||||||
|
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
|
||||||
|
|
||||||
|
name = sampler_name or default_sampler.name
|
||||||
|
|
||||||
|
for scheduler in sd_schedulers.schedulers:
|
||||||
|
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
|
||||||
|
|
||||||
|
for name_option in name_options:
|
||||||
|
if name.endswith(" " + name_option):
|
||||||
|
found_scheduler = scheduler
|
||||||
|
name = name[0:-(len(name_option) + 1)]
|
||||||
|
break
|
||||||
|
|
||||||
|
sampler = all_samplers_map.get(name, default_sampler)
|
||||||
|
|
||||||
|
# revert back to Automatic if it's the default scheduler for the selected sampler
|
||||||
|
if sampler.options.get('scheduler', None) == found_scheduler.name:
|
||||||
|
found_scheduler = sd_schedulers.schedulers[0]
|
||||||
|
|
||||||
|
return sampler.name, found_scheduler.label
|
||||||
|
|
||||||
|
|
||||||
set_samplers()
|
set_samplers()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers
|
||||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||||
|
|
||||||
@ -9,32 +9,20 @@ from modules.shared import opts
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
|
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
|
||||||
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
|
||||||
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
|
||||||
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -58,12 +46,7 @@ sampler_extra_params = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
k_diffusion_scheduler = {
|
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
|
||||||
'Automatic': None,
|
|
||||||
'karras': k_diffusion.sampling.get_sigmas_karras,
|
|
||||||
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
|
||||||
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
@ -96,42 +79,43 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
steps += 1 if discard_next_to_last_sigma else 0
|
steps += 1 if discard_next_to_last_sigma else 0
|
||||||
|
|
||||||
|
scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'
|
||||||
|
if scheduler_name == 'Automatic':
|
||||||
|
scheduler_name = self.config.options.get('scheduler', None)
|
||||||
|
|
||||||
|
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
|
||||||
|
|
||||||
|
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
|
||||||
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
elif opts.k_sched_type != "Automatic":
|
elif scheduler is None or scheduler.function is None:
|
||||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
|
else:
|
||||||
sigmas_kwargs = {
|
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
|
||||||
'sigma_min': sigma_min,
|
|
||||||
'sigma_max': sigma_max,
|
|
||||||
}
|
|
||||||
|
|
||||||
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
|
if scheduler.label != 'Automatic' and not p.is_hr_pass:
|
||||||
p.extra_generation_params["Schedule type"] = opts.k_sched_type
|
p.extra_generation_params["Schedule type"] = scheduler.label
|
||||||
|
elif scheduler.label != p.extra_generation_params.get("Schedule type"):
|
||||||
|
p.extra_generation_params["Hires schedule type"] = scheduler.label
|
||||||
|
|
||||||
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
|
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
|
||||||
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
sigmas_kwargs['sigma_min'] = opts.sigma_min
|
||||||
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
|
||||||
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
|
|
||||||
|
if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
|
||||||
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
sigmas_kwargs['sigma_max'] = opts.sigma_max
|
||||||
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
|
||||||
|
|
||||||
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
|
if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
|
||||||
|
|
||||||
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
|
||||||
sigmas_kwargs['rho'] = opts.rho
|
sigmas_kwargs['rho'] = opts.rho
|
||||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||||
|
|
||||||
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
if scheduler.need_inner_model:
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
||||||
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=shared.device)
|
||||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
|
||||||
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
|
||||||
else:
|
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
|
||||||
|
|
||||||
if discard_next_to_last_sigma:
|
if discard_next_to_last_sigma:
|
||||||
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
||||||
|
43
modules/sd_schedulers.py
Normal file
43
modules/sd_schedulers.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import dataclasses
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import k_diffusion
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Scheduler:
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
function: any
|
||||||
|
|
||||||
|
default_rho: float = -1
|
||||||
|
need_inner_model: bool = False
|
||||||
|
aliases: list = None
|
||||||
|
|
||||||
|
|
||||||
|
def uniform(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
return inner_model.get_sigmas(n)
|
||||||
|
|
||||||
|
|
||||||
|
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
||||||
|
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
||||||
|
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
||||||
|
sigs = [
|
||||||
|
inner_model.t_to_sigma(ts)
|
||||||
|
for ts in torch.linspace(start, end, n + 1)[:-1]
|
||||||
|
]
|
||||||
|
sigs += [0.0]
|
||||||
|
return torch.FloatTensor(sigs).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
schedulers = [
|
||||||
|
Scheduler('automatic', 'Automatic', None),
|
||||||
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
||||||
|
Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
|
||||||
|
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
|
||||||
|
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
|
||||||
|
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
@ -6,6 +6,10 @@ import gradio as gr
|
|||||||
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from modules import util
|
from modules import util
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon
|
||||||
|
|
||||||
cmd_opts = shared_cmd_options.cmd_opts
|
cmd_opts = shared_cmd_options.cmd_opts
|
||||||
parser = shared_cmd_options.parser
|
parser = shared_cmd_options.parser
|
||||||
@ -16,11 +20,11 @@ styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.st
|
|||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
demo = None
|
demo: gr.Blocks = None
|
||||||
|
|
||||||
device = None
|
device: str = None
|
||||||
|
|
||||||
weight_load_location = None
|
weight_load_location: str = None
|
||||||
|
|
||||||
xformers_available = False
|
xformers_available = False
|
||||||
|
|
||||||
@ -28,22 +32,22 @@ hypernetworks = {}
|
|||||||
|
|
||||||
loaded_hypernetworks = []
|
loaded_hypernetworks = []
|
||||||
|
|
||||||
state = None
|
state: 'shared_state.State' = None
|
||||||
|
|
||||||
prompt_styles = None
|
prompt_styles: 'styles.StyleDatabase' = None
|
||||||
|
|
||||||
interrogator = None
|
interrogator: 'interrogate.InterrogateModels' = None
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
options_templates = None
|
options_templates: dict = None
|
||||||
opts = None
|
opts: options.Options = None
|
||||||
restricted_opts = None
|
restricted_opts: set[str] = None
|
||||||
|
|
||||||
sd_model: sd_models_types.WebuiSdModel = None
|
sd_model: sd_models_types.WebuiSdModel = None
|
||||||
|
|
||||||
settings_components = None
|
settings_components: dict = None
|
||||||
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
"""assigned from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
tab_names = []
|
tab_names = []
|
||||||
|
|
||||||
@ -65,9 +69,9 @@ progress_print_out = sys.stdout
|
|||||||
|
|
||||||
gradio_theme = gr.themes.Base()
|
gradio_theme = gr.themes.Base()
|
||||||
|
|
||||||
total_tqdm = None
|
total_tqdm: 'shared_total_tqdm.TotalTQDM' = None
|
||||||
|
|
||||||
mem_mon = None
|
mem_mon: 'memmon.MemUsageMonitor' = None
|
||||||
|
|
||||||
options_section = options.options_section
|
options_section = options.options_section
|
||||||
OptionInfo = options.OptionInfo
|
OptionInfo = options.OptionInfo
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
import html
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from modules import script_callbacks, scripts, ui_components
|
||||||
|
from modules.options import OptionHTML, OptionInfo
|
||||||
from modules.shared_cmd_options import cmd_opts
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
|
||||||
|
|
||||||
@ -118,6 +121,45 @@ def ui_reorder_categories():
|
|||||||
yield "scripts"
|
yield "scripts"
|
||||||
|
|
||||||
|
|
||||||
|
def callbacks_order_settings():
|
||||||
|
options = {
|
||||||
|
"sd_vae_explanation": OptionHTML("""
|
||||||
|
For categories below, callbacks added to dropdowns happen before others, in order listed.
|
||||||
|
"""),
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
callback_options = {}
|
||||||
|
|
||||||
|
for category, _ in script_callbacks.enumerate_callbacks():
|
||||||
|
callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)
|
||||||
|
|
||||||
|
for method_name in scripts.scripts_txt2img.callback_names:
|
||||||
|
callback_options["script_" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)
|
||||||
|
|
||||||
|
for method_name in scripts.scripts_img2img.callback_names:
|
||||||
|
callbacks = callback_options.get("script_" + method_name, [])
|
||||||
|
|
||||||
|
for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):
|
||||||
|
if any(x.name == addition.name for x in callbacks):
|
||||||
|
continue
|
||||||
|
|
||||||
|
callbacks.append(addition)
|
||||||
|
|
||||||
|
callback_options["script_" + method_name] = callbacks
|
||||||
|
|
||||||
|
for category, callbacks in callback_options.items():
|
||||||
|
if not callbacks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
option_info = OptionInfo([], f"{category} callback priority", ui_components.DropdownMulti, {"choices": [x.name for x in callbacks]})
|
||||||
|
option_info.needs_restart()
|
||||||
|
option_info.html("<div class='info'>Default order: <ol>" + "".join(f"<li>{html.escape(x.name)}</li>\n" for x in callbacks) + "</ol></div>")
|
||||||
|
options['prioritized_callbacks_' + category] = option_info
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
class Shared(sys.modules[__name__].__class__):
|
class Shared(sys.modules[__name__].__class__):
|
||||||
"""
|
"""
|
||||||
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
@ -101,6 +101,7 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
|
|||||||
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
|
"set_scale_by_when_changing_upscaler": OptionInfo(False, "Automatically set the Scale by factor based on the name of the selected Upscaler."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration", "postprocessing"), {
|
||||||
@ -212,7 +213,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
|||||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; overrides the above if set; WARNING: truncates negative prompt if it's too long; changes seeds"),
|
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; overrides the above if set; WARNING: truncates negative prompt if it's too long; changes seeds"),
|
||||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"),
|
||||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||||
}))
|
}))
|
||||||
@ -257,7 +258,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
|||||||
"extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"),
|
"extra_networks_card_description_is_html": OptionInfo(False, "Treat card description as HTML"),
|
||||||
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||||
"extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(),
|
"extra_networks_tree_view_style": OptionInfo("Dirs", "Extra Networks directory view style", gr.Radio, {"choices": ["Tree", "Dirs"]}).needs_reload_ui(),
|
||||||
|
"extra_networks_tree_view_default_enabled": OptionInfo(True, "Show the Extra Networks directory view by default").needs_reload_ui(),
|
||||||
"extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(),
|
"extra_networks_tree_view_default_width": OptionInfo(180, "Default width for the Extra Networks directory tree view", gr.Number).needs_reload_ui(),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
@ -312,6 +314,8 @@ options_templates.update(options_section(('ui', "User interface", "ui"), {
|
|||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
|
"enable_reloading_ui_scripts": OptionInfo(False, "Reload UI scripts when using Reload UI option").info("useful for developing: if you make changes to UI scripts code, it is applied when the UI is reloded."),
|
||||||
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
@ -363,13 +367,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"),
|
'sgm_noise_multiplier': OptionInfo(False, "SGM noise multiplier", infotext='SGM noise multiplier').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818").info("Match initial noise to official SDXL implementation - only useful for reproducing images"),
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'),
|
||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||||
|
@ -157,10 +157,12 @@ class State:
|
|||||||
self.current_image_sampling_step = self.sampling_step
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
# when switching models during generation, VAE would be on CPU, so creating an image will fail.
|
||||||
# we silently ignore this error
|
# we silently ignore this error
|
||||||
errors.record_exception()
|
errors.record_exception()
|
||||||
|
|
||||||
def assign_current_image(self, image):
|
def assign_current_image(self, image):
|
||||||
|
if shared.opts.live_previews_image_format == 'jpeg' and image.mode == 'RGBA':
|
||||||
|
image = image.convert('RGB')
|
||||||
self.current_image = image
|
self.current_image = image
|
||||||
self.id_live_preview += 1
|
self.id_live_preview += 1
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from modules import errors
|
from modules import errors
|
||||||
import csv
|
import csv
|
||||||
@ -42,7 +43,7 @@ def extract_style_text_from_prompt(style_text, prompt):
|
|||||||
stripped_style_text = style_text.strip()
|
stripped_style_text = style_text.strip()
|
||||||
|
|
||||||
if "{prompt}" in stripped_style_text:
|
if "{prompt}" in stripped_style_text:
|
||||||
left, right = stripped_style_text.split("{prompt}", 2)
|
left, _, right = stripped_style_text.partition("{prompt}")
|
||||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||||
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
||||||
return True, prompt
|
return True, prompt
|
||||||
|
@ -65,7 +65,7 @@ def crop_image(im, settings):
|
|||||||
rect[3] -= 1
|
rect[3] -= 1
|
||||||
d.rectangle(rect, outline=GREEN)
|
d.rectangle(rect, outline=GREEN)
|
||||||
results.append(im_debug)
|
results.append(im_debug)
|
||||||
if settings.destop_view_image:
|
if settings.desktop_view_image:
|
||||||
im_debug.show()
|
im_debug.show()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -341,5 +341,5 @@ class Settings:
|
|||||||
self.entropy_points_weight = entropy_points_weight
|
self.entropy_points_weight = entropy_points_weight
|
||||||
self.face_points_weight = face_points_weight
|
self.face_points_weight = face_points_weight
|
||||||
self.annotate_image = annotate_image
|
self.annotate_image = annotate_image
|
||||||
self.destop_view_image = False
|
self.desktop_view_image = False
|
||||||
self.dnn_model_path = dnn_model_path
|
self.dnn_model_path = dnn_model_path
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -10,7 +9,7 @@ from random import shuffle, choices
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import tqdm
|
import tqdm
|
||||||
from modules import devices, shared
|
from modules import devices, shared, images
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
@ -61,7 +60,7 @@ class PersonalizedBase(Dataset):
|
|||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
raise Exception("interrupted")
|
raise Exception("interrupted")
|
||||||
try:
|
try:
|
||||||
image = Image.open(path)
|
image = images.read(path)
|
||||||
#Currently does not work for single color transparency
|
#Currently does not work for single color transparency
|
||||||
#We would need to read image.info['transparency'] for that
|
#We would need to read image.info['transparency'] for that
|
||||||
if use_weight and 'A' in image.getbands():
|
if use_weight and 'A' in image.getbands():
|
||||||
|
@ -193,11 +193,11 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
embedded_image = insert_image_data_embed(cap_image, test_embed)
|
embedded_image = insert_image_data_embed(cap_image, test_embed)
|
||||||
|
|
||||||
retrived_embed = extract_image_data_embed(embedded_image)
|
retrieved_embed = extract_image_data_embed(embedded_image)
|
||||||
|
|
||||||
assert str(retrived_embed) == str(test_embed)
|
assert str(retrieved_embed) == str(test_embed)
|
||||||
|
|
||||||
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
|
embedded_image2 = insert_image_data_embed(cap_image, retrieved_embed)
|
||||||
|
|
||||||
assert embedded_image == embedded_image2
|
assert embedded_image == embedded_image2
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ class EmbeddingDatabase:
|
|||||||
if data:
|
if data:
|
||||||
name = data.get('name', name)
|
name = data.get('name', name)
|
||||||
else:
|
else:
|
||||||
# if data is None, means this is not an embeding, just a preview image
|
# if data is None, means this is not an embedding, just a preview image
|
||||||
return
|
return
|
||||||
elif ext in ['.BIN', '.PT']:
|
elif ext in ['.BIN', '.PT']:
|
||||||
data = torch.load(path, map_location="cpu")
|
data = torch.load(path, map_location="cpu")
|
||||||
|
@ -11,7 +11,7 @@ from PIL import Image
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
if force_enable_hr:
|
if force_enable_hr:
|
||||||
@ -24,10 +24,8 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
sampler_name=sampler_name,
|
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
steps=steps,
|
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@ -40,6 +38,7 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
|
|||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
|
||||||
|
hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
|
@ -12,7 +12,7 @@ from gradio.components.image_editor import Brush
|
|||||||
from PIL import Image, PngImagePlugin # noqa: F401
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensions # noqa: F401
|
from modules import gradio_extensons, sd_schedulers # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
@ -230,19 +230,6 @@ def create_output_panel(tabname, outdir, toprow=None):
|
|||||||
return ui_common.create_output_panel(tabname, outdir, toprow)
|
return ui_common.create_output_panel(tabname, outdir, toprow)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_and_steps_selection(choices, tabname):
|
|
||||||
if opts.samplers_in_dropdown:
|
|
||||||
with FormRow(elem_id=f"sampler_selection_{tabname}"):
|
|
||||||
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
|
||||||
else:
|
|
||||||
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
|
|
||||||
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
|
|
||||||
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
|
|
||||||
|
|
||||||
return steps, sampler_name
|
|
||||||
|
|
||||||
|
|
||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
|
||||||
|
|
||||||
@ -270,6 +257,9 @@ def create_ui():
|
|||||||
|
|
||||||
parameters_copypaste.reset()
|
parameters_copypaste.reset()
|
||||||
|
|
||||||
|
settings = ui_settings.UiSettings()
|
||||||
|
settings.register_settings()
|
||||||
|
|
||||||
scripts.scripts_current = scripts.scripts_txt2img
|
scripts.scripts_current = scripts.scripts_txt2img
|
||||||
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
@ -294,9 +284,6 @@ def create_ui():
|
|||||||
if category == "prompt":
|
if category == "prompt":
|
||||||
toprow.create_inline_toprow_prompts()
|
toprow.create_inline_toprow_prompts()
|
||||||
|
|
||||||
if category == "sampler":
|
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
with gr.Column(elem_id="txt2img_column_size", scale=4):
|
||||||
@ -337,10 +324,11 @@ def create_ui():
|
|||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||||
|
|
||||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
hr_sampler_name = gr.Dropdown(label='Sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||||
|
hr_scheduler = gr.Dropdown(label='Schedule type', elem_id="hr_scheduler", choices=["Use same scheduler"] + [x.label for x in sd_schedulers.schedulers], value="Use same scheduler")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
@ -395,8 +383,6 @@ def create_ui():
|
|||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
toprow.ui_styles.dropdown,
|
toprow.ui_styles.dropdown,
|
||||||
steps,
|
|
||||||
sampler_name,
|
|
||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
@ -411,6 +397,7 @@ def create_ui():
|
|||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
hr_checkpoint_name,
|
hr_checkpoint_name,
|
||||||
hr_sampler_name,
|
hr_sampler_name,
|
||||||
|
hr_scheduler,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
override_settings,
|
override_settings,
|
||||||
@ -460,8 +447,6 @@ def create_ui():
|
|||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
||||||
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
||||||
PasteField(steps, "Steps", api="steps"),
|
|
||||||
PasteField(sampler_name, "Sampler", api="sampler_name"),
|
|
||||||
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
||||||
PasteField(width, "Size-1", api="width"),
|
PasteField(width, "Size-1", api="width"),
|
||||||
PasteField(height, "Size-2", api="height"),
|
PasteField(height, "Size-2", api="height"),
|
||||||
@ -475,8 +460,9 @@ def create_ui():
|
|||||||
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
||||||
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
||||||
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
||||||
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
|
PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api="hr_sampler_name"),
|
||||||
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api="hr_scheduler"),
|
||||||
|
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" or d.get("Hires schedule type", "Use same scheduler") != "Use same scheduler" else gr.update()),
|
||||||
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
||||||
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
||||||
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||||
@ -487,11 +473,13 @@ def create_ui():
|
|||||||
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
steps = scripts.scripts_txt2img.script('Sampler').steps
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
steps,
|
steps,
|
||||||
sampler_name,
|
scripts.scripts_txt2img.script('Sampler').sampler_name,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
scripts.scripts_txt2img.script('Seed').seed,
|
scripts.scripts_txt2img.script('Seed').seed,
|
||||||
width,
|
width,
|
||||||
@ -615,9 +603,6 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
if category == "sampler":
|
|
||||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
|
||||||
|
|
||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
@ -739,8 +724,6 @@ def create_ui():
|
|||||||
inpaint_color_sketch,
|
inpaint_color_sketch,
|
||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
init_mask_inpaint,
|
init_mask_inpaint,
|
||||||
steps,
|
|
||||||
sampler_name,
|
|
||||||
mask_blur,
|
mask_blur,
|
||||||
mask_alpha,
|
mask_alpha,
|
||||||
inpainting_fill,
|
inpainting_fill,
|
||||||
@ -825,6 +808,8 @@ def create_ui():
|
|||||||
**interrogate_args,
|
**interrogate_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
steps = scripts.scripts_img2img.script('Sampler').steps
|
||||||
|
|
||||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||||
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
|
||||||
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
|
||||||
@ -833,8 +818,6 @@ def create_ui():
|
|||||||
img2img_paste_fields = [
|
img2img_paste_fields = [
|
||||||
(toprow.prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
|
||||||
(sampler_name, "Sampler"),
|
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
@ -1104,7 +1087,6 @@ def create_ui():
|
|||||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||||
ui_settings_from_file = loadsave.ui_settings.copy()
|
ui_settings_from_file = loadsave.ui_settings.copy()
|
||||||
|
|
||||||
settings = ui_settings.UiSettings()
|
|
||||||
settings.create_ui(loadsave, dummy_component)
|
settings.create_ui(loadsave, dummy_component)
|
||||||
|
|
||||||
interfaces = [
|
interfaces = [
|
||||||
|
@ -108,7 +108,7 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
|
logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
|
||||||
|
|
||||||
# NOTE: ensure csv integrity when fields are added by
|
# NOTE: ensure csv integrity when fields are added by
|
||||||
# updating headers and padding with delimeters where needed
|
# updating headers and padding with delimiters where needed
|
||||||
if os.path.exists(logfile_path):
|
if os.path.exists(logfile_path):
|
||||||
update_logfile(logfile_path, fields)
|
update_logfile(logfile_path, fields)
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ class DropdownEditable(gr.Dropdown, FormComponent):
|
|||||||
class InputAccordionImpl(gr.Checkbox):
|
class InputAccordionImpl(gr.Checkbox):
|
||||||
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
|
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
|
||||||
|
|
||||||
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
|
Actually just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
webui_do_not_create_gradio_pyi_thank_you = True
|
webui_do_not_create_gradio_pyi_thank_you = True
|
||||||
|
@ -380,7 +380,7 @@ def install_extension_from_url(dirname, url, branch_name=None):
|
|||||||
except OSError as err:
|
except OSError as err:
|
||||||
if err.errno == errno.EXDEV:
|
if err.errno == errno.EXDEV:
|
||||||
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
|
||||||
# Since we can't use a rename, do the slower but more versitile shutil.move()
|
# Since we can't use a rename, do the slower but more versatile shutil.move()
|
||||||
shutil.move(tmpdir, target_dir)
|
shutil.move(tmpdir, target_dir)
|
||||||
else:
|
else:
|
||||||
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import functools
|
import functools
|
||||||
import os.path
|
import os.path
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from base64 import b64decode
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -11,6 +13,7 @@ import gradio as gr
|
|||||||
import json
|
import json
|
||||||
import html
|
import html
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from modules.infotext_utils import image_from_url_text
|
from modules.infotext_utils import image_from_url_text
|
||||||
|
|
||||||
@ -108,6 +111,31 @@ def fetch_file(filename: str = ""):
|
|||||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_cover_images(page: str = "", item: str = "", index: int = 0):
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||||
|
if page is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
metadata = page.metadata.get(item)
|
||||||
|
if metadata is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
|
||||||
|
image = cover_images[index] if index < len(cover_images) else None
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(BytesIO(b64decode(image)))
|
||||||
|
buffer = BytesIO()
|
||||||
|
image.save(buffer, format=image.format)
|
||||||
|
return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
|
||||||
|
except Exception as err:
|
||||||
|
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
|
||||||
|
|
||||||
|
|
||||||
def get_metadata(page: str = "", item: str = ""):
|
def get_metadata(page: str = "", item: str = ""):
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
@ -119,6 +147,8 @@ def get_metadata(page: str = "", item: str = ""):
|
|||||||
if metadata is None:
|
if metadata is None:
|
||||||
return JSONResponse({})
|
return JSONResponse({})
|
||||||
|
|
||||||
|
metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'} # those are cover images, and they are too big to display in UI as text
|
||||||
|
|
||||||
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
|
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
|
||||||
|
|
||||||
|
|
||||||
@ -142,6 +172,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
|||||||
|
|
||||||
def add_pages_to_demo(app):
|
def add_pages_to_demo(app):
|
||||||
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
|
||||||
|
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
|
||||||
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
|
||||||
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
|
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
|
||||||
|
|
||||||
@ -151,6 +182,7 @@ def quote_js(s):
|
|||||||
s = s.replace('"', '\\"')
|
s = s.replace('"', '\\"')
|
||||||
return f'"{s}"'
|
return f'"{s}"'
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPage:
|
class ExtraNetworksPage:
|
||||||
def __init__(self, title):
|
def __init__(self, title):
|
||||||
self.title = title
|
self.title = title
|
||||||
@ -164,6 +196,8 @@ class ExtraNetworksPage:
|
|||||||
self.lister = util.MassFileLister()
|
self.lister = util.MassFileLister()
|
||||||
# HTML Templates
|
# HTML Templates
|
||||||
self.pane_tpl = shared.html("extra-networks-pane.html")
|
self.pane_tpl = shared.html("extra-networks-pane.html")
|
||||||
|
self.pane_content_tree_tpl = shared.html("extra-networks-pane-tree.html")
|
||||||
|
self.pane_content_dirs_tpl = shared.html("extra-networks-pane-dirs.html")
|
||||||
self.card_tpl = shared.html("extra-networks-card.html")
|
self.card_tpl = shared.html("extra-networks-card.html")
|
||||||
self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
|
self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
|
||||||
self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
|
self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
|
||||||
@ -243,14 +277,12 @@ class ExtraNetworksPage:
|
|||||||
btn_metadata = self.btn_metadata_tpl.format(
|
btn_metadata = self.btn_metadata_tpl.format(
|
||||||
**{
|
**{
|
||||||
"extra_networks_tabname": self.extra_networks_tabname,
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
"name": html.escape(item["name"]),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
btn_edit_item = self.btn_edit_item_tpl.format(
|
btn_edit_item = self.btn_edit_item_tpl.format(
|
||||||
**{
|
**{
|
||||||
"tabname": tabname,
|
"tabname": tabname,
|
||||||
"extra_networks_tabname": self.extra_networks_tabname,
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
"name": html.escape(item["name"]),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -476,6 +508,47 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
|
return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
|
||||||
|
|
||||||
|
def create_dirs_view_html(self, tabname: str) -> str:
|
||||||
|
"""Generates HTML for displaying folders."""
|
||||||
|
|
||||||
|
subdirs = {}
|
||||||
|
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||||
|
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
||||||
|
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
||||||
|
x = os.path.join(root, dirname)
|
||||||
|
|
||||||
|
if not os.path.isdir(x):
|
||||||
|
continue
|
||||||
|
|
||||||
|
subdir = os.path.abspath(x)[len(parentdir):]
|
||||||
|
|
||||||
|
if shared.opts.extra_networks_dir_button_function:
|
||||||
|
if not subdir.startswith(os.path.sep):
|
||||||
|
subdir = os.path.sep + subdir
|
||||||
|
else:
|
||||||
|
while subdir.startswith(os.path.sep):
|
||||||
|
subdir = subdir[1:]
|
||||||
|
|
||||||
|
is_empty = len(os.listdir(x)) == 0
|
||||||
|
if not is_empty and not subdir.endswith(os.path.sep):
|
||||||
|
subdir = subdir + os.path.sep
|
||||||
|
|
||||||
|
if (os.path.sep + "." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
||||||
|
continue
|
||||||
|
|
||||||
|
subdirs[subdir] = 1
|
||||||
|
|
||||||
|
if subdirs:
|
||||||
|
subdirs = {"": 1, **subdirs}
|
||||||
|
|
||||||
|
subdirs_html = "".join([f"""
|
||||||
|
<button class='lg secondary gradio-button custom-button{" search-all" if subdir == "" else ""}' onclick='extraNetworksSearchButton("{tabname}", "{self.extra_networks_tabname}", event)'>
|
||||||
|
{html.escape(subdir if subdir != "" else "all")}
|
||||||
|
</button>
|
||||||
|
""" for subdir in subdirs])
|
||||||
|
|
||||||
|
return subdirs_html
|
||||||
|
|
||||||
def create_card_view_html(self, tabname: str, *, none_message) -> str:
|
def create_card_view_html(self, tabname: str, *, none_message) -> str:
|
||||||
"""Generates HTML for the network Card View section for a tab.
|
"""Generates HTML for the network Card View section for a tab.
|
||||||
|
|
||||||
@ -489,15 +562,15 @@ class ExtraNetworksPage:
|
|||||||
Returns:
|
Returns:
|
||||||
HTML formatted string.
|
HTML formatted string.
|
||||||
"""
|
"""
|
||||||
res = ""
|
res = []
|
||||||
for item in self.items.values():
|
for item in self.items.values():
|
||||||
res += self.create_item_html(tabname, item, self.card_tpl)
|
res.append(self.create_item_html(tabname, item, self.card_tpl))
|
||||||
|
|
||||||
if res == "":
|
if not res:
|
||||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||||
res = none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
res = [none_message or shared.html("extra-networks-no-cards.html").format(dirs=dirs)]
|
||||||
|
|
||||||
return res
|
return "".join(res)
|
||||||
|
|
||||||
def create_html(self, tabname, *, empty=False):
|
def create_html(self, tabname, *, empty=False):
|
||||||
"""Generates an HTML string for the current pane.
|
"""Generates an HTML string for the current pane.
|
||||||
@ -526,35 +599,28 @@ class ExtraNetworksPage:
|
|||||||
if "user_metadata" not in item:
|
if "user_metadata" not in item:
|
||||||
self.read_user_metadata(item)
|
self.read_user_metadata(item)
|
||||||
|
|
||||||
data_sortdir = shared.opts.extra_networks_card_order
|
show_tree = shared.opts.extra_networks_tree_view_default_enabled
|
||||||
data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
|
|
||||||
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
|
|
||||||
tree_view_btn_extra_class = ""
|
|
||||||
tree_view_div_extra_class = "hidden"
|
|
||||||
tree_view_div_default_display = "none"
|
|
||||||
extra_network_pane_content_default_display = "flex"
|
|
||||||
if shared.opts.extra_networks_tree_view_default_enabled:
|
|
||||||
tree_view_btn_extra_class = "extra-network-control--enabled"
|
|
||||||
tree_view_div_extra_class = ""
|
|
||||||
tree_view_div_default_display = "block"
|
|
||||||
extra_network_pane_content_default_display = "grid"
|
|
||||||
|
|
||||||
return self.pane_tpl.format(
|
page_params = {
|
||||||
**{
|
"tabname": tabname,
|
||||||
"tabname": tabname,
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
"extra_networks_tabname": self.extra_networks_tabname,
|
"data_sortdir": shared.opts.extra_networks_card_order,
|
||||||
"data_sortmode": data_sortmode,
|
"sort_path_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '',
|
||||||
"data_sortkey": data_sortkey,
|
"sort_name_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '',
|
||||||
"data_sortdir": data_sortdir,
|
"sort_date_created_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '',
|
||||||
"tree_view_btn_extra_class": tree_view_btn_extra_class,
|
"sort_date_modified_active": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '',
|
||||||
"tree_view_div_extra_class": tree_view_div_extra_class,
|
"tree_view_btn_extra_class": "extra-network-control--enabled" if show_tree else "",
|
||||||
"tree_html": self.create_tree_view_html(tabname),
|
"items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None),
|
||||||
"items_html": self.create_card_view_html(tabname, none_message="Loading..." if empty else None),
|
"extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width,
|
||||||
"extra_networks_tree_view_default_width": shared.opts.extra_networks_tree_view_default_width,
|
"tree_view_div_default_display_class": "" if show_tree else "extra-network-dirs-hidden",
|
||||||
"tree_view_div_default_display": tree_view_div_default_display,
|
}
|
||||||
"extra_network_pane_content_default_display": extra_network_pane_content_default_display,
|
|
||||||
}
|
if shared.opts.extra_networks_tree_view_style == "Tree":
|
||||||
)
|
pane_content = self.pane_content_tree_tpl.format(**page_params, tree_html=self.create_tree_view_html(tabname))
|
||||||
|
else:
|
||||||
|
pane_content = self.pane_content_dirs_tpl.format(**page_params, dirs_html=self.create_dirs_view_html(tabname))
|
||||||
|
|
||||||
|
return self.pane_tpl.format(**page_params, pane_content=pane_content)
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
def create_item(self, name, index=None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -591,6 +657,17 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def find_embedded_preview(self, path, name, metadata):
|
||||||
|
"""
|
||||||
|
Find if embedded preview exists in safetensors metadata and return endpoint for it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file = f"{path}.safetensors"
|
||||||
|
if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
|
||||||
|
return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def find_description(self, path):
|
def find_description(self, path):
|
||||||
"""
|
"""
|
||||||
Find and read a description file for a given path (without extension).
|
Find and read a description file for a given path (without extension).
|
||||||
|
@ -133,8 +133,10 @@ class UserMetadataEditor:
|
|||||||
filename = item.get("filename", None)
|
filename = item.get("filename", None)
|
||||||
basename, ext = os.path.splitext(filename)
|
basename, ext = os.path.splitext(filename)
|
||||||
|
|
||||||
with open(basename + '.json', "w", encoding="utf8") as file:
|
metadata_path = basename + '.json'
|
||||||
|
with open(metadata_path, "w", encoding="utf8") as file:
|
||||||
json.dump(metadata, file, indent=4, ensure_ascii=False)
|
json.dump(metadata, file, indent=4, ensure_ascii=False)
|
||||||
|
self.page.lister.update_file_entry(metadata_path)
|
||||||
|
|
||||||
def save_user_metadata(self, name, desc, notes):
|
def save_user_metadata(self, name, desc, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
@ -185,7 +187,8 @@ class UserMetadataEditor:
|
|||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
|
|
||||||
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
images.save_image_with_geninfo(image, geninfo, item["local_preview"])
|
||||||
|
self.page.lister.update_file_entry(item["local_preview"])
|
||||||
|
item['preview'] = self.page.find_preview(item["local_preview"])
|
||||||
return self.get_card_html(name), ''
|
return self.get_card_html(name), ''
|
||||||
|
|
||||||
def setup_ui(self, gallery):
|
def setup_ui(self, gallery):
|
||||||
@ -200,6 +203,3 @@ class UserMetadataEditor:
|
|||||||
inputs=[self.edit_name_input],
|
inputs=[self.edit_name_input],
|
||||||
outputs=[]
|
outputs=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,6 +104,8 @@ class UiLoadsave:
|
|||||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||||
|
|
||||||
if type(x) == InputAccordion:
|
if type(x) == InputAccordion:
|
||||||
|
if hasattr(x, 'custom_script_source'):
|
||||||
|
x.accordion.custom_script_source = x.custom_script_source
|
||||||
if x.accordion.visible:
|
if x.accordion.visible:
|
||||||
apply_field(x.accordion, 'visible')
|
apply_field(x.accordion, 'visible')
|
||||||
apply_field(x, 'value')
|
apply_field(x, 'value')
|
||||||
|
@ -12,7 +12,7 @@ def create_ui():
|
|||||||
with gr.Column(variant='compact'):
|
with gr.Column(variant='compact'):
|
||||||
with gr.Tabs(elem_id="mode_extras"):
|
with gr.Tabs(elem_id="mode_extras"):
|
||||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||||
extras_image = gr.ImageEditor(label="Source", interactive=True, type="pil", elem_id="extras_image")
|
extras_image = gr.ImageEditor(label="Source", interactive=True, type="pil", elem_id="extras_image", image_mode="RGBA")
|
||||||
|
|
||||||
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
||||||
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||||
|
@ -67,7 +67,7 @@ class UiPromptStyles:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
||||||
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
||||||
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selection dropdown in main UI to the prompt.")
|
||||||
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
|
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items
|
||||||
from modules.call_queue import wrap_gradio_call
|
from modules.call_queue import wrap_gradio_call
|
||||||
|
from modules.options import options_section
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.ui_components import FormRow
|
from modules.ui_components import FormRow
|
||||||
from modules.ui_gradio_extensions import reload_javascript
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
@ -98,6 +99,9 @@ class UiSettings:
|
|||||||
|
|
||||||
return get_value_for_setting(key), opts.dumpjson()
|
return get_value_for_setting(key), opts.dumpjson()
|
||||||
|
|
||||||
|
def register_settings(self):
|
||||||
|
script_callbacks.ui_settings_callback()
|
||||||
|
|
||||||
def create_ui(self, loadsave, dummy_component):
|
def create_ui(self, loadsave, dummy_component):
|
||||||
self.components = []
|
self.components = []
|
||||||
self.component_dict = {}
|
self.component_dict = {}
|
||||||
@ -105,7 +109,11 @@ class UiSettings:
|
|||||||
|
|
||||||
shared.settings_components = self.component_dict
|
shared.settings_components = self.component_dict
|
||||||
|
|
||||||
script_callbacks.ui_settings_callback()
|
# we add this as late as possible so that scripts have already registered their callbacks
|
||||||
|
opts.data_labels.update(options_section(('callbacks', "Callbacks", "system"), {
|
||||||
|
**shared_items.callbacks_order_settings(),
|
||||||
|
}))
|
||||||
|
|
||||||
opts.reorder()
|
opts.reorder()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
with gr.Blocks(analytics_enabled=False) as settings_interface:
|
||||||
|
@ -20,7 +20,7 @@ class Upscaler:
|
|||||||
filter = None
|
filter = None
|
||||||
model = None
|
model = None
|
||||||
user_path = None
|
user_path = None
|
||||||
scalers: []
|
scalers: list
|
||||||
tile = True
|
tile = True
|
||||||
|
|
||||||
def __init__(self, create_dirs=False):
|
def __init__(self, create_dirs=False):
|
||||||
|
@ -69,10 +69,8 @@ def upscale_with_model(
|
|||||||
for y, h, row in grid.tiles:
|
for y, h, row in grid.tiles:
|
||||||
newrow = []
|
newrow = []
|
||||||
for x, w, tile in row:
|
for x, w, tile in row:
|
||||||
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
|
||||||
output = upscale_pil_patch(model, tile)
|
output = upscale_pil_patch(model, tile)
|
||||||
scale_factor = output.width // tile.width
|
scale_factor = output.width // tile.width
|
||||||
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
|
||||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||||
p.update(1)
|
p.update(1)
|
||||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||||
|
@ -81,6 +81,17 @@ class MassFileListerCachedDir:
|
|||||||
self.files = {x[0].lower(): x for x in files}
|
self.files = {x[0].lower(): x for x in files}
|
||||||
self.files_cased = {x[0]: x for x in files}
|
self.files_cased = {x[0]: x for x in files}
|
||||||
|
|
||||||
|
def update_entry(self, filename):
|
||||||
|
"""Add a file to the cache"""
|
||||||
|
file_path = os.path.join(self.dirname, filename)
|
||||||
|
try:
|
||||||
|
stat = os.stat(file_path)
|
||||||
|
entry = (filename, stat.st_mtime, stat.st_ctime)
|
||||||
|
self.files[filename.lower()] = entry
|
||||||
|
self.files_cased[filename] = entry
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(f'MassFileListerCachedDir.add_entry: "{file_path}" {e}')
|
||||||
|
|
||||||
|
|
||||||
class MassFileLister:
|
class MassFileLister:
|
||||||
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
|
"""A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file."""
|
||||||
@ -136,3 +147,27 @@ class MassFileLister:
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear the cache of all directories."""
|
"""Clear the cache of all directories."""
|
||||||
self.cached_dirs.clear()
|
self.cached_dirs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def topological_sort(dependencies):
|
||||||
|
"""Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
|
||||||
|
Ignores errors relating to missing dependeencies or circular dependencies
|
||||||
|
"""
|
||||||
|
|
||||||
|
visited = {}
|
||||||
|
result = []
|
||||||
|
|
||||||
|
def inner(name):
|
||||||
|
visited[name] = True
|
||||||
|
|
||||||
|
for dep in dependencies.get(name, []):
|
||||||
|
if dep in dependencies and dep not in visited:
|
||||||
|
inner(dep)
|
||||||
|
|
||||||
|
result.append(name)
|
||||||
|
|
||||||
|
for depname in dependencies:
|
||||||
|
if depname not in visited:
|
||||||
|
inner(depname)
|
||||||
|
|
||||||
|
return result
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
|
||||||
extend-select = [
|
extend-select = [
|
||||||
"B",
|
"B",
|
||||||
"C",
|
"C",
|
||||||
@ -25,10 +27,10 @@ ignore = [
|
|||||||
"W605", # invalid escape sequence, messes with some docstrings
|
"W605", # invalid escape sequence, messes with some docstrings
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.lint.per-file-ignores]
|
||||||
"webui.py" = ["E402"] # Module level import not at top of file
|
"webui.py" = ["E402"] # Module level import not at top of file
|
||||||
|
|
||||||
[tool.ruff.flake8-bugbear]
|
[tool.ruff.lint.flake8-bugbear]
|
||||||
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
||||||
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ accelerate
|
|||||||
|
|
||||||
blendmodes
|
blendmodes
|
||||||
clean-fid
|
clean-fid
|
||||||
|
diskcache
|
||||||
einops
|
einops
|
||||||
facexlib
|
facexlib
|
||||||
fastapi>=0.90.1
|
fastapi>=0.90.1
|
||||||
|
@ -3,6 +3,7 @@ Pillow==9.5.0
|
|||||||
accelerate==0.21.0
|
accelerate==0.21.0
|
||||||
blendmodes==2022
|
blendmodes==2022
|
||||||
clean-fid==0.1.35
|
clean-fid==0.1.35
|
||||||
|
diskcache==5.6.3
|
||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
facexlib==0.3.0
|
facexlib==0.3.0
|
||||||
fastapi==0.104.1
|
fastapi==0.104.1
|
||||||
|
@ -102,7 +102,7 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
|
|||||||
shaped_noise_fft = _fft2(noise_rgb)
|
shaped_noise_fft = _fft2(noise_rgb)
|
||||||
shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
|
shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
|
||||||
|
|
||||||
brightness_variation = 0. # color_variation # todo: temporarily tieing brightness variation to color variation for now
|
brightness_variation = 0. # color_variation # todo: temporarily tying brightness variation to color variation for now
|
||||||
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
|
contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
|
||||||
|
|
||||||
# scikit-image is used for histogram matching, very convenient!
|
# scikit-image is used for histogram matching, very convenient!
|
||||||
|
@ -25,7 +25,7 @@ class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing
|
|||||||
if codeformer_visibility == 0 or not enable:
|
if codeformer_visibility == 0 or not enable:
|
||||||
return
|
return
|
||||||
|
|
||||||
restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
|
restored_img = codeformer_model.codeformer.restore(np.array(pp.image.convert("RGB"), dtype=np.uint8), w=codeformer_weight)
|
||||||
res = Image.fromarray(restored_img)
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
if codeformer_visibility < 1.0:
|
if codeformer_visibility < 1.0:
|
||||||
|
@ -22,7 +22,7 @@ class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
if gfpgan_visibility == 0 or not enable:
|
if gfpgan_visibility == 0 or not enable:
|
||||||
return
|
return
|
||||||
|
|
||||||
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
|
restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image.convert("RGB"), dtype=np.uint8))
|
||||||
res = Image.fromarray(restored_img)
|
res = Image.fromarray(restored_img)
|
||||||
|
|
||||||
if gfpgan_visibility < 1.0:
|
if gfpgan_visibility < 1.0:
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import scripts_postprocessing, shared
|
from modules import scripts_postprocessing, shared
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.ui_components import FormRow, ToolButton
|
from modules.ui_components import FormRow, ToolButton, InputAccordion
|
||||||
from modules.ui import switch_values_symbol
|
from modules.ui import switch_values_symbol
|
||||||
|
|
||||||
upscale_cache = {}
|
upscale_cache = {}
|
||||||
@ -17,7 +19,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
def ui(self):
|
def ui(self):
|
||||||
selected_tab = gr.Number(value=0, visible=False)
|
selected_tab = gr.Number(value=0, visible=False)
|
||||||
|
|
||||||
with gr.Column():
|
with InputAccordion(True, label="Upscale", elem_id="extras_upscale") as upscale_enabled:
|
||||||
|
with FormRow():
|
||||||
|
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
|
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Tabs(elem_id="extras_resize_mode"):
|
with gr.Tabs(elem_id="extras_resize_mode"):
|
||||||
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
|
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
|
||||||
@ -32,18 +41,24 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
|
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
|
||||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||||
|
|
||||||
with FormRow():
|
def on_selected_upscale_method(upscale_method):
|
||||||
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
if not shared.opts.set_scale_by_when_changing_upscaler:
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
with FormRow():
|
match = re.search(r'(\d)[xX]|[xX](\d)', upscale_method)
|
||||||
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
if not match:
|
||||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
return gr.update()
|
||||||
|
|
||||||
|
return gr.update(value=int(match.group(1) or match.group(2)))
|
||||||
|
|
||||||
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
|
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
|
||||||
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
||||||
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
||||||
|
|
||||||
|
extras_upscaler_1.change(on_selected_upscale_method, inputs=[extras_upscaler_1], outputs=[upscaling_resize], show_progress="hidden")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"upscale_enabled": upscale_enabled,
|
||||||
"upscale_mode": selected_tab,
|
"upscale_mode": selected_tab,
|
||||||
"upscale_by": upscaling_resize,
|
"upscale_by": upscaling_resize,
|
||||||
"upscale_to_width": upscaling_resize_w,
|
"upscale_to_width": upscaling_resize_w,
|
||||||
@ -81,7 +96,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
|
def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
|
||||||
if upscale_mode == 1:
|
if upscale_mode == 1:
|
||||||
pp.shared.target_width = upscale_to_width
|
pp.shared.target_width = upscale_to_width
|
||||||
pp.shared.target_height = upscale_to_height
|
pp.shared.target_height = upscale_to_height
|
||||||
@ -89,7 +104,10 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
pp.shared.target_width = int(pp.image.width * upscale_by)
|
pp.shared.target_width = int(pp.image.width * upscale_by)
|
||||||
pp.shared.target_height = int(pp.image.height * upscale_by)
|
pp.shared.target_height = int(pp.image.height * upscale_by)
|
||||||
|
|
||||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
|
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
|
||||||
|
if not upscale_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
if upscaler_1_name == "None":
|
if upscaler_1_name == "None":
|
||||||
upscaler_1_name = None
|
upscaler_1_name = None
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user