mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 03:40:14 +08:00
Merge branch 'dev' into ngrok-py
This commit is contained in:
commit
182330ae40
4
.eslintignore
Normal file
4
.eslintignore
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
extensions
|
||||||
|
extensions-disabled
|
||||||
|
repositories
|
||||||
|
venv
|
89
.eslintrc.js
Normal file
89
.eslintrc.js
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
module.exports = {
|
||||||
|
env: {
|
||||||
|
browser: true,
|
||||||
|
es2021: true,
|
||||||
|
},
|
||||||
|
extends: "eslint:recommended",
|
||||||
|
parserOptions: {
|
||||||
|
ecmaVersion: "latest",
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
"arrow-spacing": "error",
|
||||||
|
"block-spacing": "error",
|
||||||
|
"brace-style": "error",
|
||||||
|
"comma-dangle": ["error", "only-multiline"],
|
||||||
|
"comma-spacing": "error",
|
||||||
|
"comma-style": ["error", "last"],
|
||||||
|
"curly": ["error", "multi-line", "consistent"],
|
||||||
|
"eol-last": "error",
|
||||||
|
"func-call-spacing": "error",
|
||||||
|
"function-call-argument-newline": ["error", "consistent"],
|
||||||
|
"function-paren-newline": ["error", "consistent"],
|
||||||
|
"indent": ["error", 4],
|
||||||
|
"key-spacing": "error",
|
||||||
|
"keyword-spacing": "error",
|
||||||
|
"linebreak-style": ["error", "unix"],
|
||||||
|
"no-extra-semi": "error",
|
||||||
|
"no-mixed-spaces-and-tabs": "error",
|
||||||
|
"no-trailing-spaces": "error",
|
||||||
|
"no-whitespace-before-property": "error",
|
||||||
|
"object-curly-newline": ["error", {consistent: true, multiline: true}],
|
||||||
|
"quote-props": ["error", "consistent-as-needed"],
|
||||||
|
"semi": ["error", "always"],
|
||||||
|
"semi-spacing": "error",
|
||||||
|
"semi-style": ["error", "last"],
|
||||||
|
"space-before-blocks": "error",
|
||||||
|
"space-before-function-paren": ["error", "never"],
|
||||||
|
"space-in-parens": ["error", "never"],
|
||||||
|
"space-infix-ops": "error",
|
||||||
|
"space-unary-ops": "error",
|
||||||
|
"switch-colon-spacing": "error",
|
||||||
|
"template-curly-spacing": ["error", "never"],
|
||||||
|
"unicode-bom": "error",
|
||||||
|
"no-multi-spaces": "error",
|
||||||
|
"object-curly-spacing": ["error", "never"],
|
||||||
|
"operator-linebreak": ["error", "after"],
|
||||||
|
"no-unused-vars": "off",
|
||||||
|
"no-redeclare": "off",
|
||||||
|
},
|
||||||
|
globals: {
|
||||||
|
// this file
|
||||||
|
module: "writable",
|
||||||
|
//script.js
|
||||||
|
gradioApp: "writable",
|
||||||
|
onUiLoaded: "writable",
|
||||||
|
onUiUpdate: "writable",
|
||||||
|
onOptionsChanged: "writable",
|
||||||
|
uiCurrentTab: "writable",
|
||||||
|
uiElementIsVisible: "writable",
|
||||||
|
executeCallbacks: "writable",
|
||||||
|
//ui.js
|
||||||
|
opts: "writable",
|
||||||
|
all_gallery_buttons: "writable",
|
||||||
|
selected_gallery_button: "writable",
|
||||||
|
selected_gallery_index: "writable",
|
||||||
|
args_to_array: "writable",
|
||||||
|
switch_to_txt2img: "writable",
|
||||||
|
switch_to_img2img_tab: "writable",
|
||||||
|
switch_to_img2img: "writable",
|
||||||
|
switch_to_sketch: "writable",
|
||||||
|
switch_to_inpaint: "writable",
|
||||||
|
switch_to_inpaint_sketch: "writable",
|
||||||
|
switch_to_extras: "writable",
|
||||||
|
get_tab_index: "writable",
|
||||||
|
create_submit_args: "writable",
|
||||||
|
restart_reload: "writable",
|
||||||
|
updateInput: "writable",
|
||||||
|
//extraNetworks.js
|
||||||
|
requestGet: "writable",
|
||||||
|
popup: "writable",
|
||||||
|
// from python
|
||||||
|
localization: "writable",
|
||||||
|
// progrssbar.js
|
||||||
|
randomId: "writable",
|
||||||
|
requestProgress: "writable",
|
||||||
|
// imageviewer.js
|
||||||
|
modalPrevImage: "writable",
|
||||||
|
modalNextImage: "writable",
|
||||||
|
}
|
||||||
|
};
|
21
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
21
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -47,6 +47,15 @@ body:
|
|||||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: py-version
|
||||||
|
attributes:
|
||||||
|
label: What Python version are you running on ?
|
||||||
|
multiple: false
|
||||||
|
options:
|
||||||
|
- Python 3.10.x
|
||||||
|
- Python 3.11.x (above, no supported yet)
|
||||||
|
- Python 3.9.x (below, no recommended)
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: platforms
|
id: platforms
|
||||||
attributes:
|
attributes:
|
||||||
@ -59,6 +68,18 @@ body:
|
|||||||
- iOS
|
- iOS
|
||||||
- Android
|
- Android
|
||||||
- Other/Cloud
|
- Other/Cloud
|
||||||
|
- type: dropdown
|
||||||
|
id: device
|
||||||
|
attributes:
|
||||||
|
label: What device are you running WebUI on?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Nvidia GPUs (RTX 20 above)
|
||||||
|
- Nvidia GPUs (GTX 16 below)
|
||||||
|
- AMD GPUs (RX 6000 above)
|
||||||
|
- AMD GPUs (RX 5000 below)
|
||||||
|
- CPU
|
||||||
|
- Other GPUs
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: browsers
|
id: browsers
|
||||||
attributes:
|
attributes:
|
||||||
|
49
.github/workflows/on_pull_request.yaml
vendored
49
.github/workflows/on_pull_request.yaml
vendored
@ -1,39 +1,34 @@
|
|||||||
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
|
|
||||||
name: Run Linting/Formatting on Pull Requests
|
name: Run Linting/Formatting on Pull Requests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
- push
|
- push
|
||||||
- pull_request
|
- pull_request
|
||||||
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
|
|
||||||
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
|
|
||||||
# pull_request:
|
|
||||||
# branches:
|
|
||||||
# - master
|
|
||||||
# branches-ignore:
|
|
||||||
# - development
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint-python:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- uses: actions/setup-python@v4
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.11
|
||||||
cache: pip
|
# NB: there's no cache: pip here since we're not installing anything
|
||||||
cache-dependency-path: |
|
# from the requirements.txt file(s) in the repository; it's faster
|
||||||
**/requirements*txt
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
- name: Install PyLint
|
# of PyTorch and other dependencies.
|
||||||
run: |
|
- name: Install Ruff
|
||||||
python -m pip install --upgrade pip
|
run: pip install ruff==0.0.265
|
||||||
pip install pylint
|
- name: Run Ruff
|
||||||
# This lets PyLint check to see if it can resolve imports
|
run: ruff .
|
||||||
- name: Install dependencies
|
lint-js:
|
||||||
run: |
|
runs-on: ubuntu-latest
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
steps:
|
||||||
python launch.py
|
- name: Checkout Code
|
||||||
- name: Analysing the code with pylint
|
uses: actions/checkout@v3
|
||||||
run: |
|
- name: Install Node.js
|
||||||
pylint $(git ls-files '*.py')
|
uses: actions/setup-node@v3
|
||||||
|
with:
|
||||||
|
node-version: 18
|
||||||
|
- run: npm i --ci
|
||||||
|
- run: npm run lint
|
||||||
|
6
.github/workflows/run_tests.yaml
vendored
6
.github/workflows/run_tests.yaml
vendored
@ -17,8 +17,14 @@ jobs:
|
|||||||
cache: pip
|
cache: pip
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
**/requirements*txt
|
**/requirements*txt
|
||||||
|
launch.py
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||||
|
env:
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK: "1"
|
||||||
|
PIP_PROGRESS_BAR: "off"
|
||||||
|
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||||
|
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||||
- name: Upload main app stdout-stderr
|
- name: Upload main app stdout-stderr
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
if: always()
|
if: always()
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -34,3 +34,5 @@ notification.mp3
|
|||||||
/test/stderr.txt
|
/test/stderr.txt
|
||||||
/cache.json*
|
/cache.json*
|
||||||
/config_states/
|
/config_states/
|
||||||
|
/node_modules
|
||||||
|
/package-lock.json
|
@ -99,6 +99,12 @@ Alternatively, use online services (like Google Colab):
|
|||||||
|
|
||||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||||
|
|
||||||
|
### Installation on Windows 10/11 with NVidia-GPUs using release package
|
||||||
|
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
|
||||||
|
2. Run `update.bat`.
|
||||||
|
3. Run `run.bat`.
|
||||||
|
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
|
||||||
|
|
||||||
### Automatic Installation on Windows
|
### Automatic Installation on Windows
|
||||||
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
|
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
|
||||||
2. Install [git](https://git-scm.com/download/win).
|
2. Install [git](https://git-scm.com/download/win).
|
||||||
@ -158,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -88,7 +88,7 @@ class LDSR:
|
|||||||
|
|
||||||
x_t = None
|
x_t = None
|
||||||
logs = None
|
logs = None
|
||||||
for n in range(n_runs):
|
for _ in range(n_runs):
|
||||||
if custom_shape is not None:
|
if custom_shape is not None:
|
||||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
@ -110,7 +110,6 @@ class LDSR:
|
|||||||
diffusion_steps = int(steps)
|
diffusion_steps = int(steps)
|
||||||
eta = 1.0
|
eta = 1.0
|
||||||
|
|
||||||
down_sample_method = 'Lanczos'
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
if torch.cuda.is_available:
|
||||||
@ -131,11 +130,11 @@ class LDSR:
|
|||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
@ -158,7 +157,7 @@ class LDSR:
|
|||||||
|
|
||||||
|
|
||||||
def get_cond(selected_path):
|
def get_cond(selected_path):
|
||||||
example = dict()
|
example = {}
|
||||||
up_f = 4
|
up_f = 4
|
||||||
c = selected_path.convert('RGB')
|
c = selected_path.convert('RGB')
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
log = dict()
|
log = {}
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
import sd_hijack_autoencoder # noqa: F401
|
||||||
|
import sd_hijack_ddpm_v1 # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
import ldm.models.autoencoder
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class VQModel(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
|
|||||||
n_embed,
|
n_embed,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=None):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
|
|||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
|
|||||||
return self.decoder.conv_out.weight
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.image_key)
|
x = self.get_input(batch, self.image_key)
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
if only_inputs:
|
if only_inputs:
|
||||||
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
|
|||||||
if plot_ema:
|
if plot_ema:
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
xrec_ema, _ = self(x)
|
xrec_ema, _ = self(x)
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
return log
|
return log
|
||||||
|
|
||||||
@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
|
|||||||
|
|
||||||
class VQModelInterface(VQModel):
|
class VQModelInterface(VQModel):
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
|
|||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
ldm.models.autoencoder.VQModel = VQModel
|
||||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
||||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface):
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
@ -177,7 +176,7 @@ def load_lora(name, filename):
|
|||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
@ -189,7 +188,7 @@ def load_lora(name, filename):
|
|||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||||
|
|
||||||
if len(keys_failed_to_match) > 0:
|
if len(keys_failed_to_match) > 0:
|
||||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||||
@ -207,7 +206,7 @@ def load_loras(names, multipliers=None):
|
|||||||
loaded_loras.clear()
|
loaded_loras.clear()
|
||||||
|
|
||||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
if any([x is None for x in loras_on_disk]):
|
if any(x is None for x in loras_on_disk):
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
|
||||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
|
|
||||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||||
|
|
||||||
setattr(self, "lora_current_names", wanted_names)
|
self.lora_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
def lora_forward(module, input, original_forward):
|
def lora_forward(module, input, original_forward):
|
||||||
@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward):
|
|||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
setattr(self, "lora_current_names", ())
|
self.lora_current_names = ()
|
||||||
setattr(self, "lora_weights_backup", None)
|
self.lora_weights_backup = None
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
@ -428,7 +427,7 @@ def infotext_pasted(infotext, params):
|
|||||||
|
|
||||||
added = []
|
added = []
|
||||||
|
|
||||||
for k, v in params.items():
|
for k in params:
|
||||||
if not k.startswith("AddNet Model "):
|
if not k.startswith("AddNet Model "):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
|||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -10,10 +10,9 @@ from tqdm import tqdm
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader, script_callbacks
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules import images
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for k, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||||
|
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
@ -61,7 +61,9 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output: tensor shape [b h w c]
|
output: tensor shape [b h w c]
|
||||||
"""
|
"""
|
||||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
if self.type != 'W':
|
||||||
|
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
@ -85,8 +87,9 @@ class WMSA(nn.Module):
|
|||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
if self.type != 'W':
|
||||||
dims=(1, 2))
|
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def relative_embedding(self):
|
def relative_embedding(self):
|
||||||
@ -262,4 +265,4 @@ class SCUNet(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import cmd_opts, opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR as net
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
@ -644,7 +644,7 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
x = self.check_image_size(x)
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
self.mean = self.mean.type_as(x)
|
||||||
x = (x - self.mean) * self.img_range
|
x = (x - self.mean) * self.img_range
|
||||||
|
|
||||||
@ -844,7 +844,7 @@ class SwinIR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
|
@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=[0, 0]):
|
pretrained_window_size=(0, 0)):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
def calculate_mask(self, x_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
def forward(self, x, x_size):
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
else:
|
else:
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
flops += H * W * self.dim // 2
|
flops += H * W * self.dim // 2
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
""" A basic Swin Transformer layer for one stage.
|
""" A basic Swin Transformer layer for one stage.
|
||||||
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
|
|||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
nn.init.constant_(blk.norm2.bias, 0)
|
nn.init.constant_(blk.norm2.bias, 0)
|
||||||
nn.init.constant_(blk.norm2.weight, 0)
|
nn.init.constant_(blk.norm2.weight, 0)
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
r""" Image to Patch Embedding
|
r""" Image to Patch Embedding
|
||||||
Args:
|
Args:
|
||||||
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
|
|||||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
flops += Ho * Wo * self.embed_dim
|
flops += Ho * Wo * self.embed_dim
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
class RSTB(nn.Module):
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
"""Residual Swin Transformer Block (RSTB).
|
||||||
@ -531,7 +531,7 @@ class RSTB(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop, attn_drop=attn_drop,
|
drop=drop, attn_drop=attn_drop,
|
||||||
drop_path=drop_path,
|
drop_path=drop_path,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample, self).__init__(*m)
|
super(Upsample, self).__init__(*m)
|
||||||
|
|
||||||
class Upsample_hf(nn.Sequential):
|
class Upsample_hf(nn.Sequential):
|
||||||
"""Upsample module.
|
"""Upsample module.
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
|
|||||||
m.append(nn.PixelShuffle(3))
|
m.append(nn.PixelShuffle(3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample_hf, self).__init__(*m)
|
super(Upsample_hf, self).__init__(*m)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
class UpsampleOneStep(nn.Sequential):
|
||||||
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
flops = H * W * self.num_feat * 3 * 9
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Swin2SR(nn.Module):
|
class Swin2SR(nn.Module):
|
||||||
r""" Swin2SR
|
r""" Swin2SR
|
||||||
@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle_hf':
|
if self.upsampler == 'pixelshuffle_hf':
|
||||||
self.layers_hf = nn.ModuleList()
|
self.layers_hf = nn.ModuleList()
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers_hf.append(layer)
|
self.layers_hf.append(layer)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
# build the last conv layer in deep feature extraction
|
||||||
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
|
|||||||
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
self.conv_after_aux = nn.Sequential(
|
self.conv_after_aux = nn.Sequential(
|
||||||
nn.Conv2d(3, num_feat, 3, 1, 1),
|
nn.Conv2d(3, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
self.upsample = Upsample(upscale, num_feat)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffle_hf':
|
elif self.upsampler == 'pixelshuffle_hf':
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
|
|||||||
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
elif self.upsampler == 'pixelshuffledirect':
|
||||||
# for lightweight SR (to save parameters)
|
# for lightweight SR (to save parameters)
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||||
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_features_hf(self, x):
|
def forward_features_hf(self, x):
|
||||||
x_size = (x.shape[2], x.shape[3])
|
x_size = (x.shape[2], x.shape[3])
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.norm(x) # B L C
|
x = self.norm(x) # B L C
|
||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
x = self.conv_after_body(self.forward_features(x)) + x
|
||||||
x_before = self.conv_before_upsample(x)
|
x_before = self.conv_before_upsample(x)
|
||||||
x_out = self.conv_last(self.upsample(x_before))
|
x_out = self.conv_last(self.upsample(x_before))
|
||||||
|
|
||||||
x_hf = self.conv_first_hf(x_before)
|
x_hf = self.conv_first_hf(x_before)
|
||||||
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
||||||
x_hf = self.conv_before_upsample_hf(x_hf)
|
x_hf = self.conv_before_upsample_hf(x_hf)
|
||||||
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
|
|||||||
x_first = self.conv_first(x)
|
x_first = self.conv_first(x)
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||||
x = x + self.conv_last(res)
|
x = x + self.conv_last(res)
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
x = x / self.img_range + self.mean
|
||||||
if self.upsampler == "pixelshuffle_aux":
|
if self.upsampler == "pixelshuffle_aux":
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
||||||
|
|
||||||
elif self.upsampler == "pixelshuffle_hf":
|
elif self.upsampler == "pixelshuffle_hf":
|
||||||
x_out = x_out / self.img_range + self.mean
|
x_out = x_out / self.img_range + self.mean
|
||||||
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
x = torch.randn((1, 3, height, width))
|
||||||
x = model(x)
|
x = model(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
@ -4,39 +4,39 @@
|
|||||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||||
|
|
||||||
function checkBrackets(textArea, counterElt) {
|
function checkBrackets(textArea, counterElt) {
|
||||||
var counts = {};
|
var counts = {};
|
||||||
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
|
(textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => {
|
||||||
counts[bracket] = (counts[bracket] || 0) + 1;
|
counts[bracket] = (counts[bracket] || 0) + 1;
|
||||||
});
|
});
|
||||||
var errors = [];
|
var errors = [];
|
||||||
|
|
||||||
function checkPair(open, close, kind) {
|
function checkPair(open, close, kind) {
|
||||||
if (counts[open] !== counts[close]) {
|
if (counts[open] !== counts[close]) {
|
||||||
errors.push(
|
errors.push(
|
||||||
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
|
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
|
||||||
);
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
checkPair('(', ')', 'round brackets');
|
checkPair('(', ')', 'round brackets');
|
||||||
checkPair('[', ']', 'square brackets');
|
checkPair('[', ']', 'square brackets');
|
||||||
checkPair('{', '}', 'curly brackets');
|
checkPair('{', '}', 'curly brackets');
|
||||||
counterElt.title = errors.join('\n');
|
counterElt.title = errors.join('\n');
|
||||||
counterElt.classList.toggle('error', errors.length !== 0);
|
counterElt.classList.toggle('error', errors.length !== 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
function setupBracketChecking(id_prompt, id_counter) {
|
function setupBracketChecking(id_prompt, id_counter) {
|
||||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||||
var counter = gradioApp().getElementById(id_counter)
|
var counter = gradioApp().getElementById(id_counter);
|
||||||
|
|
||||||
if (textarea && counter) {
|
if (textarea && counter) {
|
||||||
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
|
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiLoaded(function () {
|
onUiLoaded(function() {
|
||||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
|
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
|
||||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
|
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
|
||||||
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
|
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
|
||||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
|
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
|
||||||
});
|
});
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
<ul>
|
<ul>
|
||||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||||
</ul>
|
</ul>
|
||||||
<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
|
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
<span class='description'>{description}</span>
|
<span class='description'>{description}</span>
|
||||||
|
@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
|
||||||
|
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Ollin Boer Bohan
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
</pre>
|
</pre>
|
@ -1,111 +1,113 @@
|
|||||||
|
|
||||||
let currentWidth = null;
|
let currentWidth = null;
|
||||||
let currentHeight = null;
|
let currentHeight = null;
|
||||||
let arFrameTimeout = setTimeout(function(){},0);
|
let arFrameTimeout = setTimeout(function() {}, 0);
|
||||||
|
|
||||||
function dimensionChange(e, is_width, is_height){
|
function dimensionChange(e, is_width, is_height) {
|
||||||
|
|
||||||
if(is_width){
|
if (is_width) {
|
||||||
currentWidth = e.target.value*1.0
|
currentWidth = e.target.value * 1.0;
|
||||||
}
|
}
|
||||||
if(is_height){
|
if (is_height) {
|
||||||
currentHeight = e.target.value*1.0
|
currentHeight = e.target.value * 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
|
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
|
||||||
|
|
||||||
if(!inImg2img){
|
if (!inImg2img) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var targetElement = null;
|
var targetElement = null;
|
||||||
|
|
||||||
var tabIndex = get_tab_index('mode_img2img')
|
var tabIndex = get_tab_index('mode_img2img');
|
||||||
if(tabIndex == 0){ // img2img
|
if (tabIndex == 0) { // img2img
|
||||||
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
|
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
|
||||||
} else if(tabIndex == 1){ //Sketch
|
} else if (tabIndex == 1) { //Sketch
|
||||||
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
|
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
|
||||||
} else if(tabIndex == 2){ // Inpaint
|
} else if (tabIndex == 2) { // Inpaint
|
||||||
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
||||||
} else if(tabIndex == 3){ // Inpaint sketch
|
} else if (tabIndex == 3) { // Inpaint sketch
|
||||||
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
|
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if(targetElement){
|
if (targetElement) {
|
||||||
|
|
||||||
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
||||||
if(!arPreviewRect){
|
if (!arPreviewRect) {
|
||||||
arPreviewRect = document.createElement('div')
|
arPreviewRect = document.createElement('div');
|
||||||
arPreviewRect.id = "imageARPreview";
|
arPreviewRect.id = "imageARPreview";
|
||||||
gradioApp().appendChild(arPreviewRect)
|
gradioApp().appendChild(arPreviewRect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
var viewportOffset = targetElement.getBoundingClientRect();
|
var viewportOffset = targetElement.getBoundingClientRect();
|
||||||
|
|
||||||
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);
|
||||||
|
|
||||||
var scaledx = targetElement.naturalWidth*viewportscale
|
var scaledx = targetElement.naturalWidth * viewportscale;
|
||||||
var scaledy = targetElement.naturalHeight*viewportscale
|
var scaledy = targetElement.naturalHeight * viewportscale;
|
||||||
|
|
||||||
var cleintRectTop = (viewportOffset.top+window.scrollY)
|
var cleintRectTop = (viewportOffset.top + window.scrollY);
|
||||||
var cleintRectLeft = (viewportOffset.left+window.scrollX)
|
var cleintRectLeft = (viewportOffset.left + window.scrollX);
|
||||||
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight / 2);
|
||||||
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth / 2);
|
||||||
|
|
||||||
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
|
var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);
|
||||||
var arscaledx = currentWidth*arscale
|
var arscaledx = currentWidth * arscale;
|
||||||
var arscaledy = currentHeight*arscale
|
var arscaledy = currentHeight * arscale;
|
||||||
|
|
||||||
var arRectTop = cleintRectCentreY-(arscaledy/2)
|
var arRectTop = cleintRectCentreY - (arscaledy / 2);
|
||||||
var arRectLeft = cleintRectCentreX-(arscaledx/2)
|
var arRectLeft = cleintRectCentreX - (arscaledx / 2);
|
||||||
var arRectWidth = arscaledx
|
var arRectWidth = arscaledx;
|
||||||
var arRectHeight = arscaledy
|
var arRectHeight = arscaledy;
|
||||||
|
|
||||||
arPreviewRect.style.top = arRectTop+'px';
|
arPreviewRect.style.top = arRectTop + 'px';
|
||||||
arPreviewRect.style.left = arRectLeft+'px';
|
arPreviewRect.style.left = arRectLeft + 'px';
|
||||||
arPreviewRect.style.width = arRectWidth+'px';
|
arPreviewRect.style.width = arRectWidth + 'px';
|
||||||
arPreviewRect.style.height = arRectHeight+'px';
|
arPreviewRect.style.height = arRectHeight + 'px';
|
||||||
|
|
||||||
clearTimeout(arFrameTimeout);
|
clearTimeout(arFrameTimeout);
|
||||||
arFrameTimeout = setTimeout(function(){
|
arFrameTimeout = setTimeout(function() {
|
||||||
arPreviewRect.style.display = 'none';
|
arPreviewRect.style.display = 'none';
|
||||||
},2000);
|
}, 2000);
|
||||||
|
|
||||||
arPreviewRect.style.display = 'block';
|
arPreviewRect.style.display = 'block';
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function() {
|
||||||
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
var arPreviewRect = gradioApp().querySelector('#imageARPreview');
|
||||||
if(arPreviewRect){
|
if (arPreviewRect) {
|
||||||
arPreviewRect.style.display = 'none';
|
arPreviewRect.style.display = 'none';
|
||||||
}
|
}
|
||||||
var tabImg2img = gradioApp().querySelector("#tab_img2img");
|
var tabImg2img = gradioApp().querySelector("#tab_img2img");
|
||||||
if (tabImg2img) {
|
if (tabImg2img) {
|
||||||
var inImg2img = tabImg2img.style.display == "block";
|
var inImg2img = tabImg2img.style.display == "block";
|
||||||
if(inImg2img){
|
if (inImg2img) {
|
||||||
let inputs = gradioApp().querySelectorAll('input');
|
let inputs = gradioApp().querySelectorAll('input');
|
||||||
inputs.forEach(function(e){
|
inputs.forEach(function(e) {
|
||||||
var is_width = e.parentElement.id == "img2img_width"
|
var is_width = e.parentElement.id == "img2img_width";
|
||||||
var is_height = e.parentElement.id == "img2img_height"
|
var is_height = e.parentElement.id == "img2img_height";
|
||||||
|
|
||||||
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {
|
||||||
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
e.addEventListener('input', function(e) {
|
||||||
e.classList.add('scrollwatch')
|
dimensionChange(e, is_width, is_height);
|
||||||
}
|
});
|
||||||
if(is_width){
|
e.classList.add('scrollwatch');
|
||||||
currentWidth = e.value*1.0
|
}
|
||||||
}
|
if (is_width) {
|
||||||
if(is_height){
|
currentWidth = e.value * 1.0;
|
||||||
currentHeight = e.value*1.0
|
}
|
||||||
}
|
if (is_height) {
|
||||||
})
|
currentHeight = e.value * 1.0;
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
@ -1,166 +1,172 @@
|
|||||||
|
|
||||||
contextMenuInit = function(){
|
var contextMenuInit = function() {
|
||||||
let eventListenerApplied=false;
|
let eventListenerApplied = false;
|
||||||
let menuSpecs = new Map();
|
let menuSpecs = new Map();
|
||||||
|
|
||||||
const uid = function(){
|
const uid = function() {
|
||||||
return Date.now().toString(36) + Math.random().toString(36).substring(2);
|
return Date.now().toString(36) + Math.random().toString(36).substring(2);
|
||||||
}
|
};
|
||||||
|
|
||||||
function showContextMenu(event,element,menuEntries){
|
function showContextMenu(event, element, menuEntries) {
|
||||||
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
|
||||||
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
|
||||||
|
|
||||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
let oldMenu = gradioApp().querySelector('#context-menu');
|
||||||
if(oldMenu){
|
if (oldMenu) {
|
||||||
oldMenu.remove()
|
oldMenu.remove();
|
||||||
}
|
}
|
||||||
|
|
||||||
let baseStyle = window.getComputedStyle(uiCurrentTab)
|
let baseStyle = window.getComputedStyle(uiCurrentTab);
|
||||||
|
|
||||||
const contextMenu = document.createElement('nav')
|
const contextMenu = document.createElement('nav');
|
||||||
contextMenu.id = "context-menu"
|
contextMenu.id = "context-menu";
|
||||||
contextMenu.style.background = baseStyle.background
|
contextMenu.style.background = baseStyle.background;
|
||||||
contextMenu.style.color = baseStyle.color
|
contextMenu.style.color = baseStyle.color;
|
||||||
contextMenu.style.fontFamily = baseStyle.fontFamily
|
contextMenu.style.fontFamily = baseStyle.fontFamily;
|
||||||
contextMenu.style.top = posy+'px'
|
contextMenu.style.top = posy + 'px';
|
||||||
contextMenu.style.left = posx+'px'
|
contextMenu.style.left = posx + 'px';
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
const contextMenuList = document.createElement('ul')
|
const contextMenuList = document.createElement('ul');
|
||||||
contextMenuList.className = 'context-menu-items';
|
contextMenuList.className = 'context-menu-items';
|
||||||
contextMenu.append(contextMenuList);
|
contextMenu.append(contextMenuList);
|
||||||
|
|
||||||
menuEntries.forEach(function(entry){
|
menuEntries.forEach(function(entry) {
|
||||||
let contextMenuEntry = document.createElement('a')
|
let contextMenuEntry = document.createElement('a');
|
||||||
contextMenuEntry.innerHTML = entry['name']
|
contextMenuEntry.innerHTML = entry['name'];
|
||||||
contextMenuEntry.addEventListener("click", function() {
|
contextMenuEntry.addEventListener("click", function() {
|
||||||
entry['func']();
|
entry['func']();
|
||||||
})
|
});
|
||||||
contextMenuList.append(contextMenuEntry);
|
contextMenuList.append(contextMenuEntry);
|
||||||
|
|
||||||
})
|
});
|
||||||
|
|
||||||
gradioApp().appendChild(contextMenu)
|
gradioApp().appendChild(contextMenu);
|
||||||
|
|
||||||
let menuWidth = contextMenu.offsetWidth + 4;
|
let menuWidth = contextMenu.offsetWidth + 4;
|
||||||
let menuHeight = contextMenu.offsetHeight + 4;
|
let menuHeight = contextMenu.offsetHeight + 4;
|
||||||
|
|
||||||
let windowWidth = window.innerWidth;
|
let windowWidth = window.innerWidth;
|
||||||
let windowHeight = window.innerHeight;
|
let windowHeight = window.innerHeight;
|
||||||
|
|
||||||
if ( (windowWidth - posx) < menuWidth ) {
|
if ((windowWidth - posx) < menuWidth) {
|
||||||
contextMenu.style.left = windowWidth - menuWidth + "px";
|
contextMenu.style.left = windowWidth - menuWidth + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( (windowHeight - posy) < menuHeight ) {
|
if ((windowHeight - posy) < menuHeight) {
|
||||||
contextMenu.style.top = windowHeight - menuHeight + "px";
|
contextMenu.style.top = windowHeight - menuHeight + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
function appendContextMenuOption(targetElementSelector, entryName, entryFunction) {
|
||||||
|
|
||||||
var currentItems = menuSpecs.get(targetElementSelector)
|
var currentItems = menuSpecs.get(targetElementSelector);
|
||||||
|
|
||||||
if(!currentItems){
|
if (!currentItems) {
|
||||||
currentItems = []
|
currentItems = [];
|
||||||
menuSpecs.set(targetElementSelector,currentItems);
|
menuSpecs.set(targetElementSelector, currentItems);
|
||||||
}
|
}
|
||||||
let newItem = {'id':targetElementSelector+'_'+uid(),
|
let newItem = {
|
||||||
'name':entryName,
|
id: targetElementSelector + '_' + uid(),
|
||||||
'func':entryFunction,
|
name: entryName,
|
||||||
'isNew':true}
|
func: entryFunction,
|
||||||
|
isNew: true
|
||||||
currentItems.push(newItem)
|
};
|
||||||
return newItem['id']
|
|
||||||
}
|
currentItems.push(newItem);
|
||||||
|
return newItem['id'];
|
||||||
function removeContextMenuOption(uid){
|
}
|
||||||
menuSpecs.forEach(function(v) {
|
|
||||||
let index = -1
|
function removeContextMenuOption(uid) {
|
||||||
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
menuSpecs.forEach(function(v) {
|
||||||
if(index>=0){
|
let index = -1;
|
||||||
v.splice(index, 1);
|
v.forEach(function(e, ei) {
|
||||||
}
|
if (e['id'] == uid) {
|
||||||
})
|
index = ei;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
function addContextMenuEventListener(){
|
if (index >= 0) {
|
||||||
if(eventListenerApplied){
|
v.splice(index, 1);
|
||||||
return;
|
}
|
||||||
}
|
});
|
||||||
gradioApp().addEventListener("click", function(e) {
|
}
|
||||||
if(! e.isTrusted){
|
|
||||||
return
|
function addContextMenuEventListener() {
|
||||||
}
|
if (eventListenerApplied) {
|
||||||
|
return;
|
||||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
}
|
||||||
if(oldMenu){
|
gradioApp().addEventListener("click", function(e) {
|
||||||
oldMenu.remove()
|
if (!e.isTrusted) {
|
||||||
}
|
return;
|
||||||
});
|
}
|
||||||
gradioApp().addEventListener("contextmenu", function(e) {
|
|
||||||
let oldMenu = gradioApp().querySelector('#context-menu')
|
let oldMenu = gradioApp().querySelector('#context-menu');
|
||||||
if(oldMenu){
|
if (oldMenu) {
|
||||||
oldMenu.remove()
|
oldMenu.remove();
|
||||||
}
|
}
|
||||||
menuSpecs.forEach(function(v,k) {
|
});
|
||||||
if(e.composedPath()[0].matches(k)){
|
gradioApp().addEventListener("contextmenu", function(e) {
|
||||||
showContextMenu(e,e.composedPath()[0],v)
|
let oldMenu = gradioApp().querySelector('#context-menu');
|
||||||
e.preventDefault()
|
if (oldMenu) {
|
||||||
}
|
oldMenu.remove();
|
||||||
})
|
}
|
||||||
});
|
menuSpecs.forEach(function(v, k) {
|
||||||
eventListenerApplied=true
|
if (e.composedPath()[0].matches(k)) {
|
||||||
|
showContextMenu(e, e.composedPath()[0], v);
|
||||||
}
|
e.preventDefault();
|
||||||
|
}
|
||||||
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
|
});
|
||||||
}
|
});
|
||||||
|
eventListenerApplied = true;
|
||||||
initResponse = contextMenuInit();
|
|
||||||
appendContextMenuOption = initResponse[0];
|
}
|
||||||
removeContextMenuOption = initResponse[1];
|
|
||||||
addContextMenuEventListener = initResponse[2];
|
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener];
|
||||||
|
};
|
||||||
(function(){
|
|
||||||
//Start example Context Menu Items
|
var initResponse = contextMenuInit();
|
||||||
let generateOnRepeat = function(genbuttonid,interruptbuttonid){
|
var appendContextMenuOption = initResponse[0];
|
||||||
let genbutton = gradioApp().querySelector(genbuttonid);
|
var removeContextMenuOption = initResponse[1];
|
||||||
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
var addContextMenuEventListener = initResponse[2];
|
||||||
if(!interruptbutton.offsetParent){
|
|
||||||
genbutton.click();
|
(function() {
|
||||||
}
|
//Start example Context Menu Items
|
||||||
clearInterval(window.generateOnRepeatInterval)
|
let generateOnRepeat = function(genbuttonid, interruptbuttonid) {
|
||||||
window.generateOnRepeatInterval = setInterval(function(){
|
let genbutton = gradioApp().querySelector(genbuttonid);
|
||||||
if(!interruptbutton.offsetParent){
|
let interruptbutton = gradioApp().querySelector(interruptbuttonid);
|
||||||
genbutton.click();
|
if (!interruptbutton.offsetParent) {
|
||||||
}
|
genbutton.click();
|
||||||
},
|
}
|
||||||
500)
|
clearInterval(window.generateOnRepeatInterval);
|
||||||
}
|
window.generateOnRepeatInterval = setInterval(function() {
|
||||||
|
if (!interruptbutton.offsetParent) {
|
||||||
appendContextMenuOption('#txt2img_generate','Generate forever',function(){
|
genbutton.click();
|
||||||
generateOnRepeat('#txt2img_generate','#txt2img_interrupt');
|
}
|
||||||
})
|
},
|
||||||
appendContextMenuOption('#img2img_generate','Generate forever',function(){
|
500);
|
||||||
generateOnRepeat('#img2img_generate','#img2img_interrupt');
|
};
|
||||||
})
|
|
||||||
|
appendContextMenuOption('#txt2img_generate', 'Generate forever', function() {
|
||||||
let cancelGenerateForever = function(){
|
generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');
|
||||||
clearInterval(window.generateOnRepeatInterval)
|
});
|
||||||
}
|
appendContextMenuOption('#img2img_generate', 'Generate forever', function() {
|
||||||
|
generateOnRepeat('#img2img_generate', '#img2img_interrupt');
|
||||||
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
});
|
||||||
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
|
||||||
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
let cancelGenerateForever = function() {
|
||||||
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
clearInterval(window.generateOnRepeatInterval);
|
||||||
|
};
|
||||||
})();
|
|
||||||
//End example Context Menu Items
|
appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
|
||||||
|
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever);
|
||||||
onUiUpdate(function(){
|
appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever);
|
||||||
addContextMenuEventListener()
|
appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever);
|
||||||
});
|
|
||||||
|
})();
|
||||||
|
//End example Context Menu Items
|
||||||
|
|
||||||
|
onUiUpdate(function() {
|
||||||
|
addContextMenuEventListener();
|
||||||
|
});
|
||||||
|
47
javascript/dragdrop.js
vendored
47
javascript/dragdrop.js
vendored
@ -1,11 +1,11 @@
|
|||||||
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
|
// allows drag-dropping files into gradio image elements, and also pasting images from clipboard
|
||||||
|
|
||||||
function isValidImageList( files ) {
|
function isValidImageList(files) {
|
||||||
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
|
return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);
|
||||||
}
|
}
|
||||||
|
|
||||||
function dropReplaceImage( imgWrap, files ) {
|
function dropReplaceImage(imgWrap, files) {
|
||||||
if ( ! isValidImageList( files ) ) {
|
if (!isValidImageList(files)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,44 +14,44 @@ function dropReplaceImage( imgWrap, files ) {
|
|||||||
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();
|
||||||
const callback = () => {
|
const callback = () => {
|
||||||
const fileInput = imgWrap.querySelector('input[type="file"]');
|
const fileInput = imgWrap.querySelector('input[type="file"]');
|
||||||
if ( fileInput ) {
|
if (fileInput) {
|
||||||
if ( files.length === 0 ) {
|
if (files.length === 0) {
|
||||||
files = new DataTransfer();
|
files = new DataTransfer();
|
||||||
files.items.add(tmpFile);
|
files.items.add(tmpFile);
|
||||||
fileInput.files = files.files;
|
fileInput.files = files.files;
|
||||||
} else {
|
} else {
|
||||||
fileInput.files = files;
|
fileInput.files = files;
|
||||||
}
|
}
|
||||||
fileInput.dispatchEvent(new Event('change'));
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if ( imgWrap.closest('#pnginfo_image') ) {
|
if (imgWrap.closest('#pnginfo_image')) {
|
||||||
// special treatment for PNG Info tab, wait for fetch request to finish
|
// special treatment for PNG Info tab, wait for fetch request to finish
|
||||||
const oldFetch = window.fetch;
|
const oldFetch = window.fetch;
|
||||||
window.fetch = async (input, options) => {
|
window.fetch = async(input, options) => {
|
||||||
const response = await oldFetch(input, options);
|
const response = await oldFetch(input, options);
|
||||||
if ( 'api/predict/' === input ) {
|
if ('api/predict/' === input) {
|
||||||
const content = await response.text();
|
const content = await response.text();
|
||||||
window.fetch = oldFetch;
|
window.fetch = oldFetch;
|
||||||
window.requestAnimationFrame( () => callback() );
|
window.requestAnimationFrame(() => callback());
|
||||||
return new Response(content, {
|
return new Response(content, {
|
||||||
status: response.status,
|
status: response.status,
|
||||||
statusText: response.statusText,
|
statusText: response.statusText,
|
||||||
headers: response.headers
|
headers: response.headers
|
||||||
})
|
});
|
||||||
}
|
}
|
||||||
return response;
|
return response;
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
window.requestAnimationFrame( () => callback() );
|
window.requestAnimationFrame(() => callback());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
window.document.addEventListener('dragover', e => {
|
window.document.addEventListener('dragover', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
const imgWrap = target.closest('[data-testid="image"]');
|
||||||
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
@ -65,33 +65,34 @@ window.document.addEventListener('drop', e => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
const imgWrap = target.closest('[data-testid="image"]');
|
||||||
if ( !imgWrap ) {
|
if (!imgWrap) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
const files = e.dataTransfer.files;
|
const files = e.dataTransfer.files;
|
||||||
dropReplaceImage( imgWrap, files );
|
dropReplaceImage(imgWrap, files);
|
||||||
});
|
});
|
||||||
|
|
||||||
window.addEventListener('paste', e => {
|
window.addEventListener('paste', e => {
|
||||||
const files = e.clipboardData.files;
|
const files = e.clipboardData.files;
|
||||||
if ( ! isValidImageList( files ) ) {
|
if (!isValidImageList(files)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
|
const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid="image"]')]
|
||||||
.filter(el => uiElementIsVisible(el));
|
.filter(el => uiElementIsVisible(el));
|
||||||
if ( ! visibleImageFields.length ) {
|
if (!visibleImageFields.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstFreeImageField = visibleImageFields
|
const firstFreeImageField = visibleImageFields
|
||||||
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
||||||
|
|
||||||
dropReplaceImage(
|
dropReplaceImage(
|
||||||
firstFreeImageField ?
|
firstFreeImageField ?
|
||||||
firstFreeImageField :
|
firstFreeImageField :
|
||||||
visibleImageFields[visibleImageFields.length - 1]
|
visibleImageFields[visibleImageFields.length - 1]
|
||||||
, files );
|
, files
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
@ -1,120 +1,120 @@
|
|||||||
function keyupEditAttention(event){
|
function keyupEditAttention(event) {
|
||||||
let target = event.originalTarget || event.composedPath()[0];
|
let target = event.originalTarget || event.composedPath()[0];
|
||||||
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
|
if (!target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
|
||||||
if (! (event.metaKey || event.ctrlKey)) return;
|
if (!(event.metaKey || event.ctrlKey)) return;
|
||||||
|
|
||||||
let isPlus = event.key == "ArrowUp"
|
let isPlus = event.key == "ArrowUp";
|
||||||
let isMinus = event.key == "ArrowDown"
|
let isMinus = event.key == "ArrowDown";
|
||||||
if (!isPlus && !isMinus) return;
|
if (!isPlus && !isMinus) return;
|
||||||
|
|
||||||
let selectionStart = target.selectionStart;
|
let selectionStart = target.selectionStart;
|
||||||
let selectionEnd = target.selectionEnd;
|
let selectionEnd = target.selectionEnd;
|
||||||
let text = target.value;
|
let text = target.value;
|
||||||
|
|
||||||
function selectCurrentParenthesisBlock(OPEN, CLOSE){
|
function selectCurrentParenthesisBlock(OPEN, CLOSE) {
|
||||||
if (selectionStart !== selectionEnd) return false;
|
if (selectionStart !== selectionEnd) return false;
|
||||||
|
|
||||||
// Find opening parenthesis around current cursor
|
// Find opening parenthesis around current cursor
|
||||||
const before = text.substring(0, selectionStart);
|
const before = text.substring(0, selectionStart);
|
||||||
let beforeParen = before.lastIndexOf(OPEN);
|
let beforeParen = before.lastIndexOf(OPEN);
|
||||||
if (beforeParen == -1) return false;
|
if (beforeParen == -1) return false;
|
||||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||||
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find closing parenthesis around current cursor
|
// Find closing parenthesis around current cursor
|
||||||
const after = text.substring(selectionStart);
|
const after = text.substring(selectionStart);
|
||||||
let afterParen = after.indexOf(CLOSE);
|
let afterParen = after.indexOf(CLOSE);
|
||||||
if (afterParen == -1) return false;
|
if (afterParen == -1) return false;
|
||||||
let afterParenOpen = after.indexOf(OPEN);
|
let afterParenOpen = after.indexOf(OPEN);
|
||||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||||
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
||||||
}
|
}
|
||||||
if (beforeParen === -1 || afterParen === -1) return false;
|
if (beforeParen === -1 || afterParen === -1) return false;
|
||||||
|
|
||||||
// Set the selection to the text between the parenthesis
|
// Set the selection to the text between the parenthesis
|
||||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||||
const lastColon = parenContent.lastIndexOf(":");
|
const lastColon = parenContent.lastIndexOf(":");
|
||||||
selectionStart = beforeParen + 1;
|
selectionStart = beforeParen + 1;
|
||||||
selectionEnd = selectionStart + lastColon;
|
selectionEnd = selectionStart + lastColon;
|
||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
function selectCurrentWord(){
|
function selectCurrentWord() {
|
||||||
if (selectionStart !== selectionEnd) return false;
|
if (selectionStart !== selectionEnd) return false;
|
||||||
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
||||||
|
|
||||||
// seek backward until to find beggining
|
// seek backward until to find beggining
|
||||||
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
||||||
selectionStart--;
|
selectionStart--;
|
||||||
}
|
}
|
||||||
|
|
||||||
// seek forward to find end
|
// seek forward to find end
|
||||||
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
|
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
|
||||||
selectionEnd++;
|
selectionEnd++;
|
||||||
}
|
}
|
||||||
|
|
||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||||
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
||||||
selectCurrentWord();
|
selectCurrentWord();
|
||||||
}
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
var closeCharacter = ')'
|
var closeCharacter = ')';
|
||||||
var delta = opts.keyedit_precision_attention
|
var delta = opts.keyedit_precision_attention;
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
if (selectionStart > 0 && text[selectionStart - 1] == '<') {
|
||||||
closeCharacter = '>'
|
closeCharacter = '>';
|
||||||
delta = opts.keyedit_precision_extra
|
delta = opts.keyedit_precision_extra;
|
||||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
||||||
|
|
||||||
// do not include spaces at the end
|
// do not include spaces at the end
|
||||||
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
|
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
||||||
selectionEnd -= 1;
|
selectionEnd -= 1;
|
||||||
}
|
}
|
||||||
if(selectionStart == selectionEnd){
|
if (selectionStart == selectionEnd) {
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||||
|
|
||||||
selectionStart += 1;
|
selectionStart += 1;
|
||||||
selectionEnd += 1;
|
selectionEnd += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
weight = parseFloat(weight.toPrecision(12));
|
weight = parseFloat(weight.toPrecision(12));
|
||||||
if(String(weight).length == 1) weight += ".0"
|
if (String(weight).length == 1) weight += ".0";
|
||||||
|
|
||||||
if (closeCharacter == ')' && weight == 1) {
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
|
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
|
||||||
selectionStart--;
|
selectionStart--;
|
||||||
selectionEnd--;
|
selectionEnd--;
|
||||||
} else {
|
} else {
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
target.value = text;
|
target.value = text;
|
||||||
target.selectionStart = selectionStart;
|
target.selectionStart = selectionStart;
|
||||||
target.selectionEnd = selectionEnd;
|
target.selectionEnd = selectionEnd;
|
||||||
|
|
||||||
updateInput(target)
|
updateInput(target);
|
||||||
}
|
}
|
||||||
|
|
||||||
addEventListener('keydown', (event) => {
|
addEventListener('keydown', (event) => {
|
||||||
keyupEditAttention(event);
|
keyupEditAttention(event);
|
||||||
});
|
});
|
||||||
|
@ -1,71 +1,74 @@
|
|||||||
|
|
||||||
function extensions_apply(_disabled_list, _update_list, disable_all){
|
function extensions_apply(_disabled_list, _update_list, disable_all) {
|
||||||
var disable = []
|
var disable = [];
|
||||||
var update = []
|
var update = [];
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if (x.name.startsWith("enable_") && !x.checked) {
|
||||||
disable.push(x.name.substring(7))
|
disable.push(x.name.substring(7));
|
||||||
|
}
|
||||||
if(x.name.startsWith("update_") && x.checked)
|
|
||||||
update.push(x.name.substring(7))
|
if (x.name.startsWith("update_") && x.checked) {
|
||||||
})
|
update.push(x.name.substring(7));
|
||||||
|
}
|
||||||
restart_reload()
|
});
|
||||||
|
|
||||||
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
restart_reload();
|
||||||
}
|
|
||||||
|
return [JSON.stringify(disable), JSON.stringify(update), disable_all];
|
||||||
function extensions_check(){
|
}
|
||||||
var disable = []
|
|
||||||
|
function extensions_check() {
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
var disable = [];
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
|
||||||
disable.push(x.name.substring(7))
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x) {
|
||||||
})
|
if (x.name.startsWith("enable_") && !x.checked) {
|
||||||
|
disable.push(x.name.substring(7));
|
||||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
}
|
||||||
x.innerHTML = "Loading..."
|
});
|
||||||
})
|
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
|
||||||
|
x.innerHTML = "Loading...";
|
||||||
var id = randomId()
|
});
|
||||||
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function(){
|
|
||||||
|
|
||||||
})
|
var id = randomId();
|
||||||
|
requestProgress(id, gradioApp().getElementById('extensions_installed_top'), null, function() {
|
||||||
return [id, JSON.stringify(disable)]
|
|
||||||
}
|
});
|
||||||
|
|
||||||
function install_extension_from_index(button, url){
|
return [id, JSON.stringify(disable)];
|
||||||
button.disabled = "disabled"
|
}
|
||||||
button.value = "Installing..."
|
|
||||||
|
function install_extension_from_index(button, url) {
|
||||||
var textarea = gradioApp().querySelector('#extension_to_install textarea')
|
button.disabled = "disabled";
|
||||||
textarea.value = url
|
button.value = "Installing...";
|
||||||
updateInput(textarea)
|
|
||||||
|
var textarea = gradioApp().querySelector('#extension_to_install textarea');
|
||||||
gradioApp().querySelector('#install_extension_button').click()
|
textarea.value = url;
|
||||||
}
|
updateInput(textarea);
|
||||||
|
|
||||||
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
|
gradioApp().querySelector('#install_extension_button').click();
|
||||||
if (config_state_name == "Current") {
|
}
|
||||||
return [false, config_state_name, config_restore_type];
|
|
||||||
}
|
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
|
||||||
let restored = "";
|
if (config_state_name == "Current") {
|
||||||
if (config_restore_type == "extensions") {
|
return [false, config_state_name, config_restore_type];
|
||||||
restored = "all saved extension versions";
|
}
|
||||||
} else if (config_restore_type == "webui") {
|
let restored = "";
|
||||||
restored = "the webui version";
|
if (config_restore_type == "extensions") {
|
||||||
} else {
|
restored = "all saved extension versions";
|
||||||
restored = "the webui version and all saved extension versions";
|
} else if (config_restore_type == "webui") {
|
||||||
}
|
restored = "the webui version";
|
||||||
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
|
} else {
|
||||||
if (confirmed) {
|
restored = "the webui version and all saved extension versions";
|
||||||
restart_reload();
|
}
|
||||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
|
||||||
x.innerHTML = "Loading..."
|
if (confirmed) {
|
||||||
})
|
restart_reload();
|
||||||
}
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {
|
||||||
return [confirmed, config_state_name, config_restore_type];
|
x.innerHTML = "Loading...";
|
||||||
}
|
});
|
||||||
|
}
|
||||||
|
return [confirmed, config_state_name, config_restore_type];
|
||||||
|
}
|
||||||
|
@ -1,196 +1,215 @@
|
|||||||
function setupExtraNetworksForTab(tabname){
|
function setupExtraNetworksForTab(tabname) {
|
||||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
||||||
|
|
||||||
var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div')
|
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
||||||
var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea')
|
var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea');
|
||||||
var refresh = gradioApp().getElementById(tabname+'_extra_refresh')
|
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||||
|
|
||||||
search.classList.add('search')
|
search.classList.add('search');
|
||||||
tabs.appendChild(search)
|
tabs.appendChild(search);
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh);
|
||||||
|
|
||||||
var applyFilter = function(){
|
var applyFilter = function() {
|
||||||
var searchTerm = search.value.toLowerCase()
|
var searchTerm = search.value.toLowerCase();
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
||||||
var searchOnly = elem.querySelector('.search_only')
|
var searchOnly = elem.querySelector('.search_only');
|
||||||
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase();
|
||||||
|
|
||||||
var visible = text.indexOf(searchTerm) != -1
|
var visible = text.indexOf(searchTerm) != -1;
|
||||||
|
|
||||||
if(searchOnly && searchTerm.length < 4){
|
if (searchOnly && searchTerm.length < 4) {
|
||||||
visible = false
|
visible = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
elem.style.display = visible ? "" : "none"
|
elem.style.display = visible ? "" : "none";
|
||||||
})
|
});
|
||||||
}
|
};
|
||||||
|
|
||||||
search.addEventListener("input", applyFilter);
|
search.addEventListener("input", applyFilter);
|
||||||
applyFilter();
|
applyFilter();
|
||||||
|
|
||||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyExtraNetworkFilter(tabname){
|
function applyExtraNetworkFilter(tabname) {
|
||||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
var extraNetworksApplyFilter = {}
|
var extraNetworksApplyFilter = {};
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks(){
|
function setupExtraNetworks() {
|
||||||
setupExtraNetworksForTab('txt2img')
|
setupExtraNetworksForTab('txt2img');
|
||||||
setupExtraNetworksForTab('img2img')
|
setupExtraNetworksForTab('img2img');
|
||||||
|
|
||||||
function registerPrompt(tabname, id){
|
function registerPrompt(tabname, id) {
|
||||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||||
|
|
||||||
if (! activePromptTextarea[tabname]){
|
if (!activePromptTextarea[tabname]) {
|
||||||
activePromptTextarea[tabname] = textarea
|
activePromptTextarea[tabname] = textarea;
|
||||||
}
|
}
|
||||||
|
|
||||||
textarea.addEventListener("focus", function(){
|
textarea.addEventListener("focus", function() {
|
||||||
activePromptTextarea[tabname] = textarea;
|
activePromptTextarea[tabname] = textarea;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
registerPrompt('txt2img', 'txt2img_prompt')
|
registerPrompt('txt2img', 'txt2img_prompt');
|
||||||
registerPrompt('txt2img', 'txt2img_neg_prompt')
|
registerPrompt('txt2img', 'txt2img_neg_prompt');
|
||||||
registerPrompt('img2img', 'img2img_prompt')
|
registerPrompt('img2img', 'img2img_prompt');
|
||||||
registerPrompt('img2img', 'img2img_neg_prompt')
|
registerPrompt('img2img', 'img2img_neg_prompt');
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiLoaded(setupExtraNetworks)
|
onUiLoaded(setupExtraNetworks);
|
||||||
|
|
||||||
var re_extranet = /<([^:]+:[^:]+):[\d\.]+>/;
|
var re_extranet = /<([^:]+:[^:]+):[\d.]+>/;
|
||||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
|
||||||
|
|
||||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||||
var m = text.match(re_extranet)
|
var m = text.match(re_extranet);
|
||||||
if(! m) return false
|
var replaced = false;
|
||||||
|
var newTextareaText;
|
||||||
var partToSearch = m[1]
|
if (m) {
|
||||||
var replaced = false
|
var partToSearch = m[1];
|
||||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
|
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found) {
|
||||||
m = found.match(re_extranet);
|
m = found.match(re_extranet);
|
||||||
if(m[1] == partToSearch){
|
if (m[1] == partToSearch) {
|
||||||
replaced = true;
|
replaced = true;
|
||||||
return ""
|
return "";
|
||||||
}
|
}
|
||||||
return found;
|
return found;
|
||||||
})
|
});
|
||||||
|
} else {
|
||||||
if(replaced){
|
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
||||||
textarea.value = newTextareaText
|
if (found == text) {
|
||||||
return true;
|
replaced = true;
|
||||||
}
|
return "";
|
||||||
|
}
|
||||||
return false
|
return found;
|
||||||
}
|
});
|
||||||
|
}
|
||||||
function cardClicked(tabname, textToAdd, allowNegativePrompt){
|
|
||||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")
|
if (replaced) {
|
||||||
|
textarea.value = newTextareaText;
|
||||||
if(! tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)){
|
return true;
|
||||||
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd
|
}
|
||||||
}
|
|
||||||
|
return false;
|
||||||
updateInput(textarea)
|
}
|
||||||
}
|
|
||||||
|
function cardClicked(tabname, textToAdd, allowNegativePrompt) {
|
||||||
function saveCardPreview(event, tabname, filename){
|
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
|
||||||
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea')
|
|
||||||
var button = gradioApp().getElementById(tabname + '_save_preview')
|
if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
|
||||||
|
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
|
||||||
textarea.value = filename
|
}
|
||||||
updateInput(textarea)
|
|
||||||
|
updateInput(textarea);
|
||||||
button.click()
|
}
|
||||||
|
|
||||||
event.stopPropagation()
|
function saveCardPreview(event, tabname, filename) {
|
||||||
event.preventDefault()
|
var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea');
|
||||||
}
|
var button = gradioApp().getElementById(tabname + '_save_preview');
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event){
|
textarea.value = filename;
|
||||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
updateInput(textarea);
|
||||||
var button = event.target
|
|
||||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
button.click();
|
||||||
|
|
||||||
searchTextarea.value = text
|
event.stopPropagation();
|
||||||
updateInput(searchTextarea)
|
event.preventDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalPopup = null;
|
function extraNetworksSearchButton(tabs_id, event) {
|
||||||
var globalPopupInner = null;
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea');
|
||||||
function popup(contents){
|
var button = event.target;
|
||||||
if(! globalPopup){
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
||||||
globalPopup = document.createElement('div')
|
|
||||||
globalPopup.onclick = function(){ globalPopup.style.display = "none"; };
|
searchTextarea.value = text;
|
||||||
globalPopup.classList.add('global-popup');
|
updateInput(searchTextarea);
|
||||||
|
}
|
||||||
var close = document.createElement('div')
|
|
||||||
close.classList.add('global-popup-close');
|
var globalPopup = null;
|
||||||
close.onclick = function(){ globalPopup.style.display = "none"; };
|
var globalPopupInner = null;
|
||||||
close.title = "Close";
|
function popup(contents) {
|
||||||
globalPopup.appendChild(close)
|
if (!globalPopup) {
|
||||||
|
globalPopup = document.createElement('div');
|
||||||
globalPopupInner = document.createElement('div')
|
globalPopup.onclick = function() {
|
||||||
globalPopupInner.onclick = function(event){ event.stopPropagation(); return false; };
|
globalPopup.style.display = "none";
|
||||||
globalPopupInner.classList.add('global-popup-inner');
|
};
|
||||||
globalPopup.appendChild(globalPopupInner)
|
globalPopup.classList.add('global-popup');
|
||||||
|
|
||||||
gradioApp().appendChild(globalPopup);
|
var close = document.createElement('div');
|
||||||
}
|
close.classList.add('global-popup-close');
|
||||||
|
close.onclick = function() {
|
||||||
globalPopupInner.innerHTML = '';
|
globalPopup.style.display = "none";
|
||||||
globalPopupInner.appendChild(contents);
|
};
|
||||||
|
close.title = "Close";
|
||||||
globalPopup.style.display = "flex";
|
globalPopup.appendChild(close);
|
||||||
}
|
|
||||||
|
globalPopupInner = document.createElement('div');
|
||||||
function extraNetworksShowMetadata(text){
|
globalPopupInner.onclick = function(event) {
|
||||||
var elem = document.createElement('pre')
|
event.stopPropagation(); return false;
|
||||||
elem.classList.add('popup-metadata');
|
};
|
||||||
elem.textContent = text;
|
globalPopupInner.classList.add('global-popup-inner');
|
||||||
|
globalPopup.appendChild(globalPopupInner);
|
||||||
popup(elem);
|
|
||||||
}
|
gradioApp().appendChild(globalPopup);
|
||||||
|
}
|
||||||
function requestGet(url, data, handler, errorHandler){
|
|
||||||
var xhr = new XMLHttpRequest();
|
globalPopupInner.innerHTML = '';
|
||||||
var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
|
globalPopupInner.appendChild(contents);
|
||||||
xhr.open("GET", url + "?" + args, true);
|
|
||||||
|
globalPopup.style.display = "flex";
|
||||||
xhr.onreadystatechange = function () {
|
}
|
||||||
if (xhr.readyState === 4) {
|
|
||||||
if (xhr.status === 200) {
|
function extraNetworksShowMetadata(text) {
|
||||||
try {
|
var elem = document.createElement('pre');
|
||||||
var js = JSON.parse(xhr.responseText);
|
elem.classList.add('popup-metadata');
|
||||||
handler(js)
|
elem.textContent = text;
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
popup(elem);
|
||||||
errorHandler()
|
}
|
||||||
}
|
|
||||||
} else{
|
function requestGet(url, data, handler, errorHandler) {
|
||||||
errorHandler()
|
var xhr = new XMLHttpRequest();
|
||||||
}
|
var args = Object.keys(data).map(function(k) {
|
||||||
}
|
return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]);
|
||||||
};
|
}).join('&');
|
||||||
var js = JSON.stringify(data);
|
xhr.open("GET", url + "?" + args, true);
|
||||||
xhr.send(js);
|
|
||||||
}
|
xhr.onreadystatechange = function() {
|
||||||
|
if (xhr.readyState === 4) {
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
if (xhr.status === 200) {
|
||||||
var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
try {
|
||||||
|
var js = JSON.parse(xhr.responseText);
|
||||||
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
handler(js);
|
||||||
if(data && data.metadata){
|
} catch (error) {
|
||||||
extraNetworksShowMetadata(data.metadata)
|
console.error(error);
|
||||||
} else{
|
errorHandler();
|
||||||
showError()
|
}
|
||||||
}
|
} else {
|
||||||
}, showError)
|
errorHandler();
|
||||||
|
}
|
||||||
event.stopPropagation()
|
}
|
||||||
}
|
};
|
||||||
|
var js = JSON.stringify(data);
|
||||||
|
xhr.send(js);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
||||||
|
var showError = function() {
|
||||||
|
extraNetworksShowMetadata("there was an error getting metadata");
|
||||||
|
};
|
||||||
|
|
||||||
|
requestGet("./sd_extra_networks/metadata", {page: extraPage, item: cardName}, function(data) {
|
||||||
|
if (data && data.metadata) {
|
||||||
|
extraNetworksShowMetadata(data.metadata);
|
||||||
|
} else {
|
||||||
|
showError();
|
||||||
|
}
|
||||||
|
}, showError);
|
||||||
|
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
@ -1,33 +1,35 @@
|
|||||||
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
|
||||||
|
|
||||||
let txt2img_gallery, img2img_gallery, modal = undefined;
|
let txt2img_gallery, img2img_gallery, modal = undefined;
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function() {
|
||||||
if (!txt2img_gallery) {
|
if (!txt2img_gallery) {
|
||||||
txt2img_gallery = attachGalleryListeners("txt2img")
|
txt2img_gallery = attachGalleryListeners("txt2img");
|
||||||
}
|
}
|
||||||
if (!img2img_gallery) {
|
if (!img2img_gallery) {
|
||||||
img2img_gallery = attachGalleryListeners("img2img")
|
img2img_gallery = attachGalleryListeners("img2img");
|
||||||
}
|
}
|
||||||
if (!modal) {
|
if (!modal) {
|
||||||
modal = gradioApp().getElementById('lightboxModal')
|
modal = gradioApp().getElementById('lightboxModal');
|
||||||
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
|
modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
let modalObserver = new MutationObserver(function(mutations) {
|
let modalObserver = new MutationObserver(function(mutations) {
|
||||||
mutations.forEach(function(mutationRecord) {
|
mutations.forEach(function(mutationRecord) {
|
||||||
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
|
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;
|
||||||
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
|
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {
|
||||||
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
|
gradioApp().getElementById(selectedTab + "_generation_info_button")?.click();
|
||||||
});
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
function attachGalleryListeners(tab_name) {
|
function attachGalleryListeners(tab_name) {
|
||||||
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');
|
||||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click());
|
||||||
gallery?.addEventListener('keydown', (e) => {
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow
|
||||||
gradioApp().getElementById(tab_name+"_generation_info_button").click()
|
gradioApp().getElementById(tab_name + "_generation_info_button").click();
|
||||||
});
|
}
|
||||||
return gallery;
|
});
|
||||||
|
return gallery;
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
// mouseover tooltips for various UI elements
|
// mouseover tooltips for various UI elements
|
||||||
|
|
||||||
titles = {
|
var titles = {
|
||||||
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
|
"Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results",
|
||||||
"Sampling method": "Which algorithm to use to produce the image",
|
"Sampling method": "Which algorithm to use to produce the image",
|
||||||
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
|
||||||
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help",
|
||||||
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
|
||||||
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
|
"UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models",
|
||||||
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
|
||||||
|
|
||||||
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
"\u{1F4D0}": "Auto detect size from img2img",
|
||||||
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
|
"Batch count": "How many batches of images to create (has no impact on generation performance or VRAM usage)",
|
||||||
|
"Batch size": "How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)",
|
||||||
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
"CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results",
|
||||||
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
|
||||||
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
|
||||||
@ -40,7 +41,7 @@ titles = {
|
|||||||
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
|
||||||
|
|
||||||
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
|
||||||
|
|
||||||
"Skip": "Stop processing current image and continue processing.",
|
"Skip": "Stop processing current image and continue processing.",
|
||||||
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
"Interrupt": "Stop processing images and return any results accumulated so far.",
|
||||||
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
|
||||||
@ -66,8 +67,8 @@ titles = {
|
|||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||||
@ -96,7 +97,7 @@ titles = {
|
|||||||
"Add difference": "Result = A + (B - C) * M",
|
"Add difference": "Result = A + (B - C) * M",
|
||||||
"No interpolation": "Result = A",
|
"No interpolation": "Result = A",
|
||||||
|
|
||||||
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
|
||||||
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
"Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
|
||||||
|
|
||||||
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
|
||||||
@ -113,38 +114,55 @@ titles = {
|
|||||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
|
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
|
||||||
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
||||||
|
};
|
||||||
|
|
||||||
|
function updateTooltipForSpan(span) {
|
||||||
|
if (span.title) return; // already has a title
|
||||||
|
|
||||||
|
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
||||||
|
|
||||||
|
if (!tooltip) {
|
||||||
|
tooltip = localization[titles[span.value]] || titles[span.value];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tooltip) {
|
||||||
|
for (const c of span.classList) {
|
||||||
|
if (c in titles) {
|
||||||
|
tooltip = localization[titles[c]] || titles[c];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tooltip) {
|
||||||
|
span.title = tooltip;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function updateTooltipForSelect(select) {
|
||||||
|
if (select.onchange != null) return;
|
||||||
|
|
||||||
onUiUpdate(function(){
|
select.onchange = function() {
|
||||||
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
||||||
if (span.title) return; // already has a title
|
};
|
||||||
|
}
|
||||||
|
|
||||||
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1};
|
||||||
|
|
||||||
if(!tooltip){
|
onUiUpdate(function(m) {
|
||||||
tooltip = localization[titles[span.value]] || titles[span.value];
|
m.forEach(function(record) {
|
||||||
}
|
record.addedNodes.forEach(function(node) {
|
||||||
|
if (observedTooltipElements[node.tagName]) {
|
||||||
|
updateTooltipForSpan(node);
|
||||||
|
}
|
||||||
|
if (node.tagName == "SELECT") {
|
||||||
|
updateTooltipForSelect(node);
|
||||||
|
}
|
||||||
|
|
||||||
if(!tooltip){
|
if (node.querySelectorAll) {
|
||||||
for (const c of span.classList) {
|
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan);
|
||||||
if (c in titles) {
|
node.querySelectorAll('select').forEach(updateTooltipForSelect);
|
||||||
tooltip = localization[titles[c]] || titles[c];
|
}
|
||||||
break;
|
});
|
||||||
}
|
});
|
||||||
}
|
});
|
||||||
}
|
|
||||||
|
|
||||||
if(tooltip){
|
|
||||||
span.title = tooltip;
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
gradioApp().querySelectorAll('select').forEach(function(select){
|
|
||||||
if (select.onchange != null) return;
|
|
||||||
|
|
||||||
select.onchange = function(){
|
|
||||||
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
|
|
||||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {
|
||||||
function setInactive(elem, inactive){
|
function setInactive(elem, inactive) {
|
||||||
elem.classList.toggle('inactive', !!inactive)
|
elem.classList.toggle('inactive', !!inactive);
|
||||||
}
|
}
|
||||||
|
|
||||||
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');
|
||||||
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');
|
||||||
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');
|
||||||
|
|
||||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "";
|
||||||
|
|
||||||
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0)
|
setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);
|
||||||
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0)
|
setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);
|
||||||
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0)
|
setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);
|
||||||
|
|
||||||
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];
|
||||||
}
|
}
|
||||||
|
@ -4,17 +4,16 @@
|
|||||||
*/
|
*/
|
||||||
function imageMaskResize() {
|
function imageMaskResize() {
|
||||||
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
||||||
if ( ! canvases.length ) {
|
if (!canvases.length) {
|
||||||
canvases_fixed = false; // TODO: this is unused..?
|
window.removeEventListener('resize', imageMaskResize);
|
||||||
window.removeEventListener( 'resize', imageMaskResize );
|
return;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const wrapper = canvases[0].closest('.touch-none');
|
const wrapper = canvases[0].closest('.touch-none');
|
||||||
const previewImage = wrapper.previousElementSibling;
|
const previewImage = wrapper.previousElementSibling;
|
||||||
|
|
||||||
if ( ! previewImage.complete ) {
|
if (!previewImage.complete) {
|
||||||
previewImage.addEventListener( 'load', imageMaskResize);
|
previewImage.addEventListener('load', imageMaskResize);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,15 +23,15 @@ function imageMaskResize() {
|
|||||||
const nh = previewImage.naturalHeight;
|
const nh = previewImage.naturalHeight;
|
||||||
const portrait = nh > nw;
|
const portrait = nh > nw;
|
||||||
|
|
||||||
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);
|
||||||
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);
|
||||||
|
|
||||||
wrapper.style.width = `${wW}px`;
|
wrapper.style.width = `${wW}px`;
|
||||||
wrapper.style.height = `${wH}px`;
|
wrapper.style.height = `${wH}px`;
|
||||||
wrapper.style.left = `0px`;
|
wrapper.style.left = `0px`;
|
||||||
wrapper.style.top = `0px`;
|
wrapper.style.top = `0px`;
|
||||||
|
|
||||||
canvases.forEach( c => {
|
canvases.forEach(c => {
|
||||||
c.style.width = c.style.height = '';
|
c.style.width = c.style.height = '';
|
||||||
c.style.maxWidth = '100%';
|
c.style.maxWidth = '100%';
|
||||||
c.style.maxHeight = '100%';
|
c.style.maxHeight = '100%';
|
||||||
@ -41,4 +40,4 @@ function imageMaskResize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(imageMaskResize);
|
onUiUpdate(imageMaskResize);
|
||||||
window.addEventListener( 'resize', imageMaskResize);
|
window.addEventListener('resize', imageMaskResize);
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
window.onload = (function(){
|
window.onload = (function() {
|
||||||
window.addEventListener('drop', e => {
|
window.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||||
@ -10,7 +10,7 @@ window.onload = (function(){
|
|||||||
const imgParent = gradioApp().getElementById(prompt_target);
|
const imgParent = gradioApp().getElementById(prompt_target);
|
||||||
const files = e.dataTransfer.files;
|
const files = e.dataTransfer.files;
|
||||||
const fileInput = imgParent.querySelector('input[type="file"]');
|
const fileInput = imgParent.querySelector('input[type="file"]');
|
||||||
if ( fileInput ) {
|
if (fileInput) {
|
||||||
fileInput.files = files;
|
fileInput.files = files;
|
||||||
fileInput.dispatchEvent(new Event('change'));
|
fileInput.dispatchEvent(new Event('change'));
|
||||||
}
|
}
|
||||||
|
@ -5,24 +5,24 @@ function closeModal() {
|
|||||||
|
|
||||||
function showModal(event) {
|
function showModal(event) {
|
||||||
const source = event.target || event.srcElement;
|
const source = event.target || event.srcElement;
|
||||||
const modalImage = gradioApp().getElementById("modalImage")
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const lb = gradioApp().getElementById("lightboxModal")
|
const lb = gradioApp().getElementById("lightboxModal");
|
||||||
modalImage.src = source.src
|
modalImage.src = source.src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
lb.style.setProperty('background-image', 'url(' + source.src + ')');
|
||||||
}
|
}
|
||||||
lb.style.display = "flex";
|
lb.style.display = "flex";
|
||||||
lb.focus()
|
lb.focus();
|
||||||
|
|
||||||
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
|
||||||
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
const tabImg2Img = gradioApp().getElementById("tab_img2img");
|
||||||
// show the save button in modal only on txt2img or img2img tabs
|
// show the save button in modal only on txt2img or img2img tabs
|
||||||
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
|
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
|
||||||
gradioApp().getElementById("modal_save").style.display = "inline"
|
gradioApp().getElementById("modal_save").style.display = "inline";
|
||||||
} else {
|
} else {
|
||||||
gradioApp().getElementById("modal_save").style.display = "none"
|
gradioApp().getElementById("modal_save").style.display = "none";
|
||||||
}
|
}
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function negmod(n, m) {
|
function negmod(n, m) {
|
||||||
@ -30,14 +30,15 @@ function negmod(n, m) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function updateOnBackgroundChange() {
|
function updateOnBackgroundChange() {
|
||||||
const modalImage = gradioApp().getElementById("modalImage")
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
if (modalImage && modalImage.offsetParent) {
|
if (modalImage && modalImage.offsetParent) {
|
||||||
let currentButton = selected_gallery_button();
|
let currentButton = selected_gallery_button();
|
||||||
|
|
||||||
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||||
modalImage.src = currentButton.children[0].src;
|
modalImage.src = currentButton.children[0].src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
|
modal.style.setProperty('background-image', `url(${modalImage.src})`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,108 +50,109 @@ function modalImageSwitch(offset) {
|
|||||||
if (galleryButtons.length > 1) {
|
if (galleryButtons.length > 1) {
|
||||||
var currentButton = selected_gallery_button();
|
var currentButton = selected_gallery_button();
|
||||||
|
|
||||||
var result = -1
|
var result = -1;
|
||||||
galleryButtons.forEach(function(v, i) {
|
galleryButtons.forEach(function(v, i) {
|
||||||
if (v == currentButton) {
|
if (v == currentButton) {
|
||||||
result = i
|
result = i;
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
|
|
||||||
if (result != -1) {
|
if (result != -1) {
|
||||||
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)];
|
||||||
nextButton.click()
|
nextButton.click();
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
modalImage.src = nextButton.children[0].src;
|
modalImage.src = nextButton.children[0].src;
|
||||||
if (modalImage.style.display === 'none') {
|
if (modalImage.style.display === 'none') {
|
||||||
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
modal.style.setProperty('background-image', `url(${modalImage.src})`);
|
||||||
}
|
}
|
||||||
setTimeout(function() {
|
setTimeout(function() {
|
||||||
modal.focus()
|
modal.focus();
|
||||||
}, 10)
|
}, 10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveImage(){
|
function saveImage() {
|
||||||
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
|
const tabTxt2Img = gradioApp().getElementById("tab_txt2img");
|
||||||
const tabImg2Img = gradioApp().getElementById("tab_img2img")
|
const tabImg2Img = gradioApp().getElementById("tab_img2img");
|
||||||
const saveTxt2Img = "save_txt2img"
|
const saveTxt2Img = "save_txt2img";
|
||||||
const saveImg2Img = "save_img2img"
|
const saveImg2Img = "save_img2img";
|
||||||
if (tabTxt2Img.style.display != "none") {
|
if (tabTxt2Img.style.display != "none") {
|
||||||
gradioApp().getElementById(saveTxt2Img).click()
|
gradioApp().getElementById(saveTxt2Img).click();
|
||||||
} else if (tabImg2Img.style.display != "none") {
|
} else if (tabImg2Img.style.display != "none") {
|
||||||
gradioApp().getElementById(saveImg2Img).click()
|
gradioApp().getElementById(saveImg2Img).click();
|
||||||
} else {
|
} else {
|
||||||
console.error("missing implementation for saving modal of this type")
|
console.error("missing implementation for saving modal of this type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalSaveImage(event) {
|
function modalSaveImage(event) {
|
||||||
saveImage()
|
saveImage();
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalNextImage(event) {
|
function modalNextImage(event) {
|
||||||
modalImageSwitch(1)
|
modalImageSwitch(1);
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalPrevImage(event) {
|
function modalPrevImage(event) {
|
||||||
modalImageSwitch(-1)
|
modalImageSwitch(-1);
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalKeyHandler(event) {
|
function modalKeyHandler(event) {
|
||||||
switch (event.key) {
|
switch (event.key) {
|
||||||
case "s":
|
case "s":
|
||||||
saveImage()
|
saveImage();
|
||||||
break;
|
break;
|
||||||
case "ArrowLeft":
|
case "ArrowLeft":
|
||||||
modalPrevImage(event)
|
modalPrevImage(event);
|
||||||
break;
|
break;
|
||||||
case "ArrowRight":
|
case "ArrowRight":
|
||||||
modalNextImage(event)
|
modalNextImage(event);
|
||||||
break;
|
break;
|
||||||
case "Escape":
|
case "Escape":
|
||||||
closeModal();
|
closeModal();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function setupImageForLightbox(e) {
|
function setupImageForLightbox(e) {
|
||||||
if (e.dataset.modded)
|
if (e.dataset.modded) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
e.dataset.modded = true;
|
e.dataset.modded = true;
|
||||||
e.style.cursor='pointer'
|
e.style.cursor = 'pointer';
|
||||||
e.style.userSelect='none'
|
e.style.userSelect = 'none';
|
||||||
|
|
||||||
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
|
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1;
|
||||||
|
|
||||||
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
// For Firefox, listening on click first switched to next image then shows the lightbox.
|
||||||
// If you know how to fix this without switching to mousedown event, please.
|
// If you know how to fix this without switching to mousedown event, please.
|
||||||
// For other browsers the event is click to make it possiblr to drag picture.
|
// For other browsers the event is click to make it possiblr to drag picture.
|
||||||
var event = isFirefox ? 'mousedown' : 'click'
|
var event = isFirefox ? 'mousedown' : 'click';
|
||||||
|
|
||||||
e.addEventListener(event, function (evt) {
|
e.addEventListener(event, function(evt) {
|
||||||
if(!opts.js_modal_lightbox || evt.button != 0) return;
|
if (!opts.js_modal_lightbox || evt.button != 0) return;
|
||||||
|
|
||||||
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
|
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);
|
||||||
evt.preventDefault()
|
evt.preventDefault();
|
||||||
showModal(evt)
|
showModal(evt);
|
||||||
}, true);
|
}, true);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomSet(modalImage, enable) {
|
function modalZoomSet(modalImage, enable) {
|
||||||
if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
|
if (modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomToggle(event) {
|
function modalZoomToggle(event) {
|
||||||
var modalImage = gradioApp().getElementById("modalImage");
|
var modalImage = gradioApp().getElementById("modalImage");
|
||||||
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'));
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalTileImageToggle(event) {
|
function modalTileImageToggle(event) {
|
||||||
@ -159,99 +161,93 @@ function modalTileImageToggle(event) {
|
|||||||
const isTiling = modalImage.style.display === 'none';
|
const isTiling = modalImage.style.display === 'none';
|
||||||
if (isTiling) {
|
if (isTiling) {
|
||||||
modalImage.style.display = 'block';
|
modalImage.style.display = 'block';
|
||||||
modal.style.setProperty('background-image', 'none')
|
modal.style.setProperty('background-image', 'none');
|
||||||
} else {
|
} else {
|
||||||
modalImage.style.display = 'none';
|
modalImage.style.display = 'none';
|
||||||
modal.style.setProperty('background-image', `url(${modalImage.src})`)
|
modal.style.setProperty('background-image', `url(${modalImage.src})`);
|
||||||
}
|
}
|
||||||
|
|
||||||
event.stopPropagation()
|
event.stopPropagation();
|
||||||
}
|
|
||||||
|
|
||||||
function galleryImageHandler(e) {
|
|
||||||
//if (e && e.parentElement.tagName == 'BUTTON') {
|
|
||||||
e.onclick = showGalleryImage;
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onUiUpdate(function() {
|
||||||
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');
|
||||||
if (fullImg_preview != null) {
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(setupImageForLightbox);
|
fullImg_preview.forEach(setupImageForLightbox);
|
||||||
}
|
}
|
||||||
updateOnBackgroundChange();
|
updateOnBackgroundChange();
|
||||||
})
|
});
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
//const modalFragment = document.createDocumentFragment();
|
//const modalFragment = document.createDocumentFragment();
|
||||||
const modal = document.createElement('div')
|
const modal = document.createElement('div');
|
||||||
modal.onclick = closeModal;
|
modal.onclick = closeModal;
|
||||||
modal.id = "lightboxModal";
|
modal.id = "lightboxModal";
|
||||||
modal.tabIndex = 0
|
modal.tabIndex = 0;
|
||||||
modal.addEventListener('keydown', modalKeyHandler, true)
|
modal.addEventListener('keydown', modalKeyHandler, true);
|
||||||
|
|
||||||
const modalControls = document.createElement('div')
|
const modalControls = document.createElement('div');
|
||||||
modalControls.className = 'modalControls gradio-container';
|
modalControls.className = 'modalControls gradio-container';
|
||||||
modal.append(modalControls);
|
modal.append(modalControls);
|
||||||
|
|
||||||
const modalZoom = document.createElement('span')
|
const modalZoom = document.createElement('span');
|
||||||
modalZoom.className = 'modalZoom cursor';
|
modalZoom.className = 'modalZoom cursor';
|
||||||
modalZoom.innerHTML = '⤡'
|
modalZoom.innerHTML = '⤡';
|
||||||
modalZoom.addEventListener('click', modalZoomToggle, true)
|
modalZoom.addEventListener('click', modalZoomToggle, true);
|
||||||
modalZoom.title = "Toggle zoomed view";
|
modalZoom.title = "Toggle zoomed view";
|
||||||
modalControls.appendChild(modalZoom)
|
modalControls.appendChild(modalZoom);
|
||||||
|
|
||||||
const modalTileImage = document.createElement('span')
|
const modalTileImage = document.createElement('span');
|
||||||
modalTileImage.className = 'modalTileImage cursor';
|
modalTileImage.className = 'modalTileImage cursor';
|
||||||
modalTileImage.innerHTML = '⊞'
|
modalTileImage.innerHTML = '⊞';
|
||||||
modalTileImage.addEventListener('click', modalTileImageToggle, true)
|
modalTileImage.addEventListener('click', modalTileImageToggle, true);
|
||||||
modalTileImage.title = "Preview tiling";
|
modalTileImage.title = "Preview tiling";
|
||||||
modalControls.appendChild(modalTileImage)
|
modalControls.appendChild(modalTileImage);
|
||||||
|
|
||||||
const modalSave = document.createElement("span")
|
const modalSave = document.createElement("span");
|
||||||
modalSave.className = "modalSave cursor"
|
modalSave.className = "modalSave cursor";
|
||||||
modalSave.id = "modal_save"
|
modalSave.id = "modal_save";
|
||||||
modalSave.innerHTML = "🖫"
|
modalSave.innerHTML = "🖫";
|
||||||
modalSave.addEventListener("click", modalSaveImage, true)
|
modalSave.addEventListener("click", modalSaveImage, true);
|
||||||
modalSave.title = "Save Image(s)"
|
modalSave.title = "Save Image(s)";
|
||||||
modalControls.appendChild(modalSave)
|
modalControls.appendChild(modalSave);
|
||||||
|
|
||||||
const modalClose = document.createElement('span')
|
const modalClose = document.createElement('span');
|
||||||
modalClose.className = 'modalClose cursor';
|
modalClose.className = 'modalClose cursor';
|
||||||
modalClose.innerHTML = '×'
|
modalClose.innerHTML = '×';
|
||||||
modalClose.onclick = closeModal;
|
modalClose.onclick = closeModal;
|
||||||
modalClose.title = "Close image viewer";
|
modalClose.title = "Close image viewer";
|
||||||
modalControls.appendChild(modalClose)
|
modalControls.appendChild(modalClose);
|
||||||
|
|
||||||
const modalImage = document.createElement('img')
|
const modalImage = document.createElement('img');
|
||||||
modalImage.id = 'modalImage';
|
modalImage.id = 'modalImage';
|
||||||
modalImage.onclick = closeModal;
|
modalImage.onclick = closeModal;
|
||||||
modalImage.tabIndex = 0
|
modalImage.tabIndex = 0;
|
||||||
modalImage.addEventListener('keydown', modalKeyHandler, true)
|
modalImage.addEventListener('keydown', modalKeyHandler, true);
|
||||||
modal.appendChild(modalImage)
|
modal.appendChild(modalImage);
|
||||||
|
|
||||||
const modalPrev = document.createElement('a')
|
const modalPrev = document.createElement('a');
|
||||||
modalPrev.className = 'modalPrev';
|
modalPrev.className = 'modalPrev';
|
||||||
modalPrev.innerHTML = '❮'
|
modalPrev.innerHTML = '❮';
|
||||||
modalPrev.tabIndex = 0
|
modalPrev.tabIndex = 0;
|
||||||
modalPrev.addEventListener('click', modalPrevImage, true);
|
modalPrev.addEventListener('click', modalPrevImage, true);
|
||||||
modalPrev.addEventListener('keydown', modalKeyHandler, true)
|
modalPrev.addEventListener('keydown', modalKeyHandler, true);
|
||||||
modal.appendChild(modalPrev)
|
modal.appendChild(modalPrev);
|
||||||
|
|
||||||
const modalNext = document.createElement('a')
|
const modalNext = document.createElement('a');
|
||||||
modalNext.className = 'modalNext';
|
modalNext.className = 'modalNext';
|
||||||
modalNext.innerHTML = '❯'
|
modalNext.innerHTML = '❯';
|
||||||
modalNext.tabIndex = 0
|
modalNext.tabIndex = 0;
|
||||||
modalNext.addEventListener('click', modalNextImage, true);
|
modalNext.addEventListener('click', modalNextImage, true);
|
||||||
modalNext.addEventListener('keydown', modalKeyHandler, true)
|
modalNext.addEventListener('keydown', modalKeyHandler, true);
|
||||||
|
|
||||||
modal.appendChild(modalNext)
|
modal.appendChild(modalNext);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
gradioApp().appendChild(modal);
|
gradioApp().appendChild(modal);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
gradioApp().body.appendChild(modal);
|
gradioApp().body.appendChild(modal);
|
||||||
}
|
}
|
||||||
|
|
||||||
document.body.appendChild(modal);
|
document.body.appendChild(modal);
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
window.addEventListener('gamepadconnected', (e) => {
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
const index = e.gamepad.index;
|
const index = e.gamepad.index;
|
||||||
let isWaiting = false;
|
let isWaiting = false;
|
||||||
setInterval(async () => {
|
setInterval(async() => {
|
||||||
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
const gamepad = navigator.getGamepads()[index];
|
const gamepad = navigator.getGamepads()[index];
|
||||||
const xValue = gamepad.axes[0];
|
const xValue = gamepad.axes[0];
|
||||||
@ -14,7 +14,7 @@ window.addEventListener('gamepadconnected', (e) => {
|
|||||||
}
|
}
|
||||||
if (isWaiting) {
|
if (isWaiting) {
|
||||||
await sleepUntil(() => {
|
await sleepUntil(() => {
|
||||||
const xValue = navigator.getGamepads()[index].axes[0]
|
const xValue = navigator.getGamepads()[index].axes[0];
|
||||||
if (xValue < 0.3 && xValue > -0.3) {
|
if (xValue < 0.3 && xValue > -0.3) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1,177 +1,176 @@
|
|||||||
|
|
||||||
// localization = {} -- the dict with translations is created by the backend
|
// localization = {} -- the dict with translations is created by the backend
|
||||||
|
|
||||||
ignore_ids_for_localization={
|
var ignore_ids_for_localization = {
|
||||||
setting_sd_hypernetwork: 'OPTION',
|
setting_sd_hypernetwork: 'OPTION',
|
||||||
setting_sd_model_checkpoint: 'OPTION',
|
setting_sd_model_checkpoint: 'OPTION',
|
||||||
setting_realesrgan_enabled_models: 'OPTION',
|
modelmerger_primary_model_name: 'OPTION',
|
||||||
modelmerger_primary_model_name: 'OPTION',
|
modelmerger_secondary_model_name: 'OPTION',
|
||||||
modelmerger_secondary_model_name: 'OPTION',
|
modelmerger_tertiary_model_name: 'OPTION',
|
||||||
modelmerger_tertiary_model_name: 'OPTION',
|
train_embedding: 'OPTION',
|
||||||
train_embedding: 'OPTION',
|
train_hypernetwork: 'OPTION',
|
||||||
train_hypernetwork: 'OPTION',
|
txt2img_styles: 'OPTION',
|
||||||
txt2img_styles: 'OPTION',
|
img2img_styles: 'OPTION',
|
||||||
img2img_styles: 'OPTION',
|
setting_random_artist_categories: 'SPAN',
|
||||||
setting_random_artist_categories: 'SPAN',
|
setting_face_restoration_model: 'SPAN',
|
||||||
setting_face_restoration_model: 'SPAN',
|
setting_realesrgan_enabled_models: 'SPAN',
|
||||||
setting_realesrgan_enabled_models: 'SPAN',
|
extras_upscaler_1: 'SPAN',
|
||||||
extras_upscaler_1: 'SPAN',
|
extras_upscaler_2: 'SPAN',
|
||||||
extras_upscaler_2: 'SPAN',
|
};
|
||||||
}
|
|
||||||
|
var re_num = /^[.\d]+$/;
|
||||||
re_num = /^[\.\d]+$/
|
var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
|
||||||
re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
|
||||||
|
var original_lines = {};
|
||||||
original_lines = {}
|
var translated_lines = {};
|
||||||
translated_lines = {}
|
|
||||||
|
function hasLocalization() {
|
||||||
function hasLocalization() {
|
return window.localization && Object.keys(window.localization).length > 0;
|
||||||
return window.localization && Object.keys(window.localization).length > 0;
|
}
|
||||||
}
|
|
||||||
|
function textNodesUnder(el) {
|
||||||
function textNodesUnder(el){
|
var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);
|
||||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
while ((n = walk.nextNode())) a.push(n);
|
||||||
while(n=walk.nextNode()) a.push(n);
|
return a;
|
||||||
return a;
|
}
|
||||||
}
|
|
||||||
|
function canBeTranslated(node, text) {
|
||||||
function canBeTranslated(node, text){
|
if (!text) return false;
|
||||||
if(! text) return false;
|
if (!node.parentElement) return false;
|
||||||
if(! node.parentElement) return false;
|
|
||||||
|
var parentType = node.parentElement.nodeName;
|
||||||
var parentType = node.parentElement.nodeName
|
if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;
|
||||||
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
|
||||||
|
if (parentType == 'OPTION' || parentType == 'SPAN') {
|
||||||
if (parentType=='OPTION' || parentType=='SPAN'){
|
var pnode = node;
|
||||||
var pnode = node
|
for (var level = 0; level < 4; level++) {
|
||||||
for(var level=0; level<4; level++){
|
pnode = pnode.parentElement;
|
||||||
pnode = pnode.parentElement
|
if (!pnode) break;
|
||||||
if(! pnode) break;
|
|
||||||
|
if (ignore_ids_for_localization[pnode.id] == parentType) return false;
|
||||||
if(ignore_ids_for_localization[pnode.id] == parentType) return false;
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
if (re_num.test(text)) return false;
|
||||||
if(re_num.test(text)) return false;
|
if (re_emoji.test(text)) return false;
|
||||||
if(re_emoji.test(text)) return false;
|
return true;
|
||||||
return true
|
}
|
||||||
}
|
|
||||||
|
function getTranslation(text) {
|
||||||
function getTranslation(text){
|
if (!text) return undefined;
|
||||||
if(! text) return undefined
|
|
||||||
|
if (translated_lines[text] === undefined) {
|
||||||
if(translated_lines[text] === undefined){
|
original_lines[text] = 1;
|
||||||
original_lines[text] = 1
|
}
|
||||||
}
|
|
||||||
|
var tl = localization[text];
|
||||||
tl = localization[text]
|
if (tl !== undefined) {
|
||||||
if(tl !== undefined){
|
translated_lines[tl] = 1;
|
||||||
translated_lines[tl] = 1
|
}
|
||||||
}
|
|
||||||
|
return tl;
|
||||||
return tl
|
}
|
||||||
}
|
|
||||||
|
function processTextNode(node) {
|
||||||
function processTextNode(node){
|
var text = node.textContent.trim();
|
||||||
var text = node.textContent.trim()
|
|
||||||
|
if (!canBeTranslated(node, text)) return;
|
||||||
if(! canBeTranslated(node, text)) return
|
|
||||||
|
var tl = getTranslation(text);
|
||||||
tl = getTranslation(text)
|
if (tl !== undefined) {
|
||||||
if(tl !== undefined){
|
node.textContent = tl;
|
||||||
node.textContent = tl
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
function processNode(node) {
|
||||||
function processNode(node){
|
if (node.nodeType == 3) {
|
||||||
if(node.nodeType == 3){
|
processTextNode(node);
|
||||||
processTextNode(node)
|
return;
|
||||||
return
|
}
|
||||||
}
|
|
||||||
|
if (node.title) {
|
||||||
if(node.title){
|
let tl = getTranslation(node.title);
|
||||||
tl = getTranslation(node.title)
|
if (tl !== undefined) {
|
||||||
if(tl !== undefined){
|
node.title = tl;
|
||||||
node.title = tl
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
if (node.placeholder) {
|
||||||
if(node.placeholder){
|
let tl = getTranslation(node.placeholder);
|
||||||
tl = getTranslation(node.placeholder)
|
if (tl !== undefined) {
|
||||||
if(tl !== undefined){
|
node.placeholder = tl;
|
||||||
node.placeholder = tl
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
textNodesUnder(node).forEach(function(node) {
|
||||||
textNodesUnder(node).forEach(function(node){
|
processTextNode(node);
|
||||||
processTextNode(node)
|
});
|
||||||
})
|
}
|
||||||
}
|
|
||||||
|
function dumpTranslations() {
|
||||||
function dumpTranslations(){
|
if (!hasLocalization()) {
|
||||||
if(!hasLocalization()) {
|
// If we don't have any localization,
|
||||||
// If we don't have any localization,
|
// we will not have traversed the app to find
|
||||||
// we will not have traversed the app to find
|
// original_lines, so do that now.
|
||||||
// original_lines, so do that now.
|
processNode(gradioApp());
|
||||||
processNode(gradioApp());
|
}
|
||||||
}
|
var dumped = {};
|
||||||
var dumped = {}
|
if (localization.rtl) {
|
||||||
if (localization.rtl) {
|
dumped.rtl = true;
|
||||||
dumped.rtl = true;
|
}
|
||||||
}
|
|
||||||
|
for (const text in original_lines) {
|
||||||
for (const text in original_lines) {
|
if (dumped[text] !== undefined) continue;
|
||||||
if(dumped[text] !== undefined) continue;
|
dumped[text] = localization[text] || text;
|
||||||
dumped[text] = localization[text] || text;
|
}
|
||||||
}
|
|
||||||
|
return dumped;
|
||||||
return dumped;
|
}
|
||||||
}
|
|
||||||
|
function download_localization() {
|
||||||
function download_localization() {
|
var text = JSON.stringify(dumpTranslations(), null, 4);
|
||||||
var text = JSON.stringify(dumpTranslations(), null, 4)
|
|
||||||
|
var element = document.createElement('a');
|
||||||
var element = document.createElement('a');
|
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||||
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
element.setAttribute('download', "localization.json");
|
||||||
element.setAttribute('download', "localization.json");
|
element.style.display = 'none';
|
||||||
element.style.display = 'none';
|
document.body.appendChild(element);
|
||||||
document.body.appendChild(element);
|
|
||||||
|
element.click();
|
||||||
element.click();
|
|
||||||
|
document.body.removeChild(element);
|
||||||
document.body.removeChild(element);
|
}
|
||||||
}
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
document.addEventListener("DOMContentLoaded", function () {
|
if (!hasLocalization()) {
|
||||||
if (!hasLocalization()) {
|
return;
|
||||||
return;
|
}
|
||||||
}
|
|
||||||
|
onUiUpdate(function(m) {
|
||||||
onUiUpdate(function (m) {
|
m.forEach(function(mutation) {
|
||||||
m.forEach(function (mutation) {
|
mutation.addedNodes.forEach(function(node) {
|
||||||
mutation.addedNodes.forEach(function (node) {
|
processNode(node);
|
||||||
processNode(node)
|
});
|
||||||
})
|
});
|
||||||
});
|
});
|
||||||
})
|
|
||||||
|
processNode(gradioApp());
|
||||||
processNode(gradioApp())
|
|
||||||
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
mutations.forEach(mutation => {
|
||||||
mutations.forEach(mutation => {
|
mutation.addedNodes.forEach(node => {
|
||||||
mutation.addedNodes.forEach(node => {
|
if (node.tagName === 'STYLE') {
|
||||||
if (node.tagName === 'STYLE') {
|
observer.disconnect();
|
||||||
observer.disconnect();
|
|
||||||
|
for (const x of node.sheet.rules) { // find all rtl media rules
|
||||||
for (const x of node.sheet.rules) { // find all rtl media rules
|
if (Array.from(x.media || []).includes('rtl')) {
|
||||||
if (Array.from(x.media || []).includes('rtl')) {
|
x.media.appendMedium('all'); // enable them
|
||||||
x.media.appendMedium('all'); // enable them
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
})
|
});
|
||||||
});
|
})).observe(gradioApp(), {childList: true});
|
||||||
})).observe(gradioApp(), { childList: true });
|
}
|
||||||
}
|
});
|
||||||
})
|
|
||||||
|
@ -4,14 +4,14 @@ let lastHeadImg = null;
|
|||||||
|
|
||||||
let notificationButton = null;
|
let notificationButton = null;
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function() {
|
||||||
if(notificationButton == null){
|
if (notificationButton == null) {
|
||||||
notificationButton = gradioApp().getElementById('request_notifications')
|
notificationButton = gradioApp().getElementById('request_notifications');
|
||||||
|
|
||||||
if(notificationButton != null){
|
if (notificationButton != null) {
|
||||||
notificationButton.addEventListener('click', () => {
|
notificationButton.addEventListener('click', () => {
|
||||||
void Notification.requestPermission();
|
void Notification.requestPermission();
|
||||||
},true);
|
}, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ onUiUpdate(function(){
|
|||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
notification.onclick = function(_){
|
notification.onclick = function(_) {
|
||||||
parent.focus();
|
parent.focus();
|
||||||
this.close();
|
this.close();
|
||||||
};
|
};
|
||||||
|
@ -1,29 +1,29 @@
|
|||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
|
|
||||||
function rememberGallerySelection(){
|
function rememberGallerySelection() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGallerySelectedIndex(){
|
function getGallerySelectedIndex() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function request(url, data, handler, errorHandler){
|
function request(url, data, handler, errorHandler) {
|
||||||
var xhr = new XMLHttpRequest();
|
var xhr = new XMLHttpRequest();
|
||||||
xhr.open("POST", url, true);
|
xhr.open("POST", url, true);
|
||||||
xhr.setRequestHeader("Content-Type", "application/json");
|
xhr.setRequestHeader("Content-Type", "application/json");
|
||||||
xhr.onreadystatechange = function () {
|
xhr.onreadystatechange = function() {
|
||||||
if (xhr.readyState === 4) {
|
if (xhr.readyState === 4) {
|
||||||
if (xhr.status === 200) {
|
if (xhr.status === 200) {
|
||||||
try {
|
try {
|
||||||
var js = JSON.parse(xhr.responseText);
|
var js = JSON.parse(xhr.responseText);
|
||||||
handler(js)
|
handler(js);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
errorHandler()
|
errorHandler();
|
||||||
}
|
}
|
||||||
} else{
|
} else {
|
||||||
errorHandler()
|
errorHandler();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -31,147 +31,147 @@ function request(url, data, handler, errorHandler){
|
|||||||
xhr.send(js);
|
xhr.send(js);
|
||||||
}
|
}
|
||||||
|
|
||||||
function pad2(x){
|
function pad2(x) {
|
||||||
return x<10 ? '0'+x : x
|
return x < 10 ? '0' + x : x;
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatTime(secs){
|
function formatTime(secs) {
|
||||||
if(secs > 3600){
|
if (secs > 3600) {
|
||||||
return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
|
return pad2(Math.floor(secs / 60 / 60)) + ":" + pad2(Math.floor(secs / 60) % 60) + ":" + pad2(Math.floor(secs) % 60);
|
||||||
} else if(secs > 60){
|
} else if (secs > 60) {
|
||||||
return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
|
return pad2(Math.floor(secs / 60)) + ":" + pad2(Math.floor(secs) % 60);
|
||||||
} else{
|
} else {
|
||||||
return Math.floor(secs) + "s"
|
return Math.floor(secs) + "s";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function setTitle(progress){
|
function setTitle(progress) {
|
||||||
var title = 'Stable Diffusion'
|
var title = 'Stable Diffusion';
|
||||||
|
|
||||||
if(opts.show_progress_in_title && progress){
|
if (opts.show_progress_in_title && progress) {
|
||||||
title = '[' + progress.trim() + '] ' + title;
|
title = '[' + progress.trim() + '] ' + title;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(document.title != title){
|
if (document.title != title) {
|
||||||
document.title = title;
|
document.title = title;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function randomId(){
|
function randomId() {
|
||||||
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
|
return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + ")";
|
||||||
}
|
}
|
||||||
|
|
||||||
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
||||||
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
||||||
// calls onProgress every time there is a progress update
|
// calls onProgress every time there is a progress update
|
||||||
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
|
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {
|
||||||
var dateStart = new Date()
|
var dateStart = new Date();
|
||||||
var wasEverActive = false
|
var wasEverActive = false;
|
||||||
var parentProgressbar = progressbarContainer.parentNode
|
var parentProgressbar = progressbarContainer.parentNode;
|
||||||
var parentGallery = gallery ? gallery.parentNode : null
|
var parentGallery = gallery ? gallery.parentNode : null;
|
||||||
|
|
||||||
var divProgress = document.createElement('div')
|
var divProgress = document.createElement('div');
|
||||||
divProgress.className='progressDiv'
|
divProgress.className = 'progressDiv';
|
||||||
divProgress.style.display = opts.show_progressbar ? "block" : "none"
|
divProgress.style.display = opts.show_progressbar ? "block" : "none";
|
||||||
var divInner = document.createElement('div')
|
var divInner = document.createElement('div');
|
||||||
divInner.className='progress'
|
divInner.className = 'progress';
|
||||||
|
|
||||||
divProgress.appendChild(divInner)
|
divProgress.appendChild(divInner);
|
||||||
parentProgressbar.insertBefore(divProgress, progressbarContainer)
|
parentProgressbar.insertBefore(divProgress, progressbarContainer);
|
||||||
|
|
||||||
if(parentGallery){
|
if (parentGallery) {
|
||||||
var livePreview = document.createElement('div')
|
var livePreview = document.createElement('div');
|
||||||
livePreview.className='livePreview'
|
livePreview.className = 'livePreview';
|
||||||
parentGallery.insertBefore(livePreview, gallery)
|
parentGallery.insertBefore(livePreview, gallery);
|
||||||
}
|
}
|
||||||
|
|
||||||
var removeProgressBar = function(){
|
var removeProgressBar = function() {
|
||||||
setTitle("")
|
setTitle("");
|
||||||
parentProgressbar.removeChild(divProgress)
|
parentProgressbar.removeChild(divProgress);
|
||||||
if(parentGallery) parentGallery.removeChild(livePreview)
|
if (parentGallery) parentGallery.removeChild(livePreview);
|
||||||
atEnd()
|
atEnd();
|
||||||
}
|
};
|
||||||
|
|
||||||
var fun = function(id_task, id_live_preview){
|
var fun = function(id_task, id_live_preview) {
|
||||||
request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
|
request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {
|
||||||
if(res.completed){
|
if (res.completed) {
|
||||||
removeProgressBar()
|
removeProgressBar();
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var rect = progressbarContainer.getBoundingClientRect()
|
var rect = progressbarContainer.getBoundingClientRect();
|
||||||
|
|
||||||
if(rect.width){
|
if (rect.width) {
|
||||||
divProgress.style.width = rect.width + "px";
|
divProgress.style.width = rect.width + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
let progressText = ""
|
let progressText = "";
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%';
|
||||||
divInner.style.background = res.progress ? "" : "transparent"
|
divInner.style.background = res.progress ? "" : "transparent";
|
||||||
|
|
||||||
if(res.progress > 0){
|
if (res.progress > 0) {
|
||||||
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
|
progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';
|
||||||
}
|
}
|
||||||
|
|
||||||
if(res.eta){
|
if (res.eta) {
|
||||||
progressText += " ETA: " + formatTime(res.eta)
|
progressText += " ETA: " + formatTime(res.eta);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
setTitle(progressText)
|
setTitle(progressText);
|
||||||
|
|
||||||
if(res.textinfo && res.textinfo.indexOf("\n") == -1){
|
if (res.textinfo && res.textinfo.indexOf("\n") == -1) {
|
||||||
progressText = res.textinfo + " " + progressText
|
progressText = res.textinfo + " " + progressText;
|
||||||
}
|
}
|
||||||
|
|
||||||
divInner.textContent = progressText
|
divInner.textContent = progressText;
|
||||||
|
|
||||||
var elapsedFromStart = (new Date() - dateStart) / 1000
|
var elapsedFromStart = (new Date() - dateStart) / 1000;
|
||||||
|
|
||||||
if(res.active) wasEverActive = true;
|
if (res.active) wasEverActive = true;
|
||||||
|
|
||||||
if(! res.active && wasEverActive){
|
if (!res.active && wasEverActive) {
|
||||||
removeProgressBar()
|
removeProgressBar();
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
|
if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {
|
||||||
removeProgressBar()
|
removeProgressBar();
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if(res.live_preview && gallery){
|
if (res.live_preview && gallery) {
|
||||||
var rect = gallery.getBoundingClientRect()
|
rect = gallery.getBoundingClientRect();
|
||||||
if(rect.width){
|
if (rect.width) {
|
||||||
livePreview.style.width = rect.width + "px"
|
livePreview.style.width = rect.width + "px";
|
||||||
livePreview.style.height = rect.height + "px"
|
livePreview.style.height = rect.height + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
var img = new Image();
|
var img = new Image();
|
||||||
img.onload = function() {
|
img.onload = function() {
|
||||||
livePreview.appendChild(img)
|
livePreview.appendChild(img);
|
||||||
if(livePreview.childElementCount > 2){
|
if (livePreview.childElementCount > 2) {
|
||||||
livePreview.removeChild(livePreview.firstElementChild)
|
livePreview.removeChild(livePreview.firstElementChild);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
img.src = res.live_preview;
|
img.src = res.live_preview;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if(onProgress){
|
if (onProgress) {
|
||||||
onProgress(res)
|
onProgress(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
fun(id_task, res.id_live_preview);
|
fun(id_task, res.id_live_preview);
|
||||||
}, opts.live_preview_refresh_period || 500)
|
}, opts.live_preview_refresh_period || 500);
|
||||||
}, function(){
|
}, function() {
|
||||||
removeProgressBar()
|
removeProgressBar();
|
||||||
})
|
});
|
||||||
}
|
};
|
||||||
|
|
||||||
fun(id_task, 0)
|
fun(id_task, 0);
|
||||||
}
|
}
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
function start_training_textual_inversion(){
|
function start_training_textual_inversion() {
|
||||||
gradioApp().querySelector('#ti_error').innerHTML=''
|
gradioApp().querySelector('#ti_error').innerHTML = '';
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId();
|
||||||
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
|
requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {
|
||||||
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
|
gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;
|
||||||
})
|
});
|
||||||
|
|
||||||
var res = args_to_array(arguments)
|
var res = args_to_array(arguments);
|
||||||
|
|
||||||
res[0] = id
|
res[0] = id;
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
447
javascript/ui.js
447
javascript/ui.js
@ -1,9 +1,9 @@
|
|||||||
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
function set_theme(theme){
|
function set_theme(theme) {
|
||||||
var gradioURL = window.location.href
|
var gradioURL = window.location.href;
|
||||||
if (!gradioURL.includes('?__theme=')) {
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
window.location.replace(gradioURL + '?__theme=' + theme);
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ function all_gallery_buttons() {
|
|||||||
if (elem.parentElement.offsetParent) {
|
if (elem.parentElement.offsetParent) {
|
||||||
visibleGalleryButtons.push(elem);
|
visibleGalleryButtons.push(elem);
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
return visibleGalleryButtons;
|
return visibleGalleryButtons;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,31 +25,35 @@ function selected_gallery_button() {
|
|||||||
if (elem.parentElement.offsetParent) {
|
if (elem.parentElement.offsetParent) {
|
||||||
visibleCurrentButton = elem;
|
visibleCurrentButton = elem;
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
return visibleCurrentButton;
|
return visibleCurrentButton;
|
||||||
}
|
}
|
||||||
|
|
||||||
function selected_gallery_index(){
|
function selected_gallery_index() {
|
||||||
var buttons = all_gallery_buttons();
|
var buttons = all_gallery_buttons();
|
||||||
var button = selected_gallery_button();
|
var button = selected_gallery_button();
|
||||||
|
|
||||||
var result = -1
|
var result = -1;
|
||||||
buttons.forEach(function(v, i){ if(v==button) { result = i } })
|
buttons.forEach(function(v, i) {
|
||||||
|
if (v == button) {
|
||||||
|
result = i;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return result
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
function extract_image_from_gallery(gallery){
|
function extract_image_from_gallery(gallery) {
|
||||||
if (gallery.length == 0){
|
if (gallery.length == 0) {
|
||||||
return [null];
|
return [null];
|
||||||
}
|
}
|
||||||
if (gallery.length == 1){
|
if (gallery.length == 1) {
|
||||||
return [gallery[0]];
|
return [gallery[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
var index = selected_gallery_index()
|
var index = selected_gallery_index();
|
||||||
|
|
||||||
if (index < 0 || index >= gallery.length){
|
if (index < 0 || index >= gallery.length) {
|
||||||
// Use the first image in the gallery as the default
|
// Use the first image in the gallery as the default
|
||||||
index = 0;
|
index = 0;
|
||||||
}
|
}
|
||||||
@ -57,249 +61,242 @@ function extract_image_from_gallery(gallery){
|
|||||||
return [gallery[index]];
|
return [gallery[index]];
|
||||||
}
|
}
|
||||||
|
|
||||||
function args_to_array(args){
|
function args_to_array(args) {
|
||||||
var res = []
|
var res = [];
|
||||||
for(var i=0;i<args.length;i++){
|
for (var i = 0; i < args.length; i++) {
|
||||||
res.push(args[i])
|
res.push(args[i]);
|
||||||
}
|
}
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_txt2img(){
|
function switch_to_txt2img() {
|
||||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
|
gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
|
||||||
|
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_img2img_tab(no){
|
function switch_to_img2img_tab(no) {
|
||||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
||||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
|
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();
|
||||||
}
|
}
|
||||||
function switch_to_img2img(){
|
function switch_to_img2img() {
|
||||||
switch_to_img2img_tab(0);
|
switch_to_img2img_tab(0);
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_sketch(){
|
function switch_to_sketch() {
|
||||||
switch_to_img2img_tab(1);
|
switch_to_img2img_tab(1);
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_inpaint(){
|
function switch_to_inpaint() {
|
||||||
switch_to_img2img_tab(2);
|
switch_to_img2img_tab(2);
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_inpaint_sketch(){
|
function switch_to_inpaint_sketch() {
|
||||||
switch_to_img2img_tab(3);
|
switch_to_img2img_tab(3);
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function switch_to_inpaint(){
|
function switch_to_extras() {
|
||||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
|
|
||||||
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click();
|
|
||||||
|
|
||||||
return args_to_array(arguments);
|
|
||||||
}
|
|
||||||
|
|
||||||
function switch_to_extras(){
|
|
||||||
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
|
gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
|
||||||
|
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_tab_index(tabId){
|
function get_tab_index(tabId) {
|
||||||
var res = 0
|
var res = 0;
|
||||||
|
|
||||||
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
|
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i) {
|
||||||
if(button.className.indexOf('selected') != -1)
|
if (button.className.indexOf('selected') != -1) {
|
||||||
res = i
|
res = i;
|
||||||
})
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function create_tab_index_args(tabId, args){
|
function create_tab_index_args(tabId, args) {
|
||||||
var res = []
|
var res = [];
|
||||||
for(var i=0; i<args.length; i++){
|
for (var i = 0; i < args.length; i++) {
|
||||||
res.push(args[i])
|
res.push(args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
res[0] = get_tab_index(tabId)
|
res[0] = get_tab_index(tabId);
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_img2img_tab_index() {
|
function get_img2img_tab_index() {
|
||||||
let res = args_to_array(arguments)
|
let res = args_to_array(arguments);
|
||||||
res.splice(-2)
|
res.splice(-2);
|
||||||
res[0] = get_tab_index('mode_img2img')
|
res[0] = get_tab_index('mode_img2img');
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function create_submit_args(args){
|
function create_submit_args(args) {
|
||||||
var res = []
|
var res = [];
|
||||||
for(var i=0;i<args.length;i++){
|
for (var i = 0; i < args.length; i++) {
|
||||||
res.push(args[i])
|
res.push(args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
|
||||||
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
|
||||||
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
|
||||||
// If gradio at some point stops sending outputs, this may break something
|
// If gradio at some point stops sending outputs, this may break something
|
||||||
if(Array.isArray(res[res.length - 3])){
|
if (Array.isArray(res[res.length - 3])) {
|
||||||
res[res.length - 3] = null
|
res[res.length - 3] = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function showSubmitButtons(tabname, show){
|
function showSubmitButtons(tabname, show) {
|
||||||
gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
|
gradioApp().getElementById(tabname + '_interrupt').style.display = show ? "none" : "block";
|
||||||
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
gradioApp().getElementById(tabname + '_skip').style.display = show ? "none" : "block";
|
||||||
}
|
}
|
||||||
|
|
||||||
function showRestoreProgressButton(tabname, show){
|
function showRestoreProgressButton(tabname, show) {
|
||||||
var button = gradioApp().getElementById(tabname + "_restore_progress")
|
var button = gradioApp().getElementById(tabname + "_restore_progress");
|
||||||
if(! button) return
|
if (!button) return;
|
||||||
|
|
||||||
button.style.display = show ? "flex" : "none"
|
button.style.display = show ? "flex" : "none";
|
||||||
}
|
}
|
||||||
|
|
||||||
function submit(){
|
function submit() {
|
||||||
rememberGallerySelection('txt2img_gallery')
|
showSubmitButtons('txt2img', false);
|
||||||
showSubmitButtons('txt2img', false)
|
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId();
|
||||||
localStorage.setItem("txt2img_task_id", id);
|
localStorage.setItem("txt2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true)
|
showSubmitButtons('txt2img', true);
|
||||||
localStorage.removeItem("txt2img_task_id")
|
localStorage.removeItem("txt2img_task_id");
|
||||||
showRestoreProgressButton('txt2img', false)
|
showRestoreProgressButton('txt2img', false);
|
||||||
})
|
});
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments);
|
||||||
|
|
||||||
res[0] = id
|
res[0] = id;
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function submit_img2img(){
|
function submit_img2img() {
|
||||||
rememberGallerySelection('img2img_gallery')
|
showSubmitButtons('img2img', false);
|
||||||
showSubmitButtons('img2img', false)
|
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId();
|
||||||
localStorage.setItem("img2img_task_id", id);
|
localStorage.setItem("img2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
showSubmitButtons('img2img', true)
|
showSubmitButtons('img2img', true);
|
||||||
localStorage.removeItem("img2img_task_id")
|
localStorage.removeItem("img2img_task_id");
|
||||||
showRestoreProgressButton('img2img', false)
|
showRestoreProgressButton('img2img', false);
|
||||||
})
|
});
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments);
|
||||||
|
|
||||||
res[0] = id
|
res[0] = id;
|
||||||
res[1] = get_tab_index('mode_img2img')
|
res[1] = get_tab_index('mode_img2img');
|
||||||
|
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
function restoreProgressTxt2img(){
|
function restoreProgressTxt2img() {
|
||||||
showRestoreProgressButton("txt2img", false)
|
showRestoreProgressButton("txt2img", false);
|
||||||
var id = localStorage.getItem("txt2img_task_id")
|
var id = localStorage.getItem("txt2img_task_id");
|
||||||
|
|
||||||
id = localStorage.getItem("txt2img_task_id")
|
id = localStorage.getItem("txt2img_task_id");
|
||||||
|
|
||||||
if(id) {
|
if (id) {
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true)
|
showSubmitButtons('txt2img', true);
|
||||||
}, null, 0)
|
}, null, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return id
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
function restoreProgressImg2img(){
|
function restoreProgressImg2img() {
|
||||||
showRestoreProgressButton("img2img", false)
|
showRestoreProgressButton("img2img", false);
|
||||||
|
|
||||||
var id = localStorage.getItem("img2img_task_id")
|
|
||||||
|
|
||||||
if(id) {
|
var id = localStorage.getItem("img2img_task_id");
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
|
||||||
showSubmitButtons('img2img', true)
|
if (id) {
|
||||||
}, null, 0)
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
|
showSubmitButtons('img2img', true);
|
||||||
|
}, null, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
return id
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
onUiLoaded(function () {
|
onUiLoaded(function() {
|
||||||
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"))
|
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"));
|
||||||
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"))
|
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
function modelmerger(){
|
function modelmerger() {
|
||||||
var id = randomId()
|
var id = randomId();
|
||||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function() {});
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments);
|
||||||
res[0] = id
|
res[0] = id;
|
||||||
return res
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
var name_ = prompt('Style name:')
|
var name_ = prompt('Style name:');
|
||||||
return [name_, prompt_text, negative_prompt_text]
|
return [name_, prompt_text, negative_prompt_text];
|
||||||
}
|
}
|
||||||
|
|
||||||
function confirm_clear_prompt(prompt, negative_prompt) {
|
function confirm_clear_prompt(prompt, negative_prompt) {
|
||||||
if(confirm("Delete prompt?")) {
|
if (confirm("Delete prompt?")) {
|
||||||
prompt = ""
|
prompt = "";
|
||||||
negative_prompt = ""
|
negative_prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
return [prompt, negative_prompt]
|
return [prompt, negative_prompt];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
promptTokecountUpdateFuncs = {}
|
var promptTokecountUpdateFuncs = {};
|
||||||
|
|
||||||
function recalculatePromptTokens(name){
|
function recalculatePromptTokens(name) {
|
||||||
if(promptTokecountUpdateFuncs[name]){
|
if (promptTokecountUpdateFuncs[name]) {
|
||||||
promptTokecountUpdateFuncs[name]()
|
promptTokecountUpdateFuncs[name]();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function recalculate_prompts_txt2img(){
|
function recalculate_prompts_txt2img() {
|
||||||
recalculatePromptTokens('txt2img_prompt')
|
recalculatePromptTokens('txt2img_prompt');
|
||||||
recalculatePromptTokens('txt2img_neg_prompt')
|
recalculatePromptTokens('txt2img_neg_prompt');
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
function recalculate_prompts_img2img(){
|
function recalculate_prompts_img2img() {
|
||||||
recalculatePromptTokens('img2img_prompt')
|
recalculatePromptTokens('img2img_prompt');
|
||||||
recalculatePromptTokens('img2img_neg_prompt')
|
recalculatePromptTokens('img2img_neg_prompt');
|
||||||
return args_to_array(arguments);
|
return args_to_array(arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var opts = {}
|
var opts = {};
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function() {
|
||||||
if(Object.keys(opts).length != 0) return;
|
if (Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
var json_elem = gradioApp().getElementById('settings_json')
|
var json_elem = gradioApp().getElementById('settings_json');
|
||||||
if(json_elem == null) return;
|
if (json_elem == null) return;
|
||||||
|
|
||||||
var textarea = json_elem.querySelector('textarea')
|
var textarea = json_elem.querySelector('textarea');
|
||||||
var jsdata = textarea.value
|
var jsdata = textarea.value;
|
||||||
opts = JSON.parse(jsdata)
|
opts = JSON.parse(jsdata);
|
||||||
executeCallbacks(optionsChangedCallbacks);
|
|
||||||
|
executeCallbacks(optionsChangedCallbacks); /*global optionsChangedCallbacks*/
|
||||||
|
|
||||||
Object.defineProperty(textarea, 'value', {
|
Object.defineProperty(textarea, 'value', {
|
||||||
set: function(newValue) {
|
set: function(newValue) {
|
||||||
@ -308,7 +305,7 @@ onUiUpdate(function(){
|
|||||||
valueProp.set.call(textarea, newValue);
|
valueProp.set.call(textarea, newValue);
|
||||||
|
|
||||||
if (oldValue != newValue) {
|
if (oldValue != newValue) {
|
||||||
opts = JSON.parse(textarea.value)
|
opts = JSON.parse(textarea.value);
|
||||||
}
|
}
|
||||||
|
|
||||||
executeCallbacks(optionsChangedCallbacks);
|
executeCallbacks(optionsChangedCallbacks);
|
||||||
@ -319,123 +316,157 @@ onUiUpdate(function(){
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
json_elem.parentElement.style.display="none"
|
json_elem.parentElement.style.display = "none";
|
||||||
|
|
||||||
function registerTextarea(id, id_counter, id_button){
|
function registerTextarea(id, id_counter, id_button) {
|
||||||
var prompt = gradioApp().getElementById(id)
|
var prompt = gradioApp().getElementById(id);
|
||||||
var counter = gradioApp().getElementById(id_counter)
|
var counter = gradioApp().getElementById(id_counter);
|
||||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||||
|
|
||||||
if(counter.parentElement == prompt.parentElement){
|
if (counter.parentElement == prompt.parentElement) {
|
||||||
return
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt.parentElement.insertBefore(counter, prompt)
|
prompt.parentElement.insertBefore(counter, prompt);
|
||||||
prompt.parentElement.style.position = "relative"
|
prompt.parentElement.style.position = "relative";
|
||||||
|
|
||||||
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }
|
promptTokecountUpdateFuncs[id] = function() {
|
||||||
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
update_token_counter(id_button);
|
||||||
|
};
|
||||||
|
textarea.addEventListener("input", promptTokecountUpdateFuncs[id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button')
|
registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');
|
||||||
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button')
|
registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');
|
||||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');
|
||||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');
|
||||||
|
|
||||||
var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
|
||||||
var settings_tabs = gradioApp().querySelector('#settings div')
|
var settings_tabs = gradioApp().querySelector('#settings div');
|
||||||
if(show_all_pages && settings_tabs){
|
if (show_all_pages && settings_tabs) {
|
||||||
settings_tabs.appendChild(show_all_pages)
|
settings_tabs.appendChild(show_all_pages);
|
||||||
show_all_pages.onclick = function(){
|
show_all_pages.onclick = function() {
|
||||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
|
||||||
if(elem.id == "settings_tab_licenses")
|
if (elem.id == "settings_tab_licenses") {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
elem.style.display = "block";
|
elem.style.display = "block";
|
||||||
})
|
});
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
|
|
||||||
onOptionsChanged(function(){
|
onOptionsChanged(function() {
|
||||||
var elem = gradioApp().getElementById('sd_checkpoint_hash')
|
var elem = gradioApp().getElementById('sd_checkpoint_hash');
|
||||||
var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
var sd_checkpoint_hash = opts.sd_checkpoint_hash || "";
|
||||||
var shorthash = sd_checkpoint_hash.substring(0,10)
|
var shorthash = sd_checkpoint_hash.substring(0, 10);
|
||||||
|
|
||||||
if(elem && elem.textContent != shorthash){
|
if (elem && elem.textContent != shorthash) {
|
||||||
elem.textContent = shorthash
|
elem.textContent = shorthash;
|
||||||
elem.title = sd_checkpoint_hash
|
elem.title = sd_checkpoint_hash;
|
||||||
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash
|
elem.href = "https://google.com/search?q=" + sd_checkpoint_hash;
|
||||||
}
|
}
|
||||||
})
|
});
|
||||||
|
|
||||||
let txt2img_textarea, img2img_textarea = undefined;
|
let txt2img_textarea, img2img_textarea = undefined;
|
||||||
let wait_time = 800
|
let wait_time = 800;
|
||||||
let token_timeouts = {};
|
let token_timeouts = {};
|
||||||
|
|
||||||
function update_txt2img_tokens(...args) {
|
function update_txt2img_tokens(...args) {
|
||||||
update_token_counter("txt2img_token_button")
|
update_token_counter("txt2img_token_button");
|
||||||
if (args.length == 2)
|
if (args.length == 2) {
|
||||||
return args[0]
|
return args[0];
|
||||||
return args;
|
}
|
||||||
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
function update_img2img_tokens(...args) {
|
function update_img2img_tokens(...args) {
|
||||||
update_token_counter("img2img_token_button")
|
update_token_counter(
|
||||||
if (args.length == 2)
|
"img2img_token_button"
|
||||||
return args[0]
|
);
|
||||||
return args;
|
if (args.length == 2) {
|
||||||
|
return args[0];
|
||||||
|
}
|
||||||
|
return args;
|
||||||
}
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function update_token_counter(button_id) {
|
||||||
if (token_timeouts[button_id])
|
if (token_timeouts[button_id]) {
|
||||||
clearTimeout(token_timeouts[button_id]);
|
clearTimeout(token_timeouts[button_id]);
|
||||||
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
}
|
||||||
|
token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
function restart_reload(){
|
function restart_reload() {
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML = '<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
|
|
||||||
var requestPing = function(){
|
var requestPing = function() {
|
||||||
requestGet("./internal/ping", {}, function(data){
|
requestGet("./internal/ping", {}, function(data) {
|
||||||
location.reload();
|
location.reload();
|
||||||
}, function(){
|
}, function() {
|
||||||
setTimeout(requestPing, 500);
|
setTimeout(requestPing, 500);
|
||||||
})
|
});
|
||||||
}
|
};
|
||||||
|
|
||||||
setTimeout(requestPing, 2000);
|
setTimeout(requestPing, 2000);
|
||||||
|
|
||||||
return []
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
|
// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits
|
||||||
// will only visible on web page and not sent to python.
|
// will only visible on web page and not sent to python.
|
||||||
function updateInput(target){
|
function updateInput(target) {
|
||||||
let e = new Event("input", { bubbles: true })
|
let e = new Event("input", {bubbles: true});
|
||||||
Object.defineProperty(e, "target", {value: target})
|
Object.defineProperty(e, "target", {value: target});
|
||||||
target.dispatchEvent(e);
|
target.dispatchEvent(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var desiredCheckpointName = null;
|
var desiredCheckpointName = null;
|
||||||
function selectCheckpoint(name){
|
function selectCheckpoint(name) {
|
||||||
desiredCheckpointName = name;
|
desiredCheckpointName = name;
|
||||||
gradioApp().getElementById('change_checkpoint').click()
|
gradioApp().getElementById('change_checkpoint').click();
|
||||||
}
|
}
|
||||||
|
|
||||||
function currentImg2imgSourceResolution(_, _, scaleBy){
|
function currentImg2imgSourceResolution(w, h, scaleBy) {
|
||||||
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
|
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img');
|
||||||
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
|
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy];
|
||||||
}
|
}
|
||||||
|
|
||||||
function updateImg2imgResizeToTextAfterChangingImage(){
|
function updateImg2imgResizeToTextAfterChangingImage() {
|
||||||
// At the time this is called from gradio, the image has no yet been replaced.
|
// At the time this is called from gradio, the image has no yet been replaced.
|
||||||
// There may be a better solution, but this is simple and straightforward so I'm going with it.
|
// There may be a better solution, but this is simple and straightforward so I'm going with it.
|
||||||
|
|
||||||
setTimeout(function() {
|
setTimeout(function() {
|
||||||
gradioApp().getElementById('img2img_update_resize_to').click()
|
gradioApp().getElementById('img2img_update_resize_to').click();
|
||||||
}, 500);
|
}, 500);
|
||||||
|
|
||||||
return []
|
return [];
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
function setRandomSeed(elem_id) {
|
||||||
|
var input = gradioApp().querySelector("#" + elem_id + " input");
|
||||||
|
if (!input) return [];
|
||||||
|
|
||||||
|
input.value = "-1";
|
||||||
|
updateInput(input);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
function switchWidthHeight(tabname) {
|
||||||
|
var width = gradioApp().querySelector("#" + tabname + "_width input[type=number]");
|
||||||
|
var height = gradioApp().querySelector("#" + tabname + "_height input[type=number]");
|
||||||
|
if (!width || !height) return [];
|
||||||
|
|
||||||
|
var tmp = width.value;
|
||||||
|
width.value = height.value;
|
||||||
|
height.value = tmp;
|
||||||
|
|
||||||
|
updateInput(width);
|
||||||
|
updateInput(height);
|
||||||
|
return [];
|
||||||
}
|
}
|
||||||
|
@ -1,41 +1,62 @@
|
|||||||
// various hints and extra info for the settings tab
|
// various hints and extra info for the settings tab
|
||||||
|
|
||||||
onUiLoaded(function(){
|
var settingsHintsSetup = false;
|
||||||
createLink = function(elem_id, text, href){
|
|
||||||
var a = document.createElement('A')
|
onOptionsChanged(function() {
|
||||||
a.textContent = text
|
if (settingsHintsSetup) return;
|
||||||
a.target = '_blank';
|
settingsHintsSetup = true;
|
||||||
|
|
||||||
elem = gradioApp().querySelector('#'+elem_id)
|
gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {
|
||||||
elem.insertBefore(a, elem.querySelector('label'))
|
var name = div.id.substr(8);
|
||||||
|
var commentBefore = opts._comments_before[name];
|
||||||
return a
|
var commentAfter = opts._comments_after[name];
|
||||||
}
|
|
||||||
|
if (!commentBefore && !commentAfter) return;
|
||||||
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
|
|
||||||
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
|
var span = null;
|
||||||
|
if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');
|
||||||
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
|
else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;
|
||||||
requestGet("./internal/quicksettings-hint", {}, function(data){
|
else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;
|
||||||
var table = document.createElement('table')
|
else span = div.querySelector('label span').firstChild;
|
||||||
table.className = 'settings-value-table'
|
|
||||||
|
if (!span) return;
|
||||||
data.forEach(function(obj){
|
|
||||||
var tr = document.createElement('tr')
|
if (commentBefore) {
|
||||||
var td = document.createElement('td')
|
var comment = document.createElement('DIV');
|
||||||
td.textContent = obj.name
|
comment.className = 'settings-comment';
|
||||||
tr.appendChild(td)
|
comment.innerHTML = commentBefore;
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
|
||||||
var td = document.createElement('td')
|
span.parentElement.insertBefore(comment, span);
|
||||||
td.textContent = obj.label
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span);
|
||||||
tr.appendChild(td)
|
}
|
||||||
|
if (commentAfter) {
|
||||||
table.appendChild(tr)
|
comment = document.createElement('DIV');
|
||||||
})
|
comment.className = 'settings-comment';
|
||||||
|
comment.innerHTML = commentAfter;
|
||||||
popup(table);
|
span.parentElement.insertBefore(comment, span.nextSibling);
|
||||||
})
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling);
|
||||||
});
|
}
|
||||||
})
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function settingsHintsShowQuicksettings() {
|
||||||
|
requestGet("./internal/quicksettings-hint", {}, function(data) {
|
||||||
|
var table = document.createElement('table');
|
||||||
|
table.className = 'settings-value-table';
|
||||||
|
|
||||||
|
data.forEach(function(obj) {
|
||||||
|
var tr = document.createElement('tr');
|
||||||
|
var td = document.createElement('td');
|
||||||
|
td.textContent = obj.name;
|
||||||
|
tr.appendChild(td);
|
||||||
|
|
||||||
|
td = document.createElement('td');
|
||||||
|
td.textContent = obj.label;
|
||||||
|
tr.appendChild(td);
|
||||||
|
|
||||||
|
table.appendChild(tr);
|
||||||
|
});
|
||||||
|
|
||||||
|
popup(table);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
108
launch.py
108
launch.py
@ -3,25 +3,23 @@ import subprocess
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shlex
|
|
||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from modules import cmd_args
|
from modules import cmd_args
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
|
||||||
sys.argv += shlex.split(commandline_args)
|
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
stored_commit_hash = None
|
|
||||||
stored_git_tag = None
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
|
|
||||||
|
# Whether to default to printing command output
|
||||||
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
@ -57,65 +55,52 @@ Use --skip-python-version-check to suppress this warning.
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def commit_hash():
|
def commit_hash():
|
||||||
global stored_commit_hash
|
|
||||||
|
|
||||||
if stored_commit_hash is not None:
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
stored_commit_hash = "<none>"
|
return "<none>"
|
||||||
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def git_tag():
|
def git_tag():
|
||||||
global stored_git_tag
|
|
||||||
|
|
||||||
if stored_git_tag is not None:
|
|
||||||
return stored_git_tag
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stored_git_tag = run(f"{git} describe --tags").strip()
|
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
stored_git_tag = "<none>"
|
return "<none>"
|
||||||
|
|
||||||
return stored_git_tag
|
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
if live:
|
run_kwargs = {
|
||||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
"args": command,
|
||||||
if result.returncode != 0:
|
"shell": True,
|
||||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
"env": os.environ if custom_env is None else custom_env,
|
||||||
Command: {command}
|
"encoding": 'utf8',
|
||||||
Error code: {result.returncode}""")
|
"errors": 'ignore',
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
if not live:
|
||||||
|
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||||
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
result = subprocess.run(**run_kwargs)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
|
error_bits = [
|
||||||
|
f"{errdesc or 'Error running command'}.",
|
||||||
|
f"Command: {command}",
|
||||||
|
f"Error code: {result.returncode}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
error_bits.append(f"stdout: {result.stdout}")
|
||||||
|
if result.stderr:
|
||||||
|
error_bits.append(f"stderr: {result.stderr}")
|
||||||
|
raise RuntimeError("\n".join(error_bits))
|
||||||
|
|
||||||
message = f"""{errdesc or 'Error running command'}.
|
return (result.stdout or "")
|
||||||
Command: {command}
|
|
||||||
Error code: {result.returncode}
|
|
||||||
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
|
||||||
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
|
||||||
"""
|
|
||||||
raise RuntimeError(message)
|
|
||||||
|
|
||||||
return result.stdout.decode(encoding="utf8", errors="ignore")
|
|
||||||
|
|
||||||
|
|
||||||
def check_run(command):
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
|
||||||
return result.returncode == 0
|
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
@ -131,11 +116,7 @@ def repo_dir(name):
|
|||||||
return os.path.join(script_path, dir_repos, name)
|
return os.path.join(script_path, dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
def run_python(code, desc=None, errdesc=None):
|
def run_pip(command, desc=None, live=default_command_live):
|
||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
|
||||||
|
|
||||||
|
|
||||||
def run_pip(command, desc=None, live=False):
|
|
||||||
if args.skip_install:
|
if args.skip_install:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -143,8 +124,9 @@ def run_pip(command, desc=None, live=False):
|
|||||||
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code: str) -> bool:
|
||||||
return check_run(f'"{python}" -c "{code}"')
|
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||||
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
@ -237,13 +219,14 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
@ -270,8 +253,11 @@ def prepare_environment():
|
|||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test:
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
raise RuntimeError(
|
||||||
|
'Torch is not able to use GPU; '
|
||||||
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
|
)
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
@ -319,7 +305,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@ -15,7 +15,8 @@ from secrets import compare_digest
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api import models
|
||||||
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
@ -25,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||||
|
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
config = sd_samplers.all_samplers_map.get(name, None)
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
@ -48,20 +52,23 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception as err:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
@ -92,6 +99,7 @@ def encode_pil_to_base64(image):
|
|||||||
|
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def api_middleware(app: FastAPI):
|
def api_middleware(app: FastAPI):
|
||||||
rich_available = True
|
rich_available = True
|
||||||
try:
|
try:
|
||||||
@ -99,7 +107,7 @@ def api_middleware(app: FastAPI):
|
|||||||
import starlette # importing just so it can be placed on silent list
|
import starlette # importing just so it can be placed on silent list
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@ -157,7 +165,7 @@ def api_middleware(app: FastAPI):
|
|||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
self.credentials = dict()
|
self.credentials = {}
|
||||||
for auth in shared.cmd_opts.api_auth.split(","):
|
for auth in shared.cmd_opts.api_auth.split(","):
|
||||||
user, password = auth.split(":")
|
user, password = auth.split(":")
|
||||||
self.credentials[user] = password
|
self.credentials[user] = password
|
||||||
@ -166,36 +174,37 @@ class Api:
|
|||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
api_middleware(self.app)
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
@ -219,17 +228,25 @@ class Api:
|
|||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
def get_scripts_list(self):
|
|
||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
|
||||||
|
|
||||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
def get_scripts_list(self):
|
||||||
|
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||||
|
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||||
|
|
||||||
|
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||||
|
|
||||||
|
def get_script_info(self):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
||||||
|
res += [script.api_info for script in script_list if script.api_info is not None]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
return script_runner.scripts[script_idx]
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
@ -264,11 +281,11 @@ class Api:
|
|||||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||||
if alwayson_script == None:
|
if alwayson_script is None:
|
||||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||||
# Selectable script in always on script param check
|
# Selectable script in always on script param check
|
||||||
if alwayson_script.alwayson == False:
|
if alwayson_script.alwayson is False:
|
||||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
# min between arg length in scriptrunner and arg length in the request
|
# min between arg length in scriptrunner and arg length in the request
|
||||||
@ -276,7 +293,7 @@ class Api:
|
|||||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
@ -310,7 +327,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -320,9 +337,9 @@ class Api:
|
|||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -367,7 +384,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -381,9 +398,9 @@ class Api:
|
|||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
img2imgreq.mask = None
|
img2imgreq.mask = None
|
||||||
|
|
||||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||||
|
|
||||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
@ -391,9 +408,9 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
image_list = reqDict.pop('imageList', [])
|
image_list = reqDict.pop('imageList', [])
|
||||||
@ -402,15 +419,15 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
if(not req.image.strip()):
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
@ -418,13 +435,13 @@ class Api:
|
|||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
items = {**{'parameters': geninfo}, **items}
|
||||||
|
|
||||||
return PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
@ -446,9 +463,9 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
@ -465,7 +482,7 @@ class Api:
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return models.InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
@ -570,36 +587,36 @@ class Api:
|
|||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info=f"create embedding filename: {filename}")
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"create embedding error: {e}")
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info=f"create hypernetwork filename: {filename}")
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"create hypernetwork error: {e}")
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info = 'preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f"preprocess error: {e}")
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info=f'preprocess error: {e}')
|
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -617,10 +634,10 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding error: {msg}")
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -641,14 +658,15 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info=f"train embedding error: {error}")
|
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os, psutil
|
import os
|
||||||
|
import psutil
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||||
@ -675,10 +693,10 @@ class Api:
|
|||||||
'events': warnings,
|
'events': warnings,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cuda = { 'error': 'unavailable' }
|
cuda = {'error': 'unavailable'}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
cuda = { 'error': f'{err}' }
|
cuda = {'error': f'{err}'}
|
||||||
return MemoryResponse(ram = ram, cuda = cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
@ -223,8 +223,9 @@ for key in _options:
|
|||||||
if(_options[key].dest != 'help'):
|
if(_options[key].dest != 'help'):
|
||||||
flag = _options[key]
|
flag = _options[key]
|
||||||
_type = str
|
_type = str
|
||||||
if _options[key].default is not None: _type = type(_options[key].default)
|
if _options[key].default is not None:
|
||||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
_type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
FlagsModel = create_model("Flags", **flags)
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
|
|||||||
ram: dict = Field(title="RAM", description="System memory stats")
|
ram: dict = Field(title="RAM", description="System memory stats")
|
||||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||||
|
|
||||||
|
|
||||||
class ScriptsList(BaseModel):
|
class ScriptsList(BaseModel):
|
||||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptArg(BaseModel):
|
||||||
|
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||||
|
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||||
|
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||||
|
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||||
|
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||||
|
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptInfo(BaseModel):
|
||||||
|
name: str = Field(default=None, title="Name", description="Script name")
|
||||||
|
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||||
|
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||||
|
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -103,4 +103,5 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
|
|||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
|
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import *
|
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
def calc_mean_std(feat, eps=1e-5):
|
||||||
@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
|||||||
tgt_mask: Optional[Tensor] = None,
|
tgt_mask: Optional[Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
query_pos: Optional[Tensor] = None):
|
query_pos: Optional[Tensor] = None):
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
tgt2 = self.norm1(tgt)
|
tgt2 = self.norm1(tgt)
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=['32', '64', '128', '256'],
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=['quantize','generator']):
|
fix_modules=('quantize', 'generator')):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||||
for _ in range(self.n_layers)])
|
for _ in range(self.n_layers)])
|
||||||
|
|
||||||
# logits_predict head
|
# logits_predict head
|
||||||
self.idx_pred_layer = nn.Sequential(
|
self.idx_pred_layer = nn.Sequential(
|
||||||
nn.LayerNorm(dim_embd),
|
nn.LayerNorm(dim_embd),
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||||
|
|
||||||
self.channels = {
|
self.channels = {
|
||||||
'16': 512,
|
'16': 512,
|
||||||
'32': 256,
|
'32': 256,
|
||||||
@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
enc_feat_dict = {}
|
enc_feat_dict = {}
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
for i, block in enumerate(self.encoder.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in out_list:
|
if i in out_list:
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||||
|
|
||||||
@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
for i, block in enumerate(self.generator.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in fuse_list: # fuse after i-th block
|
if i in fuse_list: # fuse after i-th block
|
||||||
f_size = str(x.shape[-1])
|
f_size = str(x.shape[-1])
|
||||||
if w>0:
|
if w>0:
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||||
out = x
|
out = x
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
# logits doesn't need softmax before cross_entropy loss
|
||||||
return out, logits, lq_feat
|
return out, logits, lq_feat
|
||||||
|
@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
|||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import copy
|
|
||||||
from basicsr.utils import get_root_logger
|
from basicsr.utils import get_root_logger
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def normalize(in_channels):
|
def normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def swish(x):
|
def swish(x):
|
||||||
@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h*w)
|
q = q.reshape(b, c, h*w)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(b, c, h*w)
|
k = k.reshape(b, c, h*w)
|
||||||
w_ = torch.bmm(q, k)
|
w_ = torch.bmm(q, k)
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
w_ = F.softmax(w_, dim=2)
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = v.reshape(b, c, h*w)
|
v = v.reshape(b, c, h*w)
|
||||||
w_ = w_.permute(0, 2, 1)
|
w_ = w_.permute(0, 2, 1)
|
||||||
h_ = torch.bmm(v, w_)
|
h_ = torch.bmm(v, w_)
|
||||||
h_ = h_.reshape(b, c, h, w)
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.num_resolutions = len(self.ch_mult)
|
self.num_resolutions = len(self.ch_mult)
|
||||||
self.num_res_blocks = res_blocks
|
self.num_res_blocks = res_blocks
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions
|
||||||
self.in_channels = emb_dim
|
self.in_channels = emb_dim
|
||||||
self.out_channels = 3
|
self.out_channels = 3
|
||||||
@ -317,29 +315,29 @@ class Generator(nn.Module):
|
|||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
self.in_channels = 3
|
self.in_channels = 3
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.n_blocks = res_blocks
|
self.n_blocks = res_blocks
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions or [16]
|
||||||
self.quantizer_type = quantizer
|
self.quantizer_type = quantizer
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.kl_weight
|
self.kl_weight
|
||||||
)
|
)
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
self.nf,
|
self.nf,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.ch_mult,
|
self.ch_mult,
|
||||||
self.n_blocks,
|
self.n_blocks,
|
||||||
self.resolution,
|
self.resolution,
|
||||||
self.attn_resolutions
|
self.attn_resolutions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
|||||||
raise ValueError('Wrong params!')
|
raise ValueError('Wrong params!')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.main(x)
|
return self.main(x)
|
||||||
|
@ -33,11 +33,9 @@ def setup_model(dirname):
|
|||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils import img2tensor, tensor2img
|
||||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
|
||||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
from facelib.detection.retinaface import retinaface
|
from facelib.detection.retinaface import retinaface
|
||||||
from modules.shared import cmd_opts
|
|
||||||
|
|
||||||
net_class = CodeFormer
|
net_class = CodeFormer
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ def setup_model(dirname):
|
|||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
self.face_helper.align_warp_face()
|
self.face_helper.align_warp_face()
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
for cropped_face in self.face_helper.cropped_faces:
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
@ -14,7 +14,7 @@ from collections import OrderedDict
|
|||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions
|
from modules import shared, extensions
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
all_config_states = OrderedDict()
|
all_config_states = OrderedDict()
|
||||||
@ -35,7 +35,7 @@ def list_config_states():
|
|||||||
j["filepath"] = path
|
j["filepath"] = path
|
||||||
config_states.append(j)
|
config_states.append(j)
|
||||||
|
|
||||||
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
for cs in config_states:
|
for cs in config_states:
|
||||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||||
@ -83,6 +83,8 @@ def get_extension_config():
|
|||||||
ext_config = {}
|
ext_config = {}
|
||||||
|
|
||||||
for ext in extensions.extensions:
|
for ext in extensions.extensions:
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
|
||||||
entry = {
|
entry = {
|
||||||
"name": ext.name,
|
"name": ext.name,
|
||||||
"path": ext.path,
|
"path": ext.path,
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||||
@ -79,7 +78,7 @@ class DeepDanbooru:
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||||
|
|
||||||
for tag in [x for x in tags if x not in filtertags]:
|
for tag in [x for x in tags if x not in filtertags]:
|
||||||
probability = probability_dict[tag]
|
probability = probability_dict[tag]
|
||||||
|
@ -65,7 +65,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import shared, modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -16,9 +16,7 @@ def mod2normal(state_dict):
|
|||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
if 'conv_first.weight' in state_dict:
|
if 'conv_first.weight' in state_dict:
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
|
|||||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||||
re8x = 0
|
re8x = 0
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import math
|
import math
|
||||||
import functools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
Modified options that can be used:
|
Modified options that can be used:
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
|||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
x = x + sampled_noise
|
x = x + sampled_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
@ -438,9 +437,11 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||||||
padding = padding if pad_type == 'zero' else 0
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
if convtype=='PartialConv2D':
|
||||||
|
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='DeformConv2D':
|
elif convtype=='DeformConv2D':
|
||||||
|
from torchvision.ops import DeformConv2d # not tested
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='Conv3D':
|
elif convtype=='Conv3D':
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
@ -25,6 +24,8 @@ def active():
|
|||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.path = path
|
self.path = path
|
||||||
@ -43,8 +44,13 @@ class Extension:
|
|||||||
if self.is_builtin or self.have_info_from_repo:
|
if self.is_builtin or self.have_info_from_repo:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.have_info_from_repo = True
|
with self.lock:
|
||||||
|
if self.have_info_from_repo:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.do_read_info_from_repo()
|
||||||
|
|
||||||
|
def do_read_info_from_repo(self):
|
||||||
repo = None
|
repo = None
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(self.path, ".git")):
|
if os.path.exists(os.path.join(self.path, ".git")):
|
||||||
@ -59,18 +65,18 @@ class Extension:
|
|||||||
try:
|
try:
|
||||||
self.status = 'unknown'
|
self.status = 'unknown'
|
||||||
self.remote = next(repo.remote().urls, None)
|
self.remote = next(repo.remote().urls, None)
|
||||||
head = repo.head.commit
|
|
||||||
self.commit_date = repo.head.commit.committed_date
|
self.commit_date = repo.head.commit.committed_date
|
||||||
ts = time.asctime(time.gmtime(self.commit_date))
|
|
||||||
if repo.active_branch:
|
if repo.active_branch:
|
||||||
self.branch = repo.active_branch.name
|
self.branch = repo.active_branch.name
|
||||||
self.commit_hash = head.hexsha
|
self.commit_hash = repo.head.commit.hexsha
|
||||||
self.version = f'{self.commit_hash[:8]} ({ts})'
|
self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
||||||
self.remote = None
|
self.remote = None
|
||||||
|
|
||||||
|
self.have_info_from_repo = True
|
||||||
|
|
||||||
def list_files(self, subdir, extension):
|
def list_files(self, subdir, extension):
|
||||||
from modules import scripts
|
from modules import scripts
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
|
|||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
deactivate for all remaining registered networks"""
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network_name in extra_network_data:
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
if extra_network is None:
|
if extra_network is None:
|
||||||
continue
|
continue
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared, extra_networks
|
from modules import extra_networks, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_instruct_pix2pix_model = False
|
result_is_instruct_pix2pix_model = False
|
||||||
|
|
||||||
if theta_func2:
|
if theta_func2:
|
||||||
shared.state.textinfo = f"Loading B"
|
shared.state.textinfo = "Loading B"
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
else:
|
else:
|
||||||
theta_1 = None
|
theta_1 = None
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
shared.state.textinfo = f"Loading C"
|
shared.state.textinfo = "Loading C"
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
@ -242,9 +242,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
|
metadata = None
|
||||||
|
|
||||||
if save_metadata:
|
if save_metadata:
|
||||||
|
metadata = {"format": "pt"}
|
||||||
|
|
||||||
merge_recipe = {
|
merge_recipe = {
|
||||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
"primary_model_hash": primary_model_info.sha256,
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
@ -262,15 +264,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
}
|
}
|
||||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
|
|
||||||
|
sd_merge_models = {}
|
||||||
|
|
||||||
def add_model_metadata(checkpoint_info):
|
def add_model_metadata(checkpoint_info):
|
||||||
checkpoint_info.calculate_shorthash()
|
checkpoint_info.calculate_shorthash()
|
||||||
metadata["sd_merge_models"][checkpoint_info.sha256] = {
|
sd_merge_models[checkpoint_info.sha256] = {
|
||||||
"name": checkpoint_info.name,
|
"name": checkpoint_info.name,
|
||||||
"legacy_hash": checkpoint_info.hash,
|
"legacy_hash": checkpoint_info.hash,
|
||||||
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||||
|
|
||||||
add_model_metadata(primary_model_info)
|
add_model_metadata(primary_model_info)
|
||||||
if secondary_model_info:
|
if secondary_model_info:
|
||||||
@ -278,7 +282,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
if tertiary_model_info:
|
if tertiary_model_info:
|
||||||
add_model_metadata(tertiary_model_info)
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
|
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import html
|
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
@ -23,14 +19,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
self.paste_field_names = paste_field_names
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
@ -251,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
lines.append(lastline)
|
lines.append(lastline)
|
||||||
lastline = ''
|
lastline = ''
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith("Negative prompt:"):
|
if line.startswith("Negative prompt:"):
|
||||||
done_with_prompt = True
|
done_with_prompt = True
|
||||||
@ -312,6 +308,8 @@ infotext_to_setting_name_mapping = [
|
|||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
('UniPC order', 'uni_pc_order'),
|
('UniPC order', 'uni_pc_order'),
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
('Token merging ratio', 'token_merging_ratio'),
|
||||||
|
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||||
('RNG', 'randn_source'),
|
('RNG', 'randn_source'),
|
||||||
('NGMS', 's_min_uncond'),
|
('NGMS', 's_min_uncond'),
|
||||||
]
|
]
|
||||||
|
@ -78,7 +78,7 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from facexlib import detection, parsing
|
from facexlib import detection, parsing # noqa: F401
|
||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import csv
|
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
|
|
||||||
@ -178,34 +177,34 @@ class Hypernetwork:
|
|||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
res += layer.parameters()
|
res += layer.parameters()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train(mode=mode)
|
layer.train(mode=mode)
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = mode
|
param.requires_grad = mode
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.to(device)
|
layer.to(device)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.multiplier = multiplier
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.eval()
|
layer.eval()
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
try:
|
try:
|
||||||
sd_hijack_checkpoint.add()
|
sd_hijack_checkpoint.add()
|
||||||
|
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
loss_logging.append(_loss_step)
|
loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
_loss_step = 0
|
_loss_step = 0
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
@ -1,19 +1,17 @@
|
|||||||
import html
|
import html
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
@ -13,17 +13,24 @@ import numpy as np
|
|||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from fonts.ttf import Roboto
|
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks, errors
|
from modules import sd_samplers, shared, script_callbacks, errors
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.paths_internal import roboto_ttf_file
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_font(fontsize: int):
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||||
|
except Exception:
|
||||||
|
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size=1, rows=None):
|
def image_grid(imgs, batch_size=1, rows=None):
|
||||||
if rows is None:
|
if rows is None:
|
||||||
if opts.n_rows > 0:
|
if opts.n_rows > 0:
|
||||||
@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
lines.append(word)
|
lines.append(word)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def get_font(fontsize):
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
|
||||||
except Exception:
|
|
||||||
return ImageFont.truetype(Roboto, fontsize)
|
|
||||||
|
|
||||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
fnt = initial_fnt
|
fnt = initial_fnt
|
||||||
fontsize = initial_fontsize
|
fontsize = initial_fontsize
|
||||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||||
@ -366,7 +367,7 @@ class FilenameGenerator:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@ -409,13 +410,13 @@ class FilenameGenerator:
|
|||||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||||
try:
|
try:
|
||||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
except pytz.exceptions.UnknownTimeZoneError:
|
||||||
time_zone = None
|
time_zone = None
|
||||||
|
|
||||||
time_zone_time = time_datetime.astimezone(time_zone)
|
time_zone_time = time_datetime.astimezone(time_zone)
|
||||||
try:
|
try:
|
||||||
formatted_time = time_zone_time.strftime(time_format)
|
formatted_time = time_zone_time.strftime(time_format)
|
||||||
except (ValueError, TypeError) as _:
|
except (ValueError, TypeError):
|
||||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
@ -472,15 +473,52 @@ def get_next_sequence_number(path, basename):
|
|||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(parts[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return result + 1
|
return result + 1
|
||||||
|
|
||||||
|
|
||||||
|
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
|
||||||
|
if extension is None:
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
|
||||||
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
|
existing_pnginfo = existing_pnginfo or {}
|
||||||
|
if opts.enable_pnginfo:
|
||||||
|
existing_pnginfo['parameters'] = geninfo
|
||||||
|
|
||||||
|
if extension.lower() == '.png':
|
||||||
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
|
for k, v in (existing_pnginfo or {}).items():
|
||||||
|
pnginfo_data.add_text(k, str(v))
|
||||||
|
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||||
|
|
||||||
|
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
elif image.mode == 'I;16':
|
||||||
|
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||||
|
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||||
|
|
||||||
|
if opts.enable_pnginfo and geninfo is not None:
|
||||||
|
exif_bytes = piexif.dump({
|
||||||
|
"Exif": {
|
||||||
|
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
piexif.insert(exif_bytes, filename)
|
||||||
|
else:
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||||
|
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||||
"""Save an image.
|
"""Save an image.
|
||||||
|
|
||||||
@ -565,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||||
|
|
||||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
"""
|
||||||
|
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||||
|
"""
|
||||||
temp_file_path = f"{filename_without_extension}.tmp"
|
temp_file_path = f"{filename_without_extension}.tmp"
|
||||||
image_format = Image.registered_extensions()[extension]
|
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
|
||||||
if opts.enable_pnginfo:
|
|
||||||
for k, v in params.pnginfo.items():
|
|
||||||
pnginfo_data.add_text(k, str(v))
|
|
||||||
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
|
||||||
|
|
||||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
|
||||||
if image_to_save.mode == 'RGBA':
|
|
||||||
image_to_save = image_to_save.convert("RGB")
|
|
||||||
elif image_to_save.mode == 'I;16':
|
|
||||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
|
||||||
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
|
||||||
|
|
||||||
if opts.enable_pnginfo and info is not None:
|
|
||||||
exif_bytes = piexif.dump({
|
|
||||||
"Exif": {
|
|
||||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
piexif.insert(exif_bytes, temp_file_path)
|
|
||||||
else:
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
|
||||||
|
|
||||||
# atomically rename the file with correct extension
|
|
||||||
os.replace(temp_file_path, filename_without_extension + extension)
|
os.replace(temp_file_path, filename_without_extension + extension)
|
||||||
|
|
||||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
@ -59,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
# try to find corresponding mask for an image using simple filename matching
|
# try to find corresponding mask for an image using simple filename matching
|
||||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||||
# if not found use first one ("same mask for all images" use-case)
|
# if not found use first one ("same mask for all images" use-case)
|
||||||
if not mask_image_path in inpaint_masks:
|
if mask_image_path not in inpaint_masks:
|
||||||
mask_image_path = inpaint_masks[0]
|
mask_image_path = inpaint_masks[0]
|
||||||
mask_image = Image.open(mask_image_path)
|
mask_image = Image.open(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
|
@ -11,7 +11,6 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
@ -160,7 +159,7 @@ class InterrogateModels:
|
|||||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||||
|
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
@ -208,8 +207,8 @@ class InterrogateModels:
|
|||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
for name, topn, items in self.categories():
|
for cat in self.categories():
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if shared.opts.interrogate_return_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
res += f", ({match}:{score/100:.3f})"
|
res += f", ({match}:{score/100:.3f})"
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
from modules import paths
|
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
@ -61,4 +60,4 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
if platform.processor() == 'i386':
|
if platform.processor() == 'i386':
|
||||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||||
|
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -40,7 +39,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||||
print(f"Skipping broken symlink: {full_path}")
|
print(f"Skipping broken symlink: {full_path}")
|
||||||
continue
|
continue
|
||||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||||
continue
|
continue
|
||||||
if full_path not in output:
|
if full_path not in output:
|
||||||
output.append(full_path)
|
output.append(full_path)
|
||||||
@ -108,12 +107,12 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
try:
|
try:
|
||||||
shutil.move(fullpath, dest_path)
|
shutil.move(fullpath, dest_path)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if len(os.listdir(src_path)) == 0:
|
if len(os.listdir(src_path)) == 0:
|
||||||
print(f"Removing empty folder: {src_path}")
|
print(f"Removing empty folder: {src_path}")
|
||||||
shutil.rmtree(src_path, True)
|
shutil.rmtree(src_path, True)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -127,7 +126,7 @@ def load_upscalers():
|
|||||||
full_model = f"modules.{model_name}_model"
|
full_model = f"modules.{model_name}_model"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(full_model)
|
importlib.import_module(full_model)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||||
if self.use_ema and not load_ema:
|
if self.use_ema and not load_ema:
|
||||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
|
ignore_keys = ignore_keys or []
|
||||||
|
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
use_ddim = False
|
use_ddim = False
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .sampler import UniPCSampler
|
from .sampler import UniPCSampler # noqa: F401
|
||||||
|
@ -54,7 +54,8 @@ class UniPCSampler(object):
|
|||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -179,13 +178,13 @@ def model_wrapper(
|
|||||||
model,
|
model,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs=None,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
#condition=None,
|
#condition=None,
|
||||||
#unconditional_condition=None,
|
#unconditional_condition=None,
|
||||||
guidance_scale=1.,
|
guidance_scale=1.,
|
||||||
classifier_fn=None,
|
classifier_fn=None,
|
||||||
classifier_kwargs={},
|
classifier_kwargs=None,
|
||||||
):
|
):
|
||||||
"""Create a wrapper function for the noise prediction model.
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
@ -276,6 +275,9 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
classifier_kwargs = classifier_kwargs or {}
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
@ -342,7 +344,7 @@ def model_wrapper(
|
|||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
assert isinstance(unconditional_condition, dict)
|
assert isinstance(unconditional_condition, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in condition:
|
for k in condition:
|
||||||
if isinstance(condition[k], list):
|
if isinstance(condition[k], list):
|
||||||
c_in[k] = [torch.cat([
|
c_in[k] = [torch.cat([
|
||||||
@ -353,7 +355,7 @@ def model_wrapper(
|
|||||||
unconditional_condition[k],
|
unconditional_condition[k],
|
||||||
condition[k]])
|
condition[k]])
|
||||||
elif isinstance(condition, list):
|
elif isinstance(condition, list):
|
||||||
c_in = list()
|
c_in = []
|
||||||
assert isinstance(unconditional_condition, list)
|
assert isinstance(unconditional_condition, list)
|
||||||
for i in range(len(condition)):
|
for i in range(len(condition)):
|
||||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||||
@ -757,40 +759,44 @@ class UniPC:
|
|||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
t_prev_list = [vec_t]
|
t_prev_list = [vec_t]
|
||||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
with tqdm.tqdm(total=steps) as pbar:
|
||||||
for init_order in range(1, order):
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
for init_order in range(1, order):
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
if model_x is None:
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
model_x = self.model_fn(x, vec_t)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
model_prev_list.append(model_x)
|
|
||||||
t_prev_list.append(vec_t)
|
|
||||||
for step in trange(order, steps + 1):
|
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
|
||||||
if lower_order_final:
|
|
||||||
step_order = min(order, steps + 1 - step)
|
|
||||||
else:
|
|
||||||
step_order = order
|
|
||||||
#print('this step order:', step_order)
|
|
||||||
if step == steps:
|
|
||||||
#print('do not run corrector at the last step')
|
|
||||||
use_corrector = False
|
|
||||||
else:
|
|
||||||
use_corrector = True
|
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
for i in range(order - 1):
|
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
|
||||||
t_prev_list[-1] = vec_t
|
|
||||||
# We do not need to evaluate the final model value.
|
|
||||||
if step < steps:
|
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
pbar.update()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
@ -19,6 +19,7 @@ def connect(token, port, options):
|
|||||||
if not options.get('session_metadata'):
|
if not options.get('session_metadata'):
|
||||||
options['session_metadata'] = 'stable-diffusion-webui'
|
options['session_metadata'] = 'stable-diffusion-webui'
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
import modules.safe
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
# data_path = cmd_opts_pre.data
|
||||||
|
@ -2,8 +2,14 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import shlex
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
|
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
script_path = os.path.dirname(modules_path)
|
||||||
|
|
||||||
sd_configs_path = os.path.join(script_path, "configs")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
|
|||||||
|
|
||||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||||
|
|
||||||
data_path = cmd_opts_pre.data_dir
|
data_path = cmd_opts_pre.data_dir
|
||||||
@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
|
|||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
config_states_dir = os.path.join(script_path, "config_states")
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
|
||||||
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -2,7 +2,6 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -31,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
@ -150,6 +150,8 @@ class StableDiffusionProcessing:
|
|||||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||||
self.is_using_inpainting_conditioning = False
|
self.is_using_inpainting_conditioning = False
|
||||||
self.disable_extra_networks = False
|
self.disable_extra_networks = False
|
||||||
|
self.token_merging_ratio = 0
|
||||||
|
self.token_merging_ratio_hr = 0
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
@ -165,7 +167,8 @@ class StableDiffusionProcessing:
|
|||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
self.sampler = None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
@ -274,6 +277,12 @@ class StableDiffusionProcessing:
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
if for_hr:
|
||||||
|
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
||||||
|
|
||||||
|
return self.token_merging_ratio or opts.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
class Processed:
|
class Processed:
|
||||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||||
@ -303,6 +312,8 @@ class Processed:
|
|||||||
self.styles = p.styles
|
self.styles = p.styles
|
||||||
self.job_timestamp = state.job_timestamp
|
self.job_timestamp = state.job_timestamp
|
||||||
self.clip_skip = opts.CLIP_stop_at_last_layers
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
||||||
|
self.token_merging_ratio = p.token_merging_ratio
|
||||||
|
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
||||||
|
|
||||||
self.eta = p.eta
|
self.eta = p.eta
|
||||||
self.ddim_discretize = p.ddim_discretize
|
self.ddim_discretize = p.ddim_discretize
|
||||||
@ -310,6 +321,7 @@ class Processed:
|
|||||||
self.s_tmin = p.s_tmin
|
self.s_tmin = p.s_tmin
|
||||||
self.s_tmax = p.s_tmax
|
self.s_tmax = p.s_tmax
|
||||||
self.s_noise = p.s_noise
|
self.s_noise = p.s_noise
|
||||||
|
self.s_min_uncond = p.s_min_uncond
|
||||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||||
@ -360,6 +372,9 @@ class Processed:
|
|||||||
def infotext(self, p: StableDiffusionProcessing, index):
|
def infotext(self, p: StableDiffusionProcessing, index):
|
||||||
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
def slerp(val, low, high):
|
def slerp(val, low, high):
|
||||||
@ -472,6 +487,13 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
enable_hr = getattr(p, 'enable_hr', False)
|
||||||
|
token_merging_ratio = p.get_token_merging_ratio()
|
||||||
|
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
||||||
|
|
||||||
|
uses_ensd = opts.eta_noise_seed_delta != 0
|
||||||
|
if uses_ensd:
|
||||||
|
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
@ -489,15 +511,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||||
|
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||||
|
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
@ -523,9 +546,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
sd_models.apply_token_merging(p.sd_model, 0)
|
||||||
|
|
||||||
# restore opts to original state
|
# restore opts to original state
|
||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
@ -660,12 +687,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
step_multiplier = 1
|
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
try:
|
|
||||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
|
|
||||||
@ -978,8 +1001,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@ -1141,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|
||||||
|
@ -95,9 +95,20 @@ def progressapi(req: ProgressRequest):
|
|||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
if image is not None:
|
if image is not None:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="png")
|
|
||||||
|
if opts.live_previews_image_format == "png":
|
||||||
|
# using optimize for large images takes an enormous amount of time
|
||||||
|
if max(*image.size) <= 256:
|
||||||
|
save_kwargs = {"optimize": True}
|
||||||
|
else:
|
||||||
|
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||||
|
|
||||||
|
else:
|
||||||
|
save_kwargs = {}
|
||||||
|
|
||||||
|
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||||
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
live_preview = f"data:image/png;base64,{base64_image}"
|
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||||
id_live_preview = shared.state.id_live_preview
|
id_live_preview = shared.state.id_live_preview
|
||||||
else:
|
else:
|
||||||
live_preview = None
|
live_preview = None
|
||||||
|
@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
l = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-1] = float(tree.children[-1])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-1] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
res.append(tree.children[-1])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
l.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
|
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(res))
|
||||||
|
|
||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def get_schedule(prompt):
|
def get_schedule(prompt):
|
||||||
try:
|
try:
|
||||||
tree = schedule_parser.parse(prompt)
|
tree = schedule_parser.parse(prompt)
|
||||||
except lark.exceptions.LarkError as e:
|
except lark.exceptions.LarkError:
|
||||||
if 0:
|
if 0:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
conds = model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||||
|
|
||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
|||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
for i, cond_schedule in enumerate(c):
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
for current, entry in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
res[i] = cond_schedule[target_index].cond
|
||||||
@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
tensors = []
|
tensors = []
|
||||||
conds_list = []
|
conds_list = []
|
||||||
|
|
||||||
for batch_no, composable_prompts in enumerate(c.batch):
|
for composable_prompts in c.batch:
|
||||||
conds_for_batch = []
|
conds_for_batch = []
|
||||||
|
|
||||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
for composable_prompt in composable_prompts:
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
for current, entry in enumerate(composable_prompt.schedules):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer # noqa: F401
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = self.load_models(path)
|
||||||
@ -134,6 +134,6 @@ def get_realesrgan_models(scaler):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
@ -95,16 +95,16 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
except zipfile.BadZipfile:
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
unpickler.extra_handler = extra_handler
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
@ -32,27 +32,42 @@ class CFGDenoiserParams:
|
|||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
self.image_cond = image_cond
|
self.image_cond = image_cond
|
||||||
"""Conditioning image"""
|
"""Conditioning image"""
|
||||||
|
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
"""Current sigma noise step value"""
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
self.sampling_step = sampling_step
|
self.sampling_step = sampling_step
|
||||||
"""Current Sampling step number"""
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
self.text_cond = text_cond
|
self.text_cond = text_cond
|
||||||
""" Encoder hidden states of text conditioning from prompt"""
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoisedParams:
|
class CFGDenoisedParams:
|
||||||
|
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||||
|
self.x = x
|
||||||
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
|
self.sampling_step = sampling_step
|
||||||
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
|
self.total_sampling_steps = total_sampling_steps
|
||||||
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
self.inner_model = inner_model
|
||||||
|
"""Inner model reference used for denoising"""
|
||||||
|
|
||||||
|
|
||||||
|
class AfterCFGCallbackParams:
|
||||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
@ -87,6 +102,7 @@ callback_map = dict(
|
|||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
callbacks_cfg_denoiser=[],
|
callbacks_cfg_denoiser=[],
|
||||||
callbacks_cfg_denoised=[],
|
callbacks_cfg_denoised=[],
|
||||||
|
callbacks_cfg_after_cfg=[],
|
||||||
callbacks_before_component=[],
|
callbacks_before_component=[],
|
||||||
callbacks_after_component=[],
|
callbacks_after_component=[],
|
||||||
callbacks_image_grid=[],
|
callbacks_image_grid=[],
|
||||||
@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
|||||||
report_exception(c, 'cfg_denoised_callback')
|
report_exception(c, 'cfg_denoised_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||||
|
for c in callback_map['callbacks_cfg_after_cfg']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'cfg_after_cfg_callback')
|
||||||
|
|
||||||
|
|
||||||
def before_component_callback(component, **kwargs):
|
def before_component_callback(component, **kwargs):
|
||||||
for c in callback_map['callbacks_before_component']:
|
for c in callback_map['callbacks_before_component']:
|
||||||
try:
|
try:
|
||||||
@ -240,7 +264,7 @@ def add_callback(callbacks, fun):
|
|||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
@ -332,6 +356,14 @@ def on_cfg_denoised(callback):
|
|||||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_cfg_after_cfg(callback):
|
||||||
|
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_before_component(callback):
|
def on_before_component(callback):
|
||||||
"""register a function to be called before a component is created.
|
"""register a function to be called before a component is created.
|
||||||
The callback is called with arguments:
|
The callback is called with arguments:
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
|
@ -17,6 +17,9 @@ class PostprocessImageArgs:
|
|||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
|
name = None
|
||||||
|
"""script's internal name derived from title"""
|
||||||
|
|
||||||
filename = None
|
filename = None
|
||||||
args_from = None
|
args_from = None
|
||||||
args_to = None
|
args_to = None
|
||||||
@ -25,8 +28,8 @@ class Script:
|
|||||||
is_txt2img = False
|
is_txt2img = False
|
||||||
is_img2img = False
|
is_img2img = False
|
||||||
|
|
||||||
"""A gr.Group component that has all script's UI inside it"""
|
|
||||||
group = None
|
group = None
|
||||||
|
"""A gr.Group component that has all script's UI inside it"""
|
||||||
|
|
||||||
infotext_fields = None
|
infotext_fields = None
|
||||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
@ -38,6 +41,9 @@ class Script:
|
|||||||
various "Send to <X>" buttons when clicked
|
various "Send to <X>" buttons when clicked
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
api_info = None
|
||||||
|
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -231,7 +237,7 @@ def load_scripts():
|
|||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if type(script_class) != type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -295,9 +301,9 @@ class ScriptRunner:
|
|||||||
|
|
||||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
for script_data in auto_processing_scripts + scripts_data:
|
||||||
script = script_class()
|
script = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
|
||||||
@ -313,6 +319,8 @@ class ScriptRunner:
|
|||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
def setup_ui(self):
|
def setup_ui(self):
|
||||||
|
import modules.api.models as api_models
|
||||||
|
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
inputs = [None]
|
inputs = [None]
|
||||||
@ -327,9 +335,28 @@ class ScriptRunner:
|
|||||||
if controls is None:
|
if controls is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||||
|
api_args = []
|
||||||
|
|
||||||
for control in controls:
|
for control in controls:
|
||||||
control.custom_script_source = os.path.basename(script.filename)
|
control.custom_script_source = os.path.basename(script.filename)
|
||||||
|
|
||||||
|
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||||
|
|
||||||
|
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||||
|
v = getattr(control, field, None)
|
||||||
|
if v is not None:
|
||||||
|
setattr(arg_info, field, v)
|
||||||
|
|
||||||
|
api_args.append(arg_info)
|
||||||
|
|
||||||
|
script.api_info = api_models.ScriptInfo(
|
||||||
|
name=script.name,
|
||||||
|
is_img2img=script.is_img2img,
|
||||||
|
is_alwayson=script.alwayson,
|
||||||
|
args=api_args,
|
||||||
|
)
|
||||||
|
|
||||||
if script.infotext_fields is not None:
|
if script.infotext_fields is not None:
|
||||||
self.infotext_fields += script.infotext_fields
|
self.infotext_fields += script.infotext_fields
|
||||||
|
|
||||||
@ -492,7 +519,7 @@ class ScriptRunner:
|
|||||||
module = script_loading.load_module(script.filename)
|
module = script_loading.load_module(script.filename)
|
||||||
cache[filename] = module
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
self.scripts[si] = script_class()
|
self.scripts[si] = script_class()
|
||||||
self.scripts[si].filename = filename
|
self.scripts[si].filename = filename
|
||||||
|
@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||||||
return self.postprocessing_controls.values()
|
return self.postprocessing_controls.values()
|
||||||
|
|
||||||
def postprocess_image(self, p, script_pp, *args):
|
def postprocess_image(self, p, script_pp, *args):
|
||||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||||
pp.info = {}
|
pp.info = {}
|
||||||
|
@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
|||||||
def initialize_scripts(self, scripts_data):
|
def initialize_scripts(self, scripts_data):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in scripts_data:
|
for script_data in scripts_data:
|
||||||
script: ScriptPostprocessing = script_class()
|
script: ScriptPostprocessing = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
|
|
||||||
if script.name == "Simple Upscale":
|
if script.name == "Simple Upscale":
|
||||||
continue
|
continue
|
||||||
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
|||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, component), value in zip(script.controls.items(), script_args):
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
script.process(pp, **process_args)
|
||||||
|
@ -61,7 +61,7 @@ class DisableInitialization:
|
|||||||
if res is None:
|
if res is None:
|
||||||
res = original(url, *args, local_files_only=False, **kwargs)
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import devices, sd_hijack_optimizations, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -34,10 +34,10 @@ def apply_optimizations():
|
|||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
@ -92,12 +92,12 @@ def fix_checkpoint():
|
|||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
#Calculate the weight normally, but ignore the mean
|
#Calculate the weight normally, but ignore the mean
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
|
||||||
#Check if we have weights available
|
#Check if we have weights available
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
loss *= weight
|
loss *= weight
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
#Return the loss, as mean if specified
|
||||||
return loss.mean() if mean else loss
|
return loss.mean() if mean else loss
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
#Temporarily append weights to a place accessible during loss calc
|
||||||
sd_model._custom_loss_weight = w
|
sd_model._custom_loss_weight = w
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
if not hasattr(sd_model, '_old_get_loss'):
|
||||||
@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Delete temporary weights if appended
|
#Delete temporary weights if appended
|
||||||
del sd_model._custom_loss_weight
|
del sd_model._custom_loss_weight
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
#If we have an old loss function, reset the loss function to the original one
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
if hasattr(sd_model, '_old_get_loss'):
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
sd_model.get_loss = sd_model._old_get_loss
|
||||||
@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
|
|||||||
def undo_weighted_forward(sd_model):
|
def undo_weighted_forward(sd_model):
|
||||||
try:
|
try:
|
||||||
del sd_model.weighted_forward
|
del sd_model.weighted_forward
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
@ -216,6 +216,9 @@ class StableDiffusionModelHijack:
|
|||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
def get_prompt_lengths(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
|
if self.clip is None:
|
||||||
|
return "-", "-"
|
||||||
|
|
||||||
_, token_count = self.clip.process_texts([text])
|
_, token_count = self.clip.process_texts([text])
|
||||||
|
|
||||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for position, embedding in fixes:
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
from omegaconf import ListConfig
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
import ldm.models.diffusion.ddpm
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddim import noise_like
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
if isinstance(c, dict):
|
if isinstance(c, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in c:
|
for k in c:
|
||||||
if isinstance(c[k], list):
|
if isinstance(c[k], list):
|
||||||
c_in[k] = [
|
c_in[k] = [
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import collections
|
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
|
|
||||||
def should_hijack_ip2p(checkpoint_info):
|
def should_hijack_ip2p(checkpoint_info):
|
||||||
from modules import sd_models_config
|
from modules import sd_models_config
|
||||||
@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
|||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||||
|
@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
end = i + 2
|
end = i + 2
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
s1 *= self.scale
|
s1 *= self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
del q, k, v
|
del q, k, v
|
||||||
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k_in = k_in * self.scale
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = r1.to(dtype)
|
r1 = r1.to(dtype)
|
||||||
@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
r = r.to(dtype)
|
r = r.to(dtype)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
@ -296,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
# i.e. send it down the unchunked fast-path
|
# i.e. send it down the unchunked fast-path
|
||||||
query_chunk_size = q_tokens
|
|
||||||
kv_chunk_size = k_tokens
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||||
@ -335,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -370,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -452,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -461,7 +460,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -483,7 +482,7 @@ def sdp_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -507,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import open_clip.tokenizer
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import sd_hijack_clip, devices
|
from modules import sd_hijack_clip, devices
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
@ -15,9 +15,9 @@ import ldm.modules.midas as midas
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
import tomesd
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -87,8 +87,7 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
from transformers import logging, CLIPModel # noqa: F401
|
||||||
from transformers import logging, CLIPModel
|
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -167,7 +166,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -239,7 +238,7 @@ def read_metadata_from_safetensors(filename):
|
|||||||
if isinstance(v, str) and v[0:1] == '{':
|
if isinstance(v, str) and v[0:1] == '{':
|
||||||
try:
|
try:
|
||||||
res[k] = json.loads(v)
|
res[k] = json.loads(v)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -374,7 +373,7 @@ def enable_midas_autodownload():
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
if not os.path.exists(midas_path):
|
if not os.path.exists(midas_path):
|
||||||
mkdir(midas_path)
|
mkdir(midas_path)
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
print(f"{model_type} downloaded")
|
print(f"{model_type} downloaded")
|
||||||
@ -415,6 +414,9 @@ class SdModelData:
|
|||||||
def get_sd_model(self):
|
def get_sd_model(self):
|
||||||
if self.sd_model is None:
|
if self.sd_model is None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
if self.sd_model is not None:
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -467,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
@ -538,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
@ -565,7 +566,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
@ -580,3 +581,29 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_merging(sd_model, token_merging_ratio):
|
||||||
|
"""
|
||||||
|
Applies speed and memory optimizations from tomesd.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
||||||
|
|
||||||
|
if current_token_merging_ratio == token_merging_ratio:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(sd_model)
|
||||||
|
|
||||||
|
if token_merging_ratio > 0:
|
||||||
|
tomesd.apply_patch(
|
||||||
|
sd_model,
|
||||||
|
ratio=token_merging_ratio,
|
||||||
|
use_rand=False, # can cause issues with some samplers
|
||||||
|
merge_attn=True,
|
||||||
|
merge_crossattn=False,
|
||||||
|
merge_mlp=False
|
||||||
|
)
|
||||||
|
|
||||||
|
sd_model.applied_token_merged_ratio = token_merging_ratio
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
@ -14,12 +14,18 @@ samplers_for_img2img = []
|
|||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
|
||||||
|
|
||||||
def create_sampler(name, model):
|
def find_sampler_config(name):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
config = all_samplers_map.get(name, None)
|
config = all_samplers_map.get(name, None)
|
||||||
else:
|
else:
|
||||||
config = all_samplers[0]
|
config = all_samplers[0]
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler(name, model):
|
||||||
|
config = find_sampler_config(name)
|
||||||
|
|
||||||
assert config is not None, f'bad sampler name: {name}'
|
assert config is not None, f'bad sampler name: {name}'
|
||||||
|
|
||||||
sampler = config.constructor(model)
|
sampler = config.constructor(model)
|
||||||
|
@ -2,7 +2,7 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, processing, images, sd_vae_approx
|
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -22,7 +22,7 @@ def setup_img2img_steps(p, steps=None):
|
|||||||
return steps, t_enc
|
return steps, t_enc
|
||||||
|
|
||||||
|
|
||||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
|
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||||
|
|
||||||
|
|
||||||
def single_sample_to_image(sample, approximation=None):
|
def single_sample_to_image(sample, approximation=None):
|
||||||
@ -30,15 +30,19 @@ def single_sample_to_image(sample, approximation=None):
|
|||||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||||
|
|
||||||
if approximation == 2:
|
if approximation == 2:
|
||||||
x_sample = sd_vae_approx.cheap_approximation(sample)
|
x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5
|
||||||
elif approximation == 1:
|
elif approximation == 1:
|
||||||
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5
|
||||||
|
elif approximation == 3:
|
||||||
|
x_sample = sample * 1.5
|
||||||
|
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||||
else:
|
else:
|
||||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
|
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||||
|
|
||||||
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
return Image.fromarray(x_sample)
|
return Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
|
||||||
@ -58,6 +62,25 @@ def store_latent(decoded):
|
|||||||
shared.state.assign_current_image(sample_to_image(decoded))
|
shared.state.assign_current_image(sample_to_image(decoded))
|
||||||
|
|
||||||
|
|
||||||
|
def is_sampler_using_eta_noise_seed_delta(p):
|
||||||
|
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
||||||
|
|
||||||
|
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||||
|
|
||||||
|
eta = p.eta
|
||||||
|
|
||||||
|
if eta is None and p.sampler is not None:
|
||||||
|
eta = p.sampler.eta
|
||||||
|
|
||||||
|
if eta is None and sampler_config is not None:
|
||||||
|
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
||||||
|
|
||||||
|
if eta == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return sampler_config.options.get("uses_ensd", False)
|
||||||
|
|
||||||
|
|
||||||
class InterruptedException(BaseException):
|
class InterruptedException(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import modules.models.diffusion.uni_pc
|
|||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
samplers_data_compvis = [
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||||
]
|
]
|
||||||
@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
cond = tensor
|
cond = tensor
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
@ -134,7 +134,11 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.update_step(x)
|
self.update_step(x)
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
if self.is_ddim:
|
||||||
|
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||||
|
else:
|
||||||
|
self.eta = 0.0
|
||||||
|
|
||||||
if self.eta != 0.0:
|
if self.eta != 0.0:
|
||||||
p.extra_generation_params["Eta DDIM"] = self.eta
|
p.extra_generation_params["Eta DDIM"] = self.eta
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import einops
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
@ -9,25 +8,26 @@ from modules.shared import opts, state
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||||
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
|
||||||
|
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
|
||||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),
|
||||||
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
|
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
|
||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True}),
|
||||||
]
|
]
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
samplers_data_k_diffusion = [
|
||||||
@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||||
else:
|
else:
|
||||||
image_uncond = image_cond
|
image_uncond = image_cond
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
|
|
||||||
devices.test_for_nans(x_out, "unet")
|
devices.test_for_nans(x_out, "unet")
|
||||||
@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
|
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||||
|
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||||
|
denoised = after_cfg_callback_params.x
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
@ -317,7 +321,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@ -340,9 +344,9 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args={
|
extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
@ -375,9 +379,9 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from collections import namedtuple
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -88,10 +85,10 @@ def refresh_vae_list():
|
|||||||
|
|
||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||||
for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
|
for vae_file in vae_dict.values():
|
||||||
if os.path.isfile(vae_location):
|
if os.path.basename(vae_file).startswith(checkpoint_path):
|
||||||
return vae_location
|
return vae_file
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
88
modules/sd_vae_taesd.py
Normal file
88
modules/sd_vae_taesd.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Tiny AutoEncoder for Stable Diffusion
|
||||||
|
(DNN for encoding / decoding SD's latent space)
|
||||||
|
|
||||||
|
https://github.com/madebyollin/taesd
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from modules import devices, paths_internal
|
||||||
|
|
||||||
|
sd_vae_taesd = None
|
||||||
|
|
||||||
|
|
||||||
|
def conv(n_in, n_out, **kwargs):
|
||||||
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Clamp(nn.Module):
|
||||||
|
@staticmethod
|
||||||
|
def forward(x):
|
||||||
|
return torch.tanh(x / 3) * 3
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, n_in, n_out):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||||
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||||
|
self.fuse = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|
||||||
|
|
||||||
|
def decoder():
|
||||||
|
return nn.Sequential(
|
||||||
|
Clamp(), conv(4, 64), nn.ReLU(),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||||
|
Block(64, 64), conv(64, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TAESD(nn.Module):
|
||||||
|
latent_magnitude = 3
|
||||||
|
latent_shift = 0.5
|
||||||
|
|
||||||
|
def __init__(self, decoder_path="taesd_decoder.pth"):
|
||||||
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder()
|
||||||
|
self.decoder.load_state_dict(
|
||||||
|
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unscale_latents(x):
|
||||||
|
"""[0, 1] -> raw latents"""
|
||||||
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(model_path):
|
||||||
|
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
|
||||||
|
print(f'Downloading TAESD decoder to: {model_path}')
|
||||||
|
torch.hub.download_url_to_file(model_url, model_path)
|
||||||
|
|
||||||
|
|
||||||
|
def model():
|
||||||
|
global sd_vae_taesd
|
||||||
|
|
||||||
|
if sd_vae_taesd is None:
|
||||||
|
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||||
|
download_model(model_path)
|
||||||
|
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
sd_vae_taesd = TAESD(model_path)
|
||||||
|
sd_vae_taesd.eval()
|
||||||
|
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError('TAESD model not found')
|
||||||
|
|
||||||
|
return sd_vae_taesd.decoder
|
@ -1,12 +1,10 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import requests
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@ -15,7 +13,7 @@ import modules.memmon
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
@ -113,8 +111,47 @@ class State:
|
|||||||
id_live_preview = 0
|
id_live_preview = 0
|
||||||
textinfo = None
|
textinfo = None
|
||||||
time_start = None
|
time_start = None
|
||||||
need_restart = False
|
|
||||||
server_start = None
|
server_start = None
|
||||||
|
_server_command_signal = threading.Event()
|
||||||
|
_server_command: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_restart(self) -> bool:
|
||||||
|
# Compatibility getter for need_restart.
|
||||||
|
return self.server_command == "restart"
|
||||||
|
|
||||||
|
@need_restart.setter
|
||||||
|
def need_restart(self, value: bool) -> None:
|
||||||
|
# Compatibility setter for need_restart.
|
||||||
|
if value:
|
||||||
|
self.server_command = "restart"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server_command(self):
|
||||||
|
return self._server_command
|
||||||
|
|
||||||
|
@server_command.setter
|
||||||
|
def server_command(self, value: str | None) -> None:
|
||||||
|
"""
|
||||||
|
Set the server command to `value` and signal that it's been set.
|
||||||
|
"""
|
||||||
|
self._server_command = value
|
||||||
|
self._server_command_signal.set()
|
||||||
|
|
||||||
|
def wait_for_server_command(self, timeout: float | None = None) -> str | None:
|
||||||
|
"""
|
||||||
|
Wait for server command to get set; return and clear the value and signal.
|
||||||
|
"""
|
||||||
|
if self._server_command_signal.wait(timeout):
|
||||||
|
self._server_command_signal.clear()
|
||||||
|
req = self._server_command
|
||||||
|
self._server_command = None
|
||||||
|
return req
|
||||||
|
return None
|
||||||
|
|
||||||
|
def request_restart(self) -> None:
|
||||||
|
self.interrupt()
|
||||||
|
self.server_command = "restart"
|
||||||
|
|
||||||
def skip(self):
|
def skip(self):
|
||||||
self.skipped = True
|
self.skipped = True
|
||||||
@ -202,8 +239,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
@ -212,9 +250,33 @@ class OptionInfo:
|
|||||||
self.section = section
|
self.section = section
|
||||||
self.refresh = refresh
|
self.refresh = refresh
|
||||||
|
|
||||||
|
self.comment_before = comment_before
|
||||||
|
"""HTML text that will be added after label in UI"""
|
||||||
|
|
||||||
|
self.comment_after = comment_after
|
||||||
|
"""HTML text that will be added before label in UI"""
|
||||||
|
|
||||||
|
def link(self, label, url):
|
||||||
|
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def js(self, label, js_func):
|
||||||
|
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def info(self, info):
|
||||||
|
self.comment_after += f"<span class='info'>({info})</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
def needs_restart(self):
|
||||||
|
self.comment_after += " <span class='info'>(requires restart)</span>"
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for k, v in options_dict.items():
|
for v in options_dict.values():
|
||||||
v.section = section_identifier
|
v.section = section_identifier
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
@ -243,7 +305,7 @@ options_templates = {}
|
|||||||
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
|
||||||
"samples_save": OptionInfo(True, "Always save all generated images"),
|
"samples_save": OptionInfo(True, "Always save all generated images"),
|
||||||
"samples_format": OptionInfo('png', 'File format for images'),
|
"samples_format": OptionInfo('png', 'File format for images'),
|
||||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs),
|
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||||
|
|
||||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||||
@ -262,10 +324,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
|
||||||
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
|
||||||
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
|
||||||
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
|
"export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
|
||||||
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
|
||||||
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
|
||||||
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number),
|
"img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
|
||||||
|
|
||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
@ -293,31 +355,30 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
|
||||||
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs),
|
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
|
||||||
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
|
||||||
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
|
||||||
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
@ -339,20 +400,27 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
|
"s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
@ -364,30 +432,35 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
|
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
|
||||||
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
|
||||||
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
|
||||||
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
|
"deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
|
||||||
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
|
"deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
|
||||||
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
|
"deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
|
||||||
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
|
"deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||||
|
"extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."),
|
||||||
|
"extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'),
|
||||||
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||||
|
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
@ -400,17 +473,16 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
|
||||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
@ -423,27 +495,27 @@ options_templates.update(options_section(('infotext', "Infotext"), {
|
|||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
|
||||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
"show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"),
|
||||||
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
|
||||||
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
|
"live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
|
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"),
|
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||||
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
|
||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}),
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
|
||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -460,6 +532,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update()
|
options_templates.update()
|
||||||
|
|
||||||
|
|
||||||
@ -571,7 +644,9 @@ class Options:
|
|||||||
func()
|
func()
|
||||||
|
|
||||||
def dumpjson(self):
|
def dumpjson(self):
|
||||||
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
|
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
|
||||||
|
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
|
||||||
|
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
|
||||||
return json.dumps(d)
|
return json.dumps(d)
|
||||||
|
|
||||||
def add_option(self, key, info):
|
def add_option(self, key, info):
|
||||||
@ -582,11 +657,11 @@ class Options:
|
|||||||
|
|
||||||
section_ids = {}
|
section_ids = {}
|
||||||
settings_items = self.data_labels.items()
|
settings_items = self.data_labels.items()
|
||||||
for k, item in settings_items:
|
for _, item in settings_items:
|
||||||
if item.section not in section_ids:
|
if item.section not in section_ids:
|
||||||
section_ids[item.section] = len(section_ids)
|
section_ids[item.section] = len(section_ids)
|
||||||
|
|
||||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
def cast_value(self, key, value):
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
@ -748,11 +823,14 @@ def walk_files(path, allowed_extensions=None):
|
|||||||
if allowed_extensions is not None:
|
if allowed_extensions is not None:
|
||||||
allowed_extensions = set(allowed_extensions)
|
allowed_extensions = set(allowed_extensions)
|
||||||
|
|
||||||
for root, dirs, files in os.walk(path):
|
for root, _, files in os.walk(path, followlinks=True):
|
||||||
for filename in files:
|
for filename in files:
|
||||||
if allowed_extensions is not None:
|
if allowed_extensions is not None:
|
||||||
_, ext = os.path.splitext(filename)
|
_, ext = os.path.splitext(filename)
|
||||||
if ext not in allowed_extensions:
|
if ext not in allowed_extensions:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not opts.list_hidden_files and ("/." in root or "\\." in root):
|
||||||
|
continue
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
yield os.path.join(root, filename)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user