Merge branch 'dev' into find_vae

This commit is contained in:
AUTOMATIC1111 2023-05-14 11:46:27 +03:00 committed by GitHub
commit 80adb6979d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
132 changed files with 2215 additions and 1468 deletions

View File

@ -18,22 +18,29 @@ jobs:
steps: steps:
- name: Checkout Code - name: Checkout Code
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Set up Python 3.10 - uses: actions/setup-python@v4
uses: actions/setup-python@v4
with: with:
python-version: 3.10.6 python-version: 3.11
cache: pip # NB: there's no cache: pip here since we're not installing anything
cache-dependency-path: | # from the requirements.txt file(s) in the repository; it's faster
**/requirements*txt # not to have GHA download an (at the time of writing) 4 GB cache
- name: Install PyLint # of PyTorch and other dependencies.
run: | - name: Install Ruff
python -m pip install --upgrade pip run: pip install ruff==0.0.265
pip install pylint - name: Run Ruff
# This lets PyLint check to see if it can resolve imports run: ruff .
- name: Install dependencies
run: | # The rest are currently disabled pending fixing of e.g. installing the torch dependency.
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py # - name: Install PyLint
- name: Analysing the code with pylint # run: |
run: | # python -m pip install --upgrade pip
pylint $(git ls-files '*.py') # pip install pylint
# # This lets PyLint check to see if it can resolve imports
# - name: Install dependencies
# run: |
# export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
# python launch.py
# - name: Analysing the code with pylint
# run: |
# pylint $(git ls-files '*.py')

View File

@ -17,8 +17,14 @@ jobs:
cache: pip cache: pip
cache-dependency-path: | cache-dependency-path: |
**/requirements*txt **/requirements*txt
launch.py
- name: Run tests - name: Run tests
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
env:
PIP_DISABLE_PIP_VERSION_CHECK: "1"
PIP_PROGRESS_BAR: "off"
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
- name: Upload main app stdout-stderr - name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
if: always() if: always()

View File

@ -1,3 +1,59 @@
## Upcoming 1.2.1
### Features:
* add an option to always refer to lora by filenames
### Bug Fixes:
* never refer to lora by an alias if multiple loras have same alias or the alias is called none
* fix upscalers disappearing after the user reloads UI
* allow bf16 in safe unpickler (resolves problems with loading some loras)
* allow web UI to be ran fully offline
## 1.2.0
### Features:
* do not wait for stable diffusion model to load at startup
* add filename patterns: [denoising]
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
* add version to infotext, footer and console output when starting
* add links to wiki for filename pattern settings
* add extended info for quicksettings setting and use multiselect input instead of a text field
### Minor:
* gradio bumped to 3.29.0
* torch bumped to 2.0.1
* --subpath option for gradio for use with reverse proxy
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
* possible frontend optimization: do not apply localizations if there are none
* Add extra `None` option for VAE in XYZ plot
* print error to console when batch processing in img2img fails
* create HTML for extra network pages only on demand
* allow directories starting with . to still list their models for lora, checkpoints, etc
* put infotext options into their own category in settings tab
* do not show licenses page when user selects Show all pages in settings
### Extensions:
* Tooltip localization support
* Add api method to get LoRA models with prompt
### Bug Fixes:
* re-add /docs endpoint
* fix gamepad navigation
* make the lightbox fullscreen image function properly
* fix squished thumbnails in extras tab
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
* fix webui showing the same image if you configure the generation to always save results into same file
* fix bug with upscalers not working properly
* Fix MPS on PyTorch 2.0.1, Intel Macs
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
* prevent Reload UI button/link from reloading the page when it's not yet ready
* fix prompts from file script failing to read contents from a drag/drop file
## 1.1.1 ## 1.1.1
### Bug Fixes: ### Bug Fixes:
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle * fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle

View File

@ -88,7 +88,7 @@ class LDSR:
x_t = None x_t = None
logs = None logs = None
for n in range(n_runs): for _ in range(n_runs):
if custom_shape is not None: if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
@ -110,7 +110,6 @@ class LDSR:
diffusion_steps = int(steps) diffusion_steps = int(steps)
eta = 1.0 eta = 1.0
down_sample_method = 'Lanczos'
gc.collect() gc.collect()
if torch.cuda.is_available: if torch.cuda.is_available:
@ -131,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else: else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta) logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"] sample = logs["sample"]
@ -158,7 +157,7 @@ class LDSR:
def get_cond(selected_path): def get_cond(selected_path):
example = dict() example = {}
up_f = 4 up_f = 4
c = selected_path.convert('RGB') c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
@torch.no_grad() @torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict() log = {}
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except: except Exception:
pass pass
log["sample"] = x_sample log["sample"] = x_sample

View File

@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks from modules import shared, script_callbacks
import sd_hijack_autoencoder, sd_hijack_ddpm_v1 import sd_hijack_autoencoder # noqa: F401
import sd_hijack_ddpm_v1 # noqa: F401
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):

View File

@ -1,16 +1,21 @@
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import numpy as np
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.ema import LitEma
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
import ldm.models.autoencoder import ldm.models.autoencoder
from packaging import version
class VQModel(pl.LightningModule): class VQModel(pl.LightningModule):
def __init__(self, def __init__(self,
@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
n_embed, n_embed,
embed_dim, embed_dim,
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
image_key="image", image_key="image",
colorize_nlabels=None, colorize_nlabels=None,
monitor=None, monitor=None,
@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor self.lr_g_factor = lr_g_factor
@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()): def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"] sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx) log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope(): with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict return log_dict
def _validation_step(self, batch, batch_idx, suffix=""): def _validation_step(self, batch, batch_idx, suffix=""):
@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
return self.decoder.conv_out.weight return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.image_key) x = self.get_input(batch, self.image_key)
x = x.to(self.device) x = x.to(self.device)
if only_inputs: if only_inputs:
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
if plot_ema: if plot_ema:
with self.ema_scope(): with self.ema_scope():
xrec_ema, _ = self(x) xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema log["reconstructions_ema"] = xrec_ema
return log return log
@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
class VQModelInterface(VQModel): class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs): def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs) super().__init__(*args, embed_dim=embed_dim, **kwargs)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def encode(self, x): def encode(self, x):
@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
setattr(ldm.models.autoencoder, "VQModel", VQModel) ldm.models.autoencoder.VQModel = VQModel
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) ldm.models.autoencoder.VQModelInterface = VQModelInterface

View File

@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
use_ema=True, use_ema=True,
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd["state_dict"] sd = sd["state_dict"]
keys = list(sd.keys()) keys = list(sd.keys())
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys or []:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.first_stage_key) x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N) N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row) n_row = min(x.shape[0], n_row)
@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
log["inputs"] = x log["inputs"] = x
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
x_start = x[:n_row] x_start = x[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
conditioning_key = None conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", []) ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs) super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except Exception:
self.num_downs = 0 self.num_downs = 0
if not scale_by_std: if not scale_by_std:
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config) self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False self.clip_denoised = False
self.bbox_tokenizer = None self.bbox_tokenizer = None
self.restarted_from_ckpt = False self.restarted_from_ckpt = False
if ckpt_path is not None: if ckpt_path is not None:
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim # 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface): if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i], output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize) force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])] for i in range(z.shape[-1])]
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"): if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128) ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64) stride = self.split_input_params["stride"] # eg. (64, 64)
@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial) intermediates.append(x0_partial)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img) intermediates.append(img)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates: if return_intermediates:
return img, intermediates return img, intermediates
@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond, return self.p_sample_loop(cond,
@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
log = dict() log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
force_c_encode=True, force_c_encode=True,
@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
if plot_diffusion_rows: if plot_diffusion_rows:
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
z_start = z[:n_row] z_start = z[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs) logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation' key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key] dset = self.trainer.datamodule.datasets[key]
@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
logs['bbox_image'] = cond_img logs['bbox_image'] = cond_img
return logs return logs
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1

View File

@ -1,6 +1,7 @@
from modules import extra_networks, shared from modules import extra_networks, shared
import lora import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')

View File

@ -1,10 +1,9 @@
import glob
import os import os
import re import re
import torch import torch
from typing import Union from typing import Union
from modules import shared, devices, sd_models, errors from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
@ -93,6 +92,7 @@ class LoraOnDisk:
self.metadata = m self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule: class LoraModule:
@ -165,12 +165,14 @@ def load_lora(name, filename):
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention: elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
@ -182,7 +184,7 @@ def load_lora(name, filename):
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0: if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@ -199,11 +201,11 @@ def load_loras(names, multipliers=None):
loaded_loras.clear() loaded_loras.clear()
loras_on_disk = [available_loras.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]): if any(x is None for x in loras_on_disk):
list_available_loras() list_available_loras()
loras_on_disk = [available_loras.get(name, None) for name in names] loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names): for i, name in enumerate(names):
lora = already_loaded.get(name, None) lora = already_loaded.get(name, None)
@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target):
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else: else:
updown = up @ down updown = up @ down
@ -240,6 +244,19 @@ def lora_calc_updown(lora, module, target):
return updown return updown
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
return
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of Loras to the weights of torch layer self. Applies the currently selected set of Loras to the weights of torch layer self.
@ -264,12 +281,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup self.lora_weights_backup = weights_backup
if current_names != wanted_names: if current_names != wanted_names:
if weights_backup is not None: lora_restore_weights_from_backup(self)
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
@ -297,15 +309,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
print(f'failed to calculate lora weights for layer {lora_layer_name}') print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names) self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_loras) == 0:
return original_forward(module, input)
input = devices.cond_cast_unet(input)
lora_restore_weights_from_backup(module)
lora_reset_cached_weight(module)
res = original_forward(module, input)
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
module.up.to(device=devices.device)
module.down.to(device=devices.device)
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ()) self.lora_current_names = ()
setattr(self, "lora_weights_backup", None) self.lora_weights_backup = None
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
lora_apply_weights(self) lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input) return torch.nn.Linear_forward_before_lora(self, input)
@ -318,6 +363,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input): def lora_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
lora_apply_weights(self) lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input) return torch.nn.Conv2d_forward_before_lora(self, input)
@ -343,24 +391,65 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def list_available_loras(): def list_available_loras():
available_loras.clear() available_loras.clear()
available_lora_aliases.clear()
forbidden_lora_aliases.clear()
forbidden_lora_aliases.update({"none": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \ candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates, key=str.lower): for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename): if os.path.isdir(filename):
continue continue
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename)
available_loras[name] = LoraOnDisk(name, filename) available_loras[name] = entry
if entry.alias in available_lora_aliases:
forbidden_lora_aliases[entry.alias.lower()] = 1
available_lora_aliases[name] = entry
available_lora_aliases[entry.alias] = entry
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k in params:
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_lora_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
available_loras = {} available_loras = {}
available_lora_aliases = {}
forbidden_lora_aliases = {}
loaded_loras = [] loaded_loras = []
list_available_loras() list_available_loras()

View File

@ -1,12 +1,12 @@
import torch import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI
import lora import lora
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
@ -49,8 +49,34 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
})) }))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: lora.LoraOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_loras(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
script_callbacks.on_app_started(api_loras)

View File

@ -15,13 +15,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
def list_items(self): def list_items(self):
for name, lora_on_disk in lora.available_loras.items(): for name, lora_on_disk in lora.available_loras.items():
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
alias = name
else:
alias = lora_on_disk.alias
yield { yield {
"name": name, "name": name,
"filename": path, "filename": path,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"), "prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
} }

View File

@ -10,10 +10,9 @@ from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader, script_callbacks
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
from modules.shared import opts from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for k, v in model.named_parameters(): for _, v in model.named_parameters():
v.requires_grad = False v.requires_grad = False
model = model.to(device) model = model.to(device)
return model return model
def on_ui_settings():
import gradio as gr
from modules import shared
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
script_callbacks.on_ui_settings(on_ui_settings)

View File

@ -61,7 +61,9 @@ class WMSA(nn.Module):
Returns: Returns:
output: tensor shape [b h w c] output: tensor shape [b h w c]
""" """
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) if self.type != 'W':
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1) h_windows = x.size(1)
w_windows = x.size(2) w_windows = x.size(2)
@ -85,8 +87,9 @@ class WMSA(nn.Module):
output = self.linear(output) output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), if self.type != 'W':
dims=(1, 2)) output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
return output return output
def relative_embedding(self): def relative_embedding(self):
@ -262,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)

View File

@ -1,4 +1,3 @@
import contextlib
import os import os
import numpy as np import numpy as np
@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except: except Exception:
pass pass
return img return img
@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list: for w_idx in w_idx_list:
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
break break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)

View File

@ -644,7 +644,7 @@ class SwinIR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
x = self.check_image_size(x) x = self.check_image_size(x)
self.mean = self.mean.type_as(x) self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range x = (x - self.mean) * self.img_range
@ -844,7 +844,7 @@ class SwinIR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()

View File

@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]): pretrained_window_size=(0, 0)):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask) self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size): def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
H, W = x_size H, W = x_size
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask return attn_mask
def forward(self, x, x_size): def forward(self, x, x_size):
H, W = x_size H, W = x_size
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else: else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2 flops += H * W * self.dim // 2
return flops return flops
class BasicLayer(nn.Module): class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage. """ A basic Swin Transformer layer for one stage.
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0) nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
r""" Image to Patch Embedding r""" Image to Patch Embedding
Args: Args:
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None: if self.norm is not None:
flops += Ho * Wo * self.embed_dim flops += Ho * Wo * self.embed_dim
return flops return flops
class RSTB(nn.Module): class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB). """Residual Swin Transformer Block (RSTB).
@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, drop_path=drop_path,
norm_layer=norm_layer, norm_layer=norm_layer,
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m) super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential): class Upsample_hf(nn.Sequential):
"""Upsample module. """Upsample module.
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3)) m.append(nn.PixelShuffle(3))
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m) super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential): class UpsampleOneStep(nn.Sequential):
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9 flops = H * W * self.num_feat * 3 * 9
return flops return flops
class Swin2SR(nn.Module): class Swin2SR(nn.Module):
r""" Swin2SR r""" Swin2SR
@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
) )
self.layers.append(layer) self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf': if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList() self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers): for i_layer in range(self.num_layers):
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
) )
self.layers_hf.append(layer) self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction # build the last conv layer in deep feature extraction
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential( self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1), nn.Conv2d(3, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat) self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf': elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect': elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters) # for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward_features_hf(self, x): def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3]) x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x) x = self.patch_embed(x)
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x) x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before)) x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before) x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf) x_hf = self.conv_before_upsample_hf(x_hf)
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x) x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res) x = x + self.conv_last(res)
x = x / self.img_range + self.mean x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux": if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux return x[:, :, :H*self.upscale, :W*self.upscale], aux
elif self.upsampler == "pixelshuffle_hf": elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
else: else:
return x[:, :, :H*self.upscale, :W*self.upscale] return x[:, :, :H*self.upscale, :W*self.upscale]
@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
H, W = self.patches_resolution H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9 flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops() flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers): for layer in self.layers:
flops += layer.flops() flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops() flops += self.upsample.flops()
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width)) x = torch.randn((1, 3, height, width))
x = model(x) x = model(x)
print(x.shape) print(x.shape)

View File

@ -6,7 +6,7 @@
<ul> <ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a> <a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul> </ul>
<span style="display:none" class='search_term'>{search_term}</span> <span style="display:none" class='search_term{serach_only}'>{search_term}</span>
</div> </div>
<span class='name'>{name}</span> <span class='name'>{name}</span>
<span class='description'>{description}</span> <span class='description'>{description}</span>

View File

@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
var viewportOffset = targetElement.getBoundingClientRect(); var viewportOffset = targetElement.getBoundingClientRect();
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight ) var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
scaledx = targetElement.naturalWidth*viewportscale var scaledx = targetElement.naturalWidth*viewportscale
scaledy = targetElement.naturalHeight*viewportscale var scaledy = targetElement.naturalHeight*viewportscale
cleintRectTop = (viewportOffset.top+window.scrollY) var cleintRectTop = (viewportOffset.top+window.scrollY)
cleintRectLeft = (viewportOffset.left+window.scrollX) var cleintRectLeft = (viewportOffset.left+window.scrollX)
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2) var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2) var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
viewRectTop = cleintRectCentreY-(scaledy/2) var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
viewRectLeft = cleintRectCentreX-(scaledx/2) var arscaledx = currentWidth*arscale
arRectWidth = scaledx var arscaledy = currentHeight*arscale
arRectHeight = scaledy
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight ) var arRectTop = cleintRectCentreY-(arscaledy/2)
arscaledx = currentWidth*arscale var arRectLeft = cleintRectCentreX-(arscaledx/2)
arscaledy = currentHeight*arscale var arRectWidth = arscaledx
var arRectHeight = arscaledy
arRectTop = cleintRectCentreY-(arscaledy/2)
arRectLeft = cleintRectCentreX-(arscaledx/2)
arRectWidth = arscaledx
arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px'; arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px'; arPreviewRect.style.left = arRectLeft+'px';

View File

@ -4,7 +4,7 @@ contextMenuInit = function(){
let menuSpecs = new Map(); let menuSpecs = new Map();
const uid = function(){ const uid = function(){
return Date.now().toString(36) + Math.random().toString(36).substr(2); return Date.now().toString(36) + Math.random().toString(36).substring(2);
} }
function showContextMenu(event,element,menuEntries){ function showContextMenu(event,element,menuEntries){
@ -16,8 +16,7 @@ contextMenuInit = function(){
oldMenu.remove() oldMenu.remove()
} }
let tabButton = uiCurrentTab let baseStyle = window.getComputedStyle(uiCurrentTab)
let baseStyle = window.getComputedStyle(tabButton)
const contextMenu = document.createElement('nav') const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu" contextMenu.id = "context-menu"
@ -36,7 +35,7 @@ contextMenuInit = function(){
menuEntries.forEach(function(entry){ menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a') let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name'] contextMenuEntry.innerHTML = entry['name']
contextMenuEntry.addEventListener("click", function(e) { contextMenuEntry.addEventListener("click", function() {
entry['func'](); entry['func']();
}) })
contextMenuList.append(contextMenuEntry); contextMenuList.append(contextMenuEntry);
@ -63,7 +62,7 @@ contextMenuInit = function(){
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){ function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetElementSelector) var currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){ if(!currentItems){
currentItems = [] currentItems = []
@ -79,7 +78,7 @@ contextMenuInit = function(){
} }
function removeContextMenuOption(uid){ function removeContextMenuOption(uid){
menuSpecs.forEach(function(v,k) { menuSpecs.forEach(function(v) {
let index = -1 let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}}) v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){ if(index>=0){
@ -93,8 +92,7 @@ contextMenuInit = function(){
return; return;
} }
gradioApp().addEventListener("click", function(e) { gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0] if(! e.isTrusted){
if(source.id && source.id.indexOf('check_progress')>-1){
return return
} }
@ -112,7 +110,6 @@ contextMenuInit = function(){
if(e.composedPath()[0].matches(k)){ if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v) showContextMenu(e,e.composedPath()[0],v)
e.preventDefault() e.preventDefault()
return
} }
}) })
}); });

View File

@ -69,8 +69,8 @@ function keyupEditAttention(event){
event.preventDefault(); event.preventDefault();
closeCharacter = ')' var closeCharacter = ')'
delta = opts.keyedit_precision_attention var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){ if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>' closeCharacter = '>'
@ -91,8 +91,8 @@ function keyupEditAttention(event){
selectionEnd += 1; selectionEnd += 1;
} }
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return; if (isNaN(weight)) return;
weight += isPlus ? delta : -delta; weight += isPlus ? delta : -delta;

View File

@ -1,14 +1,14 @@
function extensions_apply(_, _, disable_all){ function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = [] var disable = []
var update = [] var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked) if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7)) update.push(x.name.substring(7))
}) })
restart_reload() restart_reload()
@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
return [JSON.stringify(disable), JSON.stringify(update), disable_all] return [JSON.stringify(disable), JSON.stringify(update), disable_all]
} }
function extensions_check(_, _){ function extensions_check(){
var disable = [] var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked) if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7)) disable.push(x.name.substring(7))
}) })
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
@ -41,7 +41,7 @@ function install_extension_from_index(button, url){
button.disabled = "disabled" button.disabled = "disabled"
button.value = "Installing..." button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea') var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url textarea.value = url
updateInput(textarea) updateInput(textarea)

View File

@ -1,4 +1,3 @@
function setupExtraNetworksForTab(tabname){ function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
tabs.appendChild(search) tabs.appendChild(search)
tabs.appendChild(refresh) tabs.appendChild(refresh)
search.addEventListener("input", function(evt){ var applyFilter = function(){
searchTerm = search.value.toLowerCase() var searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase() var searchOnly = elem.querySelector('.search_only')
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
var visible = text.indexOf(searchTerm) != -1
if(searchOnly && searchTerm.length < 4){
visible = false
}
elem.style.display = visible ? "" : "none"
}) })
}); }
search.addEventListener("input", applyFilter);
applyFilter();
extraNetworksApplyFilter[tabname] = applyFilter;
} }
function applyExtraNetworkFilter(tabname){
setTimeout(extraNetworksApplyFilter[tabname], 1);
}
var extraNetworksApplyFilter = {}
var activePromptTextarea = {}; var activePromptTextarea = {};
function setupExtraNetworks(){ function setupExtraNetworks(){
@ -55,7 +72,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var partToSearch = m[1] var partToSearch = m[1]
var replaced = false var replaced = false
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){ var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
m = found.match(re_extranet); m = found.match(re_extranet);
if(m[1] == partToSearch){ if(m[1] == partToSearch){
replaced = true; replaced = true;
@ -96,9 +113,9 @@ function saveCardPreview(event, tabname, filename){
} }
function extraNetworksSearchButton(tabs_id, event){ function extraNetworksSearchButton(tabs_id, event){
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea') var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
button = event.target var button = event.target
text = button.classList.contains("search-all") ? "" : button.textContent.trim() var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text searchTextarea.value = text
updateInput(searchTextarea) updateInput(searchTextarea)
@ -133,7 +150,7 @@ function popup(contents){
} }
function extraNetworksShowMetadata(text){ function extraNetworksShowMetadata(text){
elem = document.createElement('pre') var elem = document.createElement('pre')
elem.classList.add('popup-metadata'); elem.classList.add('popup-metadata');
elem.textContent = text; elem.textContent = text;
@ -165,7 +182,7 @@ function requestGet(url, data, handler, errorHandler){
} }
function extraNetworksRequestMetadata(event, extraPage, cardName){ function extraNetworksRequestMetadata(event, extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); } var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){ requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){ if(data && data.metadata){

View File

@ -23,7 +23,7 @@ let modalObserver = new MutationObserver(function(mutations) {
}); });
function attachGalleryListeners(tab_name) { function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery') var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click()); gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => { gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow

View File

@ -66,8 +66,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.", "Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
@ -118,16 +118,18 @@ titles = {
onUiUpdate(function(){ onUiUpdate(function(){
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){ gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
tooltip = titles[span.textContent]; if (span.title) return; // already has a title
if(!tooltip){ let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
tooltip = titles[span.value];
if(!tooltip){
tooltip = localization[titles[span.value]] || titles[span.value];
} }
if(!tooltip){ if(!tooltip){
for (const c of span.classList) { for (const c of span.classList) {
if (c in titles) { if (c in titles) {
tooltip = titles[c]; tooltip = localization[titles[c]] || titles[c];
break; break;
} }
} }
@ -142,7 +144,7 @@ onUiUpdate(function(){
if (select.onchange != null) return; if (select.onchange != null) return;
select.onchange = function(){ select.onchange = function(){
select.title = titles[select.value] || ""; select.title = localization[titles[select.value]] || titles[select.value] || "";
} }
}) })
}) })

View File

@ -1,16 +1,12 @@
function setInactive(elem, inactive){
if(inactive){
elem.classList.add('inactive')
} else{
elem.classList.remove('inactive')
}
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') function setInactive(elem, inactive){
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') elem.classList.toggle('inactive', !!inactive)
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') }
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""

View File

@ -2,11 +2,10 @@
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
* @see https://github.com/gradio-app/gradio/issues/1721 * @see https://github.com/gradio-app/gradio/issues/1721
*/ */
window.addEventListener( 'resize', () => imageMaskResize());
function imageMaskResize() { function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) { if ( ! canvases.length ) {
canvases_fixed = false; canvases_fixed = false; // TODO: this is unused..?
window.removeEventListener( 'resize', imageMaskResize ); window.removeEventListener( 'resize', imageMaskResize );
return; return;
} }
@ -15,7 +14,7 @@ function imageMaskResize() {
const previewImage = wrapper.previousElementSibling; const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) { if ( ! previewImage.complete ) {
previewImage.addEventListener( 'load', () => imageMaskResize()); previewImage.addEventListener( 'load', imageMaskResize);
return; return;
} }
@ -24,7 +23,6 @@ function imageMaskResize() {
const nw = previewImage.naturalWidth; const nw = previewImage.naturalWidth;
const nh = previewImage.naturalHeight; const nh = previewImage.naturalHeight;
const portrait = nh > nw; const portrait = nh > nw;
const factor = portrait;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw); const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh); const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
@ -40,6 +38,7 @@ function imageMaskResize() {
c.style.maxHeight = '100%'; c.style.maxHeight = '100%';
c.style.objectFit = 'contain'; c.style.objectFit = 'contain';
}); });
} }
onUiUpdate(() => imageMaskResize()); onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize);

View File

@ -1,7 +1,6 @@
window.onload = (function(){ window.onload = (function(){
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder.indexOf("Prompt") == -1) return; if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";

View File

@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
}) })
if (result != -1) { if (result != -1) {
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)] var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
nextButton.click() nextButton.click()
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
} }
function modalZoomSet(modalImage, enable) { function modalZoomSet(modalImage, enable) {
if (enable) { if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
modalImage.classList.add('modalImageFullscreen');
} else {
modalImage.classList.remove('modalImageFullscreen');
}
} }
function modalZoomToggle(event) { function modalZoomToggle(event) {
modalImage = gradioApp().getElementById("modalImage"); var modalImage = gradioApp().getElementById("modalImage");
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
event.stopPropagation() event.stopPropagation()
} }
@ -179,7 +175,7 @@ function galleryImageHandler(e) {
} }
onUiUpdate(function() { onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img') var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) { if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox); fullImg_preview.forEach(setupImageForLightbox);
} }

View File

@ -1,36 +1,57 @@
let delay = 350//ms window.addEventListener('gamepadconnected', (e) => {
window.addEventListener('gamepadconnected', (e) => { const index = e.gamepad.index;
console.log("Gamepad connected!") let isWaiting = false;
const gamepad = e.gamepad; setInterval(async () => {
setInterval(() => { if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const xValue = gamepad.axes[0].toFixed(2); const gamepad = navigator.getGamepads()[index];
if (xValue < -0.3) { const xValue = gamepad.axes[0];
modalPrevImage(e); if (xValue <= -0.3) {
} else if (xValue > 0.3) {
modalNextImage(e);
}
}, delay);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e); modalPrevImage(e);
} else if (e.deltaX >= 0.6) { isWaiting = true;
} else if (xValue >= 0.3) {
modalNextImage(e); modalNextImage(e);
isWaiting = true;
} }
if (isWaiting) {
await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0]
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
}, opts.js_modal_lightbox_gamepad_repeat);
isWaiting = false;
}
}, 10);
});
setTimeout(() => { /*
isScrolling = false; Primarily for vr controller type pointer devices.
}, delay); I use the wheel event because there's currently no way to do it properly with web xr.
}); */
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e);
} else if (e.deltaX >= 0.6) {
modalNextImage(e);
}
setTimeout(() => {
isScrolling = false;
}, opts.js_modal_lightbox_gamepad_repeat);
});
function sleepUntil(f, timeout) {
return new Promise((resolve) => {
const timeStart = new Date();
const wait = setInterval(function() {
if (f() || new Date() - timeStart > timeout) {
clearInterval(wait);
resolve();
}
}, 20);
});
}

View File

@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
original_lines = {} original_lines = {}
translated_lines = {} translated_lines = {}
function hasLocalization() {
return window.localization && Object.keys(window.localization).length > 0;
}
function textNodesUnder(el){ function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n); while(n=walk.nextNode()) a.push(n);
@ -35,11 +39,11 @@ function canBeTranslated(node, text){
if(! text) return false; if(! text) return false;
if(! node.parentElement) return false; if(! node.parentElement) return false;
parentType = node.parentElement.nodeName var parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){ if (parentType=='OPTION' || parentType=='SPAN'){
pnode = node var pnode = node
for(var level=0; level<4; level++){ for(var level=0; level<4; level++){
pnode = pnode.parentElement pnode = pnode.parentElement
if(! pnode) break; if(! pnode) break;
@ -69,7 +73,7 @@ function getTranslation(text){
} }
function processTextNode(node){ function processTextNode(node){
text = node.textContent.trim() var text = node.textContent.trim()
if(! canBeTranslated(node, text)) return if(! canBeTranslated(node, text)) return
@ -105,30 +109,52 @@ function processNode(node){
} }
function dumpTranslations(){ function dumpTranslations(){
dumped = {} if(!hasLocalization()) {
// If we don't have any localization,
// we will not have traversed the app to find
// original_lines, so do that now.
processNode(gradioApp());
}
var dumped = {}
if (localization.rtl) { if (localization.rtl) {
dumped.rtl = true dumped.rtl = true;
} }
Object.keys(original_lines).forEach(function(text){ for (const text in original_lines) {
if(dumped[text] !== undefined) return if(dumped[text] !== undefined) continue;
dumped[text] = localization[text] || text;
}
dumped[text] = localization[text] || text return dumped;
})
return dumped
} }
onUiUpdate(function(m){ function download_localization() {
m.forEach(function(mutation){ var text = JSON.stringify(dumpTranslations(), null, 4)
mutation.addedNodes.forEach(function(node){
processNode(node)
})
});
})
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('download', "localization.json");
element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
document.addEventListener("DOMContentLoaded", function () {
if (!hasLocalization()) {
return;
}
onUiUpdate(function (m) {
m.forEach(function (mutation) {
mutation.addedNodes.forEach(function (node) {
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp()) processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left, if (localization.rtl) { // if the language is from right to left,
@ -149,17 +175,3 @@ document.addEventListener("DOMContentLoaded", function() {
})).observe(gradioApp(), { childList: true }); })).observe(gradioApp(), { childList: true });
} }
}) })
function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
element.setAttribute('download', "localization.json");
element.style.display = 'none';
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}

View File

@ -2,15 +2,15 @@
let lastHeadImg = null; let lastHeadImg = null;
notificationButton = null let notificationButton = null;
onUiUpdate(function(){ onUiUpdate(function(){
if(notificationButton == null){ if(notificationButton == null){
notificationButton = gradioApp().getElementById('request_notifications') notificationButton = gradioApp().getElementById('request_notifications')
if(notificationButton != null){ if(notificationButton != null){
notificationButton.addEventListener('click', function (evt) { notificationButton.addEventListener('click', () => {
Notification.requestPermission(); void Notification.requestPermission();
},true); },true);
} }
} }

View File

@ -1,16 +1,15 @@
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(id_gallery){ function rememberGallerySelection(){
} }
function getGallerySelectedIndex(id_gallery){ function getGallerySelectedIndex(){
} }
function request(url, data, handler, errorHandler){ function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest();
var url = url;
xhr.open("POST", url, true); xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json"); xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () { xhr.onreadystatechange = function () {
@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.style.width = rect.width + "px"; divProgress.style.width = rect.width + "px";
} }
progressText = "" let progressText = ""
divInner.style.width = ((res.progress || 0) * 100.0) + '%' divInner.style.width = ((res.progress || 0) * 100.0) + '%'
divInner.style.background = res.progress ? "" : "transparent" divInner.style.background = res.progress ? "" : "transparent"

View File

@ -1,7 +1,7 @@
// various functions for interaction with ui.py not large enough to warrant putting them in separate files // various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){ function set_theme(theme){
gradioURL = window.location.href var gradioURL = window.location.href
if (!gradioURL.includes('?__theme=')) { if (!gradioURL.includes('?__theme=')) {
window.location.replace(gradioURL + '?__theme=' + theme); window.location.replace(gradioURL + '?__theme=' + theme);
} }
@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
return [gallery[0]]; return [gallery[0]];
} }
index = selected_gallery_index() var index = selected_gallery_index()
if (index < 0 || index >= gallery.length){ if (index < 0 || index >= gallery.length){
// Use the first image in the gallery as the default // Use the first image in the gallery as the default
@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
} }
function args_to_array(args){ function args_to_array(args){
res = [] var res = []
for(var i=0;i<args.length;i++){ for(var i=0;i<args.length;i++){
res.push(args[i]) res.push(args[i])
} }
@ -138,7 +138,7 @@ function get_img2img_tab_index() {
} }
function create_submit_args(args){ function create_submit_args(args){
res = [] var res = []
for(var i=0;i<args.length;i++){ for(var i=0;i<args.length;i++){
res.push(args[i]) res.push(args[i])
} }
@ -160,7 +160,7 @@ function showSubmitButtons(tabname, show){
} }
function showRestoreProgressButton(tabname, show){ function showRestoreProgressButton(tabname, show){
button = gradioApp().getElementById(tabname + "_restore_progress") var button = gradioApp().getElementById(tabname + "_restore_progress")
if(! button) return if(! button) return
button.style.display = show ? "flex" : "none" button.style.display = show ? "flex" : "none"
@ -207,8 +207,9 @@ function submit_img2img(){
return res return res
} }
function restoreProgressTxt2img(x){ function restoreProgressTxt2img(){
showRestoreProgressButton("txt2img", false) showRestoreProgressButton("txt2img", false)
var id = localStorage.getItem("txt2img_task_id")
id = localStorage.getItem("txt2img_task_id") id = localStorage.getItem("txt2img_task_id")
@ -220,10 +221,11 @@ function restoreProgressTxt2img(x){
return id return id
} }
function restoreProgressImg2img(x){
showRestoreProgressButton("img2img", false)
id = localStorage.getItem("img2img_task_id") function restoreProgressImg2img(){
showRestoreProgressButton("img2img", false)
var id = localStorage.getItem("img2img_task_id")
if(id) { if(id) {
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
@ -252,7 +254,7 @@ function modelmerger(){
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') var name_ = prompt('Style name:')
return [name_, prompt_text, negative_prompt_text] return [name_, prompt_text, negative_prompt_text]
} }
@ -287,11 +289,11 @@ function recalculate_prompts_img2img(){
} }
opts = {} var opts = {}
onUiUpdate(function(){ onUiUpdate(function(){
if(Object.keys(opts).length != 0) return; if(Object.keys(opts).length != 0) return;
json_elem = gradioApp().getElementById('settings_json') var json_elem = gradioApp().getElementById('settings_json')
if(json_elem == null) return; if(json_elem == null) return;
var textarea = json_elem.querySelector('textarea') var textarea = json_elem.querySelector('textarea')
@ -340,12 +342,15 @@ onUiUpdate(function(){
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button') registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button') registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
show_all_pages = gradioApp().getElementById('settings_show_all_pages') var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div') var settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){ if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages) settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){ show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){ gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
if(elem.id == "settings_tab_licenses")
return;
elem.style.display = "block"; elem.style.display = "block";
}) })
} }
@ -353,9 +358,9 @@ onUiUpdate(function(){
}) })
onOptionsChanged(function(){ onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash') var elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || "" var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
shorthash = sd_checkpoint_hash.substr(0,10) var shorthash = sd_checkpoint_hash.substring(0,10)
if(elem && elem.textContent != shorthash){ if(elem && elem.textContent != shorthash){
elem.textContent = shorthash elem.textContent = shorthash
@ -390,7 +395,16 @@ function update_token_counter(button_id) {
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
var requestPing = function(){
requestGet("./internal/ping", {}, function(data){
location.reload();
}, function(){
setTimeout(requestPing, 500);
})
}
setTimeout(requestPing, 2000);
return [] return []
} }

View File

@ -0,0 +1,62 @@
// various hints and extra info for the settings tab
settingsHintsSetup = false
onOptionsChanged(function(){
if(settingsHintsSetup) return
settingsHintsSetup = true
gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div){
var name = div.id.substr(8)
var commentBefore = opts._comments_before[name]
var commentAfter = opts._comments_after[name]
if(! commentBefore && !commentAfter) return
var span = null
if(div.classList.contains('gradio-checkbox')) span = div.querySelector('label span')
else if(div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild
else if(div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild
else span = div.querySelector('label span').firstChild
if(!span) return
if(commentBefore){
var comment = document.createElement('DIV')
comment.className = 'settings-comment'
comment.innerHTML = commentBefore
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
span.parentElement.insertBefore(comment, span)
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
}
if(commentAfter){
var comment = document.createElement('DIV')
comment.className = 'settings-comment'
comment.innerHTML = commentAfter
span.parentElement.insertBefore(comment, span.nextSibling)
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling)
}
})
})
function settingsHintsShowQuicksettings(){
requestGet("./internal/quicksettings-hint", {}, function(data){
var table = document.createElement('table')
table.className = 'settings-value-table'
data.forEach(function(obj){
var tr = document.createElement('tr')
var td = document.createElement('td')
td.textContent = obj.name
tr.appendChild(td)
var td = document.createElement('td')
td.textContent = obj.label
tr.appendChild(td)
table.appendChild(tr)
})
popup(table);
})
}

105
launch.py
View File

@ -3,24 +3,23 @@ import subprocess
import os import os
import sys import sys
import importlib.util import importlib.util
import shlex
import platform import platform
import json import json
from functools import lru_cache
from modules import cmd_args from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
dir_repos = "repositories" dir_repos = "repositories"
# Whether to default to printing command output
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ: if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@ -56,51 +55,52 @@ Use --skip-python-version-check to suppress this warning.
""") """)
@lru_cache()
def commit_hash(): def commit_hash():
global stored_commit_hash
if stored_commit_hash is not None:
return stored_commit_hash
try: try:
stored_commit_hash = run(f"{git} rev-parse HEAD").strip() return subprocess.check_output(f"{git} rev-parse HEAD", encoding='utf8').strip()
except Exception: except Exception:
stored_commit_hash = "<none>" return "<none>"
return stored_commit_hash
def run(command, desc=None, errdesc=None, custom_env=None, live=False): @lru_cache()
def git_tag():
try:
return subprocess.check_output(f"{git} describe --tags", encoding='utf8').strip()
except Exception:
return "<none>"
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
if desc is not None: if desc is not None:
print(desc) print(desc)
if live: run_kwargs = {
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) "args": command,
if result.returncode != 0: "shell": True,
raise RuntimeError(f"""{errdesc or 'Error running command'}. "env": os.environ if custom_env is None else custom_env,
Command: {command} "encoding": 'utf8',
Error code: {result.returncode}""") "errors": 'ignore',
}
return "" if not live:
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) result = subprocess.run(**run_kwargs)
if result.returncode != 0: if result.returncode != 0:
error_bits = [
f"{errdesc or 'Error running command'}.",
f"Command: {command}",
f"Error code: {result.returncode}",
]
if result.stdout:
error_bits.append(f"stdout: {result.stdout}")
if result.stderr:
error_bits.append(f"stderr: {result.stderr}")
raise RuntimeError("\n".join(error_bits))
message = f"""{errdesc or 'Error running command'}. return (result.stdout or "")
Command: {command}
Error code: {result.returncode}
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
"""
raise RuntimeError(message)
return result.stdout.decode(encoding="utf8", errors="ignore")
def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0
def is_installed(package): def is_installed(package):
@ -116,11 +116,7 @@ def repo_dir(name):
return os.path.join(script_path, dir_repos, name) return os.path.join(script_path, dir_repos, name)
def run_python(code, desc=None, errdesc=None): def run_pip(command, desc=None, live=default_command_live):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(command, desc=None, live=False):
if args.skip_install: if args.skip_install:
return return
@ -128,8 +124,9 @@ def run_pip(command, desc=None, live=False):
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code): def check_run_python(code: str) -> bool:
return check_run(f'"{python}" -c "{code}"') result = subprocess.run([python, "-c", code], capture_output=True, shell=True)
return result.returncode == 0
def git_clone(url, dir, name, commithash=None): def git_clone(url, dir, name, commithash=None):
@ -222,13 +219,14 @@ def run_extensions_installers(settings_file):
def prepare_environment(): def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
@ -246,15 +244,20 @@ def prepare_environment():
check_python_version() check_python_version()
commit = commit_hash() commit = commit_hash()
tag = git_tag()
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Version: {tag}")
print(f"Commit hash: {commit}") print(f"Commit hash: {commit}")
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not args.skip_torch_cuda_test: if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") raise RuntimeError(
'Torch is not able to use GPU; '
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
)
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
@ -302,7 +305,7 @@ def prepare_environment():
if args.update_all_extensions: if args.update_all_extensions:
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)

BIN
modules/Roboto-Regular.ttf Normal file

Binary file not shown.

View File

@ -15,7 +15,8 @@ from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import * from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
@ -25,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from modules import devices from modules import devices
from typing import List from typing import Dict, List, Any
import piexif import piexif
import piexif.helper import piexif.helper
def upscaler_to_index(name: str): def upscaler_to_index(name: str):
try: try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except: except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
def script_name_to_index(name, scripts): def script_name_to_index(name, scripts):
try: try:
return [script.title().lower() for script in scripts].index(name.lower()) return [script.title().lower() for script in scripts].index(name.lower())
except: except Exception as e:
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
def validate_sampler_name(name): def validate_sampler_name(name):
config = sd_samplers.all_samplers_map.get(name, None) config = sd_samplers.all_samplers_map.get(name, None)
@ -48,20 +52,23 @@ def validate_sampler_name(name):
return name return name
def setUpscalers(req: dict): def setUpscalers(req: dict):
reqDict = vars(req) reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict return reqDict
def decode_base64_to_image(encoding): def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"): if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1] encoding = encoding.split(";")[1].split(",")[1]
try: try:
image = Image.open(BytesIO(base64.b64decode(encoding))) image = Image.open(BytesIO(base64.b64decode(encoding)))
return image return image
except Exception as err: except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") raise HTTPException(status_code=500, detail="Invalid encoded image") from e
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
@ -92,6 +99,7 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data) return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI): def api_middleware(app: FastAPI):
rich_available = True rich_available = True
try: try:
@ -99,7 +107,7 @@ def api_middleware(app: FastAPI):
import starlette # importing just so it can be placed on silent list import starlette # importing just so it can be placed on silent list
from rich.console import Console from rich.console import Console
console = Console() console = Console()
except: except Exception:
import traceback import traceback
rich_available = False rich_available = False
@ -157,7 +165,7 @@ def api_middleware(app: FastAPI):
class Api: class Api:
def __init__(self, app: FastAPI, queue_lock: Lock): def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth: if shared.cmd_opts.api_auth:
self.credentials = dict() self.credentials = {}
for auth in shared.cmd_opts.api_auth.split(","): for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":") user, password = auth.split(":")
self.credentials[user] = password self.credentials[user] = password
@ -166,36 +174,36 @@ class Api:
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
api_middleware(self.app) api_middleware(self.app)
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel) self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel) self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.default_script_arg_txt2img = [] self.default_script_arg_txt2img = []
self.default_script_arg_img2img = [] self.default_script_arg_img2img = []
@ -219,17 +227,17 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx] script = script_runner.selectable_scripts[script_idx]
return script, script_idx return script, script_idx
def get_scripts_list(self): def get_scripts_list(self):
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
return ScriptsList(txt2img = t2ilist, img2img = i2ilist) return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
def get_script(self, script_name, script_runner): def get_script(self, script_name, script_runner):
if script_name is None or script_name == "": if script_name is None or script_name == "":
return None, None return None, None
script_idx = script_name_to_index(script_name, script_runner.scripts) script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx] return script_runner.scripts[script_idx]
@ -264,11 +272,11 @@ class Api:
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
for alwayson_script_name in request.alwayson_scripts.keys(): for alwayson_script_name in request.alwayson_scripts.keys():
alwayson_script = self.get_script(alwayson_script_name, script_runner) alwayson_script = self.get_script(alwayson_script_name, script_runner)
if alwayson_script == None: if alwayson_script is None:
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found") raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
# Selectable script in always on script param check # Selectable script in always on script param check
if alwayson_script.alwayson == False: if alwayson_script.alwayson is False:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params") raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests # always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]: if "args" in request.alwayson_scripts[alwayson_script_name]:
# min between arg length in scriptrunner and arg length in the request # min between arg length in scriptrunner and arg length in the request
@ -276,7 +284,7 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx] script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
script_runner = scripts.scripts_txt2img script_runner = scripts.scripts_txt2img
if not script_runner.scripts: if not script_runner.scripts:
script_runner.initialize_scripts(False) script_runner.initialize_scripts(False)
@ -310,7 +318,7 @@ class Api:
p.outpath_samples = opts.outdir_txt2img_samples p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin() shared.state.begin()
if selectable_scripts != None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
@ -320,9 +328,9 @@ class Api:
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
@ -367,7 +375,7 @@ class Api:
p.outpath_samples = opts.outdir_img2img_samples p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin() shared.state.begin()
if selectable_scripts != None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
else: else:
@ -381,9 +389,9 @@ class Api:
img2imgreq.init_images = None img2imgreq.init_images = None
img2imgreq.mask = None img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest): def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
reqDict = setUpscalers(req) reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image']) reqDict['image'] = decode_base64_to_image(reqDict['image'])
@ -391,9 +399,9 @@ class Api:
with self.queue_lock: with self.queue_lock:
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest): def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
reqDict = setUpscalers(req) reqDict = setUpscalers(req)
image_list = reqDict.pop('imageList', []) image_list = reqDict.pop('imageList', [])
@ -402,15 +410,15 @@ class Api:
with self.queue_lock: with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict) result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: PNGInfoRequest): def pnginfoapi(self, req: models.PNGInfoRequest):
if(not req.image.strip()): if(not req.image.strip()):
return PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip()) image = decode_base64_to_image(req.image.strip())
if image is None: if image is None:
return PNGInfoResponse(info="") return models.PNGInfoResponse(info="")
geninfo, items = images.read_info_from_image(image) geninfo, items = images.read_info_from_image(image)
if geninfo is None: if geninfo is None:
@ -418,13 +426,13 @@ class Api:
items = {**{'parameters': geninfo}, **items} items = {**{'parameters': geninfo}, **items}
return PNGInfoResponse(info=geninfo, items=items) return models.PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: ProgressRequest = Depends()): def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
if shared.state.job_count == 0: if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
# avoid dividing zero # avoid dividing zero
progress = 0.01 progress = 0.01
@ -446,9 +454,9 @@ class Api:
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
def interrogateapi(self, interrogatereq: InterrogateRequest): def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image image_b64 = interrogatereq.image
if image_b64 is None: if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found") raise HTTPException(status_code=404, detail="Image not found")
@ -465,7 +473,7 @@ class Api:
else: else:
raise HTTPException(status_code=404, detail="Model not found") raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed) return models.InterrogateResponse(caption=processed)
def interruptapi(self): def interruptapi(self):
shared.state.interrupt() shared.state.interrupt()
@ -570,36 +578,36 @@ class Api:
filename = create_embedding(**args) # create empty embedding filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end() shared.state.end()
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename)) return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return TrainResponse(info = "create embedding error: {error}".format(error = e)) return models.TrainResponse(info=f"create embedding error: {e}")
def create_hypernetwork(self, args: dict): def create_hypernetwork(self, args: dict):
try: try:
shared.state.begin() shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding filename = create_hypernetwork(**args) # create empty embedding
shared.state.end() shared.state.end()
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename)) return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e)) return models.TrainResponse(info=f"create hypernetwork error: {e}")
def preprocess(self, args: dict): def preprocess(self, args: dict):
try: try:
shared.state.begin() shared.state.begin()
preprocess(**args) # quick operation unless blip/booru interrogation is enabled preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end() shared.state.end()
return PreprocessResponse(info = 'preprocess complete') return models.PreprocessResponse(info = 'preprocess complete')
except KeyError as e: except KeyError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e)) return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e: except AssertionError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = "preprocess error: {error}".format(error = e)) return models.PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e: except FileNotFoundError as e:
shared.state.end() shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e)) return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
@ -617,10 +625,10 @@ class Api:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except AssertionError as msg:
shared.state.end() shared.state.end()
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg)) return models.TrainResponse(info=f"train embedding error: {msg}")
def train_hypernetwork(self, args: dict): def train_hypernetwork(self, args: dict):
try: try:
@ -641,14 +649,15 @@ class Api:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except AssertionError:
shared.state.end() shared.state.end()
return TrainResponse(info="train embedding error: {error}".format(error=error)) return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self): def get_memory(self):
try: try:
import os, psutil import os
import psutil
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
@ -675,10 +684,10 @@ class Api:
'events': warnings, 'events': warnings,
} }
else: else:
cuda = { 'error': 'unavailable' } cuda = {'error': 'unavailable'}
except Exception as err: except Exception as err:
cuda = { 'error': f'{err}' } cuda = {'error': f'{err}'}
return MemoryResponse(ram = ram, cuda = cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)

View File

@ -223,8 +223,9 @@ for key in _options:
if(_options[key].dest != 'help'): if(_options[key].dest != 'help'):
flag = _options[key] flag = _options[key]
_type = str _type = str
if _options[key].default is not None: _type = type(_options[key].default) if _options[key].default is not None:
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))}) _type = type(_options[key].default)
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags) FlagsModel = create_model("Flags", **flags)
@ -288,4 +289,4 @@ class MemoryResponse(BaseModel):
class ScriptsList(BaseModel): class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")

View File

@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
max_debug_str_len = 131072 # (1024*1024)/8 max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr) print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}" argStr = f"Arguments: {args} {kwargs}"
print(argStr[:max_debug_str_len], file=sys.stderr) print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len: if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
if extra_outputs_array is None: if extra_outputs_array is None:
extra_outputs_array = [None, ''] extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"] error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False

View File

@ -1,6 +1,6 @@
import argparse import argparse
import os import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -102,3 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')

View File

@ -1,14 +1,12 @@
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py # this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
import math import math
import numpy as np
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List from typing import Optional
from modules.codeformer.vqgan_arch import * from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def calc_mean_std(feat, eps=1e-5): def calc_mean_std(feat, eps=1e-5):
@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): query_pos: Optional[Tensor] = None):
# self attention # self attention
tgt2 = self.norm1(tgt) tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos) q = k = self.with_pos_embed(tgt2, query_pos)
@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=['32', '64', '128', '256'], connect_list=('32', '64', '128', '256'),
fix_modules=['quantize','generator']): fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
if fix_modules is not None: if fix_modules is not None:
@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd) self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer # transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)]) for _ in range(self.n_layers)])
# logits_predict head # logits_predict head
self.idx_pred_layer = nn.Sequential( self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False)) nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = { self.channels = {
'16': 512, '16': 512,
'32': 256, '32': 256,
@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {} enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks): for i, block in enumerate(self.encoder.blocks):
x = block(x) x = block(x)
if i in out_list: if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone() enc_feat_dict[str(x.shape[-1])] = x.clone()
@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks): for i, block in enumerate(self.generator.blocks):
x = block(x) x = block(x)
if i in fuse_list: # fuse after i-th block if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1]) f_size = str(x.shape[-1])
if w>0: if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x out = x
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat return out, logits, lq_feat

View File

@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
''' '''
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels): def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script @torch.jit.script
def swish(x): def swish(x):
@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention # compute attention
b, c, h, w = q.shape b, c, h, w = q.shape
q = q.reshape(b, c, h*w) q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1) q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w) k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k) w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2) w_ = F.softmax(w_, dim=2)
# attend to values # attend to values
v = v.reshape(b, c, h*w) v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1) w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_) h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w) h_ = h_.reshape(b, c, h, w)
@ -272,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
return x return x
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__() super().__init__()
self.nf = nf self.nf = nf
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult) self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks self.num_res_blocks = res_blocks
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim self.in_channels = emb_dim
self.out_channels = 3 self.out_channels = 3
@ -317,29 +315,29 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
return x return x
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
self.in_channels = 3 self.in_channels = 3
self.nf = nf self.nf = nf
self.n_blocks = res_blocks self.n_blocks = res_blocks
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions or [16]
self.quantizer_type = quantizer self.quantizer_type = quantizer
self.encoder = Encoder( self.encoder = Encoder(
self.in_channels, self.in_channels,
@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight self.kl_weight
) )
self.generator = Generator( self.generator = Generator(
self.nf, self.nf,
self.embed_dim, self.embed_dim,
self.ch_mult, self.ch_mult,
self.n_blocks, self.n_blocks,
self.resolution, self.resolution,
self.attn_resolutions self.attn_resolutions
) )
@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x)

View File

@ -33,11 +33,9 @@ def setup_model(dirname):
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url from basicsr.utils import img2tensor, tensor2img
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
@ -96,7 +94,7 @@ def setup_model(dirname):
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face() self.face_helper.align_warp_face()
for idx, cropped_face in enumerate(self.face_helper.cropped_faces): for cropped_face in self.face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)

View File

@ -14,7 +14,7 @@ from collections import OrderedDict
import git import git
from modules import shared, extensions from modules import shared, extensions
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir from modules.paths_internal import script_path, config_states_dir
all_config_states = OrderedDict() all_config_states = OrderedDict()
@ -35,7 +35,7 @@ def list_config_states():
j["filepath"] = path j["filepath"] = path
config_states.append(j) config_states.append(j)
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)) config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
for cs in config_states: for cs in config_states:
timestamp = time.asctime(time.gmtime(cs["created_at"])) timestamp = time.asctime(time.gmtime(cs["created_at"]))

View File

@ -2,7 +2,6 @@ import os
import re import re
import torch import torch
from PIL import Image
import numpy as np import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared from modules import modelloader, paths, deepbooru_model, devices, images, shared
@ -79,7 +78,7 @@ class DeepDanbooru:
res = [] res = []
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")]) filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
for tag in [x for x in tags if x not in filtertags]: for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag] probability = probability_dict[tag]

View File

@ -65,7 +65,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True

View File

@ -6,7 +6,7 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
@ -16,9 +16,7 @@ def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer # this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict: if 'conv_first.weight' in state_dict:
crt_net = {} crt_net = {}
items = [] items = list(state_dict)
for k, v in state_dict.items():
items.append(k)
crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias'] crt_net['model.0.bias'] = state_dict['conv_first.bias']
@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0 re8x = 0
crt_net = {} crt_net = {}
items = [] items = list(state_dict)
for k, v in state_dict.items():
items.append(k)
crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias'] crt_net['model.0.bias'] = state_dict['conv_first.bias']
@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, filename = load_file_from_url(
file_name="%s.pth" % self.model_name, url=self.model_url,
progress=True) model_dir=self.model_path,
file_name=f"{self.model_name}.pth",
progress=True,
)
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None: if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename)) print(f"Unable to load {self.model_path} from {filename}")
return None return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

View File

@ -2,7 +2,6 @@
from collections import OrderedDict from collections import OrderedDict
import math import math
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
elif upsample_mode == 'pixelshuffle': elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block upsample_block = pixelshuffle_block
else: else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
if upscale == 3: if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else: else:
@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
Modified options that can be used: Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718 - "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957 - "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo} {Rakotonirina} and A. {Rasoanaivo}
""" """
@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise x = x + sampled_noise
return x return x
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
@ -261,10 +260,10 @@ class Upsample(nn.Module):
def extra_repr(self): def extra_repr(self):
if self.scale_factor is not None: if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor) info = f'scale_factor={self.scale_factor}'
else: else:
info = 'size=' + str(self.size) info = f'size={self.size}'
info += ', mode=' + self.mode info += f', mode={self.mode}'
return info return info
@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
elif act_type == 'sigmoid': # [0, 1] range output elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid() layer = nn.Sigmoid()
else: else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type)) raise NotImplementedError(f'activation layer [{act_type}] is not found')
return layer return layer
@ -372,7 +371,7 @@ def norm(norm_type, nc):
elif norm_type == 'none': elif norm_type == 'none':
def norm_layer(x): return Identity() def norm_layer(x): return Identity()
else: else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type)) raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return layer return layer
@ -388,7 +387,7 @@ def pad(pad_type, padding):
elif pad_type == 'zero': elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding) layer = nn.ZeroPad2d(padding)
else: else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type)) raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
return layer return layer
@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False): spectral_norm=False):
""" Conv layer with padding, normalization, activation """ """ Conv layer with padding, normalization, activation """
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode) assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
padding = get_valid_padding(kernel_size, dilation) padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0 padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D': if convtype=='PartialConv2D':
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups) dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D': elif convtype=='DeformConv2D':
from torchvision.ops import DeformConv2d # not tested
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups) dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D': elif convtype=='Conv3D':

View File

@ -3,11 +3,10 @@ import sys
import traceback import traceback
import time import time
from datetime import datetime
import git import git
from modules import shared from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
extensions = [] extensions = []

View File

@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call """call deactivate for extra networks in extra_network_data in specified order, then call
deactivate for all remaining registered networks""" deactivate for all remaining registered networks"""
for extra_network_name, extra_network_args in extra_network_data.items(): for extra_network_name in extra_network_data:
extra_network = extra_network_registry.get(extra_network_name, None) extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None: if extra_network is None:
continue continue

View File

@ -1,4 +1,4 @@
from modules import extra_networks, shared, extra_networks from modules import extra_networks, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
additional = shared.opts.sd_hypernetwork additional = shared.opts.sd_hypernetwork
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []

View File

@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_instruct_pix2pix_model = False result_is_instruct_pix2pix_model = False
if theta_func2: if theta_func2:
shared.state.textinfo = f"Loading B" shared.state.textinfo = "Loading B"
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
else: else:
theta_1 = None theta_1 = None
if theta_func1: if theta_func1:
shared.state.textinfo = f"Loading C" shared.state.textinfo = "Loading C"
print(f"Loading {tertiary_model_info.filename}...") print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_inpainting_model = True result_is_inpainting_model = True
else: else:
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
theta_0[key] = to_half(theta_0[key], save_as_half) theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1 shared.state.sampling_step += 1

View File

@ -1,15 +1,11 @@
import base64 import base64
import html
import io import io
import math
import os import os
import re import re
from pathlib import Path
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks
import tempfile
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
@ -23,14 +19,14 @@ registered_param_bindings = []
class ParamBinding: class ParamBinding:
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]): def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
self.paste_button = paste_button self.paste_button = paste_button
self.tabname = tabname self.tabname = tabname
self.source_text_component = source_text_component self.source_text_component = source_text_component
self.source_image_component = source_image_component self.source_image_component = source_image_component
self.source_tabname = source_tabname self.source_tabname = source_tabname
self.override_settings_component = override_settings_component self.override_settings_component = override_settings_component
self.paste_field_names = paste_field_names self.paste_field_names = paste_field_names or []
def reset(): def reset():
@ -59,6 +55,7 @@ def image_from_url_text(filedata):
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
filename = filename.rsplit('?', 1)[0]
return Image.open(filename) return Image.open(filename)
if type(filedata) == list: if type(filedata) == list:
@ -129,6 +126,7 @@ def connect_paste_params_buttons():
_js=jsfunc, _js=jsfunc,
inputs=[binding.source_image_component], inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
show_progress=False,
) )
if binding.source_text_component is not None and fields is not None: if binding.source_text_component is not None and fields is not None:
@ -140,6 +138,7 @@ def connect_paste_params_buttons():
fn=lambda *x: x, fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names], outputs=[field for field, name in fields if name in paste_field_names],
show_progress=False,
) )
binding.paste_button.click( binding.paste_button.click(
@ -147,6 +146,7 @@ def connect_paste_params_buttons():
_js=f"switch_to_{binding.tabname}", _js=f"switch_to_{binding.tabname}",
inputs=None, inputs=None,
outputs=None, outputs=None,
show_progress=False,
) )
@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
lines.append(lastline) lines.append(lastline)
lastline = '' lastline = ''
for i, line in enumerate(lines): for line in lines:
line = line.strip() line = line.strip()
if line.startswith("Negative prompt:"): if line.startswith("Negative prompt:"):
done_with_prompt = True done_with_prompt = True
@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v) m = re_imagesize.match(v)
if m is not None: if m is not None:
res[k+"-1"] = m.group(1) res[f"{k}-1"] = m.group(1)
res[k+"-2"] = m.group(2) res[f"{k}-2"] = m.group(2)
else: else:
res[k] = v res[k] = v
@ -308,6 +308,8 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'), ('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'), ('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'), ('UniPC lower order final', 'uni_pc_lower_order_final'),
('Token merging ratio', 'token_merging_ratio'),
('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'), ('RNG', 'randn_source'),
('NGMS', 's_min_uncond'), ('NGMS', 's_min_uncond'),
] ]
@ -409,12 +411,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
fn=paste_func, fn=paste_func,
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
show_progress=False,
) )
button.click( button.click(
fn=None, fn=None,
_js=f"recalculate_prompts_{tabname}", _js=f"recalculate_prompts_{tabname}",
inputs=[], inputs=[],
outputs=[], outputs=[],
show_progress=False,
) )

View File

@ -78,7 +78,7 @@ def setup_model(dirname):
try: try:
from gfpgan import GFPGANer from gfpgan import GFPGANer
from facexlib import detection, parsing from facexlib import detection, parsing # noqa: F401
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor

View File

@ -13,7 +13,7 @@ cache_data = None
def dump_cache(): def dump_cache():
with filelock.FileLock(cache_filename+".lock"): with filelock.FileLock(f"{cache_filename}.lock"):
with open(cache_filename, "w", encoding="utf8") as file: with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4) json.dump(cache_data, file, indent=4)
@ -22,7 +22,7 @@ def cache(subsection):
global cache_data global cache_data
if cache_data is None: if cache_data is None:
with filelock.FileLock(cache_filename+".lock"): with filelock.FileLock(f"{cache_filename}.lock"):
if not os.path.isfile(cache_filename): if not os.path.isfile(cache_filename):
cache_data = {} cache_data = {}
else: else:

View File

@ -1,4 +1,3 @@
import csv
import datetime import datetime
import glob import glob
import html import html
@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
from collections import defaultdict, deque from collections import deque
from statistics import stdev, mean from statistics import stdev, mean
@ -178,34 +177,34 @@ class Hypernetwork:
def weights(self): def weights(self):
res = [] res = []
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
res += layer.parameters() res += layer.parameters()
return res return res
def train(self, mode=True): def train(self, mode=True):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.train(mode=mode) layer.train(mode=mode)
for param in layer.parameters(): for param in layer.parameters():
param.requires_grad = mode param.requires_grad = mode
def to(self, device): def to(self, device):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.to(device) layer.to(device)
return self return self
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.multiplier = multiplier layer.multiplier = multiplier
return self return self
def eval(self): def eval(self):
for k, layers in self.layers.items(): for layers in self.layers.values():
for layer in layers: for layer in layers:
layer.eval() layer.eval()
for param in layer.parameters(): for param in layer.parameters():
@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
k = self.to_k(context_k) k = self.to_k(context_k)
v = self.to_v(context_v) v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad: if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
print(e) print(e)
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size batch_size = ds.batch_size
gradient_step = ds.gradient_step gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed # n steps = batch_size * gradient_step * n image processed
@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
try: try:
sd_hijack_checkpoint.add() sd_hijack_checkpoint.add()
for i in range((steps-initial_step) * gradient_step): for _ in range((steps-initial_step) * gradient_step):
if scheduler.finished: if scheduler.finished:
break break
if shared.state.interrupted: if shared.state.interrupted:
@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight: if use_weight:
@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
loss_logging.append(_loss_step) loss_logging.append(_loss_step)
if clip_grad: if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate) clip_grad(weights, clip_grad_sched.learn_rate)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
hypernetwork.step += 1 hypernetwork.step += 1
@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0 _loss_step = 0
steps_done = hypernetwork.step + 1 steps_done = hypernetwork.step + 1
epoch_num = hypernetwork.step // steps_per_epoch epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch

View File

@ -1,19 +1,17 @@
import html import html
import os
import re
import gradio as gr import gradio as gr
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):

View File

@ -13,17 +13,24 @@ import numpy as np
import piexif import piexif
import piexif.helper import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string import string
import json import json
import hashlib import hashlib
from modules import sd_samplers, shared, script_callbacks, errors from modules import sd_samplers, shared, script_callbacks, errors
from modules.shared import opts, cmd_opts from modules.paths_internal import roboto_ttf_file
from modules.shared import opts
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
def get_font(fontsize: int):
try:
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
except Exception:
return ImageFont.truetype(roboto_ttf_file, fontsize)
def image_grid(imgs, batch_size=1, rows=None): def image_grid(imgs, batch_size=1, rows=None):
if rows is None: if rows is None:
if opts.n_rows > 0: if opts.n_rows > 0:
@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
lines.append(word) lines.append(word)
return lines return lines
def get_font(fontsize):
try:
return ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize): def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines): for line in lines:
fnt = initial_fnt fnt = initial_fnt
fontsize = initial_fontsize fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0: while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
@ -357,6 +358,7 @@ class FilenameGenerator:
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
@ -365,7 +367,7 @@ class FilenameGenerator:
self.seed = seed self.seed = seed
self.prompt = prompt self.prompt = prompt
self.image = image self.image = image
def hasprompt(self, *args): def hasprompt(self, *args):
lower = self.prompt.lower() lower = self.prompt.lower()
if self.p is None or self.prompt is None: if self.p is None or self.prompt is None:
@ -408,13 +410,13 @@ class FilenameGenerator:
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
try: try:
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
except pytz.exceptions.UnknownTimeZoneError as _: except pytz.exceptions.UnknownTimeZoneError:
time_zone = None time_zone = None
time_zone_time = time_datetime.astimezone(time_zone) time_zone_time = time_datetime.astimezone(time_zone)
try: try:
formatted_time = time_zone_time.strftime(time_format) formatted_time = time_zone_time.strftime(time_format)
except (ValueError, TypeError) as _: except (ValueError, TypeError):
formatted_time = time_zone_time.strftime(self.default_time_format) formatted_time = time_zone_time.strftime(self.default_time_format)
return sanitize_filename_part(formatted_time, replace_spaces=False) return sanitize_filename_part(formatted_time, replace_spaces=False)
@ -466,14 +468,14 @@ def get_next_sequence_number(path, basename):
""" """
result = -1 result = -1
if basename != '': if basename != '':
basename = basename + "-" basename = f"{basename}-"
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(parts[0]), result)
except ValueError: except ValueError:
pass pass
@ -535,7 +537,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
add_number = opts.save_images_add_number or file_decoration == '' add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number: if file_decoration != "" and add_number:
file_decoration = "-" + file_decoration file_decoration = f"-{file_decoration}"
file_decoration = namegen.apply(file_decoration) + suffix file_decoration = namegen.apply(file_decoration) + suffix
@ -565,7 +567,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
def _atomically_save_image(image_to_save, filename_without_extension, extension): def _atomically_save_image(image_to_save, filename_without_extension, extension):
# save image with .tmp extension to avoid race condition when another process detects new image in the directory # save image with .tmp extension to avoid race condition when another process detects new image in the directory
temp_file_path = filename_without_extension + ".tmp" temp_file_path = f"{filename_without_extension}.tmp"
image_format = Image.registered_extensions()[extension] image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png': if extension.lower() == '.png':
@ -625,7 +627,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt" txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file: with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n") file.write(f"{info}\n")
else: else:
txt_fullfn = None txt_fullfn = None

View File

@ -1,19 +1,15 @@
import math
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
from modules import devices, sd_samplers from modules import sd_samplers
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts import modules.scripts
@ -48,7 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
try: try:
img = Image.open(image) img = Image.open(image)
except UnidentifiedImageError: except UnidentifiedImageError as e:
print(e)
continue continue
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img) img = ImageOps.exif_transpose(img)
@ -58,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
# try to find corresponding mask for an image using simple filename matching # try to find corresponding mask for an image using simple filename matching
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
# if not found use first one ("same mask for all images" use-case) # if not found use first one ("same mask for all images" use-case)
if not mask_image_path in inpaint_masks: if mask_image_path not in inpaint_masks:
mask_image_path = inpaint_masks[0] mask_image_path = inpaint_masks[0]
mask_image = Image.open(mask_image_path) mask_image = Image.open(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image

View File

@ -11,7 +11,6 @@ import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
@ -28,7 +27,7 @@ def category_types():
def download_default_clip_interrogate_categories(content_dir): def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...") print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp" tmpdir = f"{content_dir}_tmp"
category_types = ["artists", "flavors", "mediums", "movements"] category_types = ["artists", "flavors", "mediums", "movements"]
try: try:
@ -160,7 +159,7 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate) text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
@ -208,13 +207,13 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
for name, topn, items in self.categories(): for cat in self.categories():
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, cat.items, top_count=cat.topn)
for match, score in matches: for match, score in matches:
if shared.opts.interrogate_return_ranks: if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})" res += f", ({match}:{score/100:.3f})"
else: else:
res += ", " + match res += f", {match}"
except Exception: except Exception:
print("Error interrogating", file=sys.stderr) print("Error interrogating", file=sys.stderr)

View File

@ -23,7 +23,7 @@ def list_localizations(dirname):
localizations[fn] = file.path localizations[fn] = file.path
def localization_js(current_localization_name): def localization_js(current_localization_name: str) -> str:
fn = localizations.get(current_localization_name, None) fn = localizations.get(current_localization_name, None)
data = {} data = {}
if fn is not None: if fn is not None:
@ -34,4 +34,4 @@ def localization_js(current_localization_name):
print(f"Error loading localization from {fn}:", file=sys.stderr) print(f"Error loading localization from {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
return f"var localization = {json.dumps(data)}\n" return f"window.localization = {json.dumps(data)}"

View File

@ -1,6 +1,5 @@
import torch import torch
import platform import platform
from modules import paths
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
@ -43,7 +42,7 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
@ -54,6 +53,11 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113 # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6) CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

View File

@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0): def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
h, w = mask.shape h, w = mask.shape
crop_left = 0 crop_left = 0

View File

@ -1,4 +1,3 @@
import glob
import os import os
import shutil import shutil
import importlib import importlib
@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
""" """
output = [] output = []
if ext_filter is None:
ext_filter = []
try: try:
places = [] places = []
@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(model_path) places.append(model_path)
for place in places: for place in places:
if os.path.exists(place): for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
for file in glob.iglob(place + '**/**', recursive=True): if os.path.islink(full_path) and not os.path.exists(full_path):
full_path = file print(f"Skipping broken symlink: {full_path}")
if os.path.isdir(full_path): continue
continue if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
if os.path.islink(full_path) and not os.path.exists(full_path): continue
print(f"Skipping broken symlink: {full_path}") if full_path not in output:
continue output.append(full_path)
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
print(f"Moving {file} from {src_path} to {dest_path}.") print(f"Moving {file} from {src_path} to {dest_path}.")
try: try:
shutil.move(fullpath, dest_path) shutil.move(fullpath, dest_path)
except: except Exception:
pass pass
if len(os.listdir(src_path)) == 0: if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}") print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True) shutil.rmtree(src_path, True)
except: except Exception:
pass pass
builtin_upscaler_classes = []
forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
forbidden_upscaler_classes.add(cls)
def load_upscalers(): def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
@ -155,15 +126,22 @@ def load_upscalers():
full_model = f"modules.{model_name}_model" full_model = f"modules.{model_name}_model"
try: try:
importlib.import_module(full_model) importlib.import_module(full_model)
except: except Exception:
pass pass
datas = [] datas = []
commandline_options = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
if cls in forbidden_upscaler_classes:
continue
# some of upscaler classes will not go away after reloading their modules, and we'll end
# up with two copies of those classes. The newest copy will always be the last in the list,
# so we go from end to beginning and ignore duplicates
used_classes = {}
for cls in reversed(Upscaler.__subclasses__()):
classname = str(cls)
if classname not in used_classes:
used_classes[classname] = cls
for cls in reversed(used_classes.values()):
name = cls.__name__ name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None)) scaler = cls(commandline_options.get(cmd_name, None))

View File

@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
beta_schedule="linear", beta_schedule="linear",
loss_type="l2", loss_type="l2",
ckpt_path=None, ckpt_path=None,
ignore_keys=[], ignore_keys=None,
load_only_unet=False, load_only_unet=False,
monitor="val/loss", monitor="val/loss",
use_ema=True, use_ema=True,
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
# If initialing from EMA-only checkpoint, create EMA model after loading. # If initialing from EMA-only checkpoint, create EMA model after loading.
if self.use_ema and not load_ema: if self.use_ema and not load_ema:
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
ignore_keys = ignore_keys or []
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
sd = sd["state_dict"] sd = sd["state_dict"]
@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print(f"Deleting key {k} from state_dict.")
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
_, loss_dict_no_ema = self.shared_step(batch) _, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope(): with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch) _, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict() log = {}
x = self.get_input(batch, self.first_stage_key) x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N) N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row) n_row = min(x.shape[0], n_row)
@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
log["inputs"] = x log["inputs"] = x
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
x_start = x[:n_row] x_start = x[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
conditioning_key = None conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None) ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", []) ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except Exception:
self.num_downs = 0 self.num_downs = 0
if not scale_by_std: if not scale_by_std:
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False): def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict): if isinstance(cond, dict):
@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial) intermediates.append(x0_partial)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates return img, intermediates
@torch.no_grad() @torch.no_grad()
@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1: if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img) intermediates.append(img)
if callback: callback(i) if callback:
if img_callback: img_callback(img, i) callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates: if return_intermediates:
return img, intermediates return img, intermediates
@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
if cond is not None: if cond is not None:
if isinstance(cond, dict): if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond} [x[:batch_size] for x in cond[key]] for key in cond}
else: else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond, return self.p_sample_loop(cond,
@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
use_ddim = False use_ddim = False
log = dict() log = {}
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True, return_first_stage_outputs=True,
force_c_encode=True, force_c_encode=True,
@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
if plot_diffusion_rows: if plot_diffusion_rows:
# get diffusion row # get diffusion row
diffusion_row = list() diffusion_row = []
z_start = z[:n_row] z_start = z[:n_row]
for t in range(self.num_timesteps): for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1: if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device) mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class # TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs): def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs): def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs) logs = super().log_images(*args, batch=batch, N=N, **kwargs)
key = 'train' if self.training else 'validation' key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key] dset = self.trainer.datamodule.datasets[key]

View File

@ -1 +1 @@
from .sampler import UniPCSampler from .sampler import UniPCSampler # noqa: F401

View File

@ -54,7 +54,8 @@ class UniPCSampler(object):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0] while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0] cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

View File

@ -1,7 +1,6 @@
import torch import torch
import torch.nn.functional as F
import math import math
from tqdm.auto import trange import tqdm
class NoiseScheduleVP: class NoiseScheduleVP:
@ -94,7 +93,7 @@ class NoiseScheduleVP:
""" """
if schedule not in ['discrete', 'linear', 'cosine']: if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
self.schedule = schedule self.schedule = schedule
if schedule == 'discrete': if schedule == 'discrete':
@ -179,13 +178,13 @@ def model_wrapper(
model, model,
noise_schedule, noise_schedule,
model_type="noise", model_type="noise",
model_kwargs={}, model_kwargs=None,
guidance_type="uncond", guidance_type="uncond",
#condition=None, #condition=None,
#unconditional_condition=None, #unconditional_condition=None,
guidance_scale=1., guidance_scale=1.,
classifier_fn=None, classifier_fn=None,
classifier_kwargs={}, classifier_kwargs=None,
): ):
"""Create a wrapper function for the noise prediction model. """Create a wrapper function for the noise prediction model.
@ -276,6 +275,9 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs. A noise prediction model that accepts the noised data and the continuous time as the inputs.
""" """
model_kwargs = model_kwargs or {}
classifier_kwargs = classifier_kwargs or {}
def get_model_input_time(t_continuous): def get_model_input_time(t_continuous):
""" """
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
@ -342,7 +344,7 @@ def model_wrapper(
t_in = torch.cat([t_continuous] * 2) t_in = torch.cat([t_continuous] * 2)
if isinstance(condition, dict): if isinstance(condition, dict):
assert isinstance(unconditional_condition, dict) assert isinstance(unconditional_condition, dict)
c_in = dict() c_in = {}
for k in condition: for k in condition:
if isinstance(condition[k], list): if isinstance(condition[k], list):
c_in[k] = [torch.cat([ c_in[k] = [torch.cat([
@ -353,7 +355,7 @@ def model_wrapper(
unconditional_condition[k], unconditional_condition[k],
condition[k]]) condition[k]])
elif isinstance(condition, list): elif isinstance(condition, list):
c_in = list() c_in = []
assert isinstance(unconditional_condition, list) assert isinstance(unconditional_condition, list)
for i in range(len(condition)): for i in range(len(condition)):
c_in.append(torch.cat([unconditional_condition[i], condition[i]])) c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
@ -469,7 +471,7 @@ class UniPC:
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t return t
else: else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
""" """
@ -757,40 +759,44 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)] model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t] t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver. with tqdm.tqdm(total=steps) as pbar:
for init_order in range(1, order): # Init the first `order` values by lower order multistep DPM-Solver.
vec_t = timesteps[init_order].expand(x.shape[0]) for init_order in range(1, order):
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) vec_t = timesteps[init_order].expand(x.shape[0])
if model_x is None: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
model_x = self.model_fn(x, vec_t)
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
pbar.update()
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero:

View File

@ -7,13 +7,13 @@ def connect(token, port, region):
else: else:
if ':' in token: if ':' in token:
# token = authtoken:username:password # token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1] token, username, password = token.split(':', 2)
token = token.split(':')[0] account = f"{username}:{password}"
config = conf.PyngrokConfig( config = conf.PyngrokConfig(
auth_token=token, region=region auth_token=token, region=region
) )
# Guard for existing tunnels # Guard for existing tunnels
existing = ngrok.get_tunnels(pyngrok_config=config) existing = ngrok.get_tunnels(pyngrok_config=config)
if existing: if existing:
@ -24,7 +24,7 @@ def connect(token, port, region):
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n' print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
'You can use this link after the launch is complete.') 'You can use this link after the launch is complete.')
return return
try: try:
if account is None: if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url

View File

@ -1,8 +1,8 @@
import os import os
import sys import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
import modules.safe import modules.safe # noqa: F401
# data_path = cmd_opts_pre.data # data_path = cmd_opts_pre.data
@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)
break break
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),

View File

@ -2,8 +2,14 @@
import argparse import argparse
import os import os
import sys
import shlex
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
sd_configs_path = os.path.join(script_path, "configs") sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values # Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False) parser_pre = argparse.ArgumentParser(add_help=False)
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
cmd_opts_pre = parser_pre.parse_known_args()[0] cmd_opts_pre = parser_pre.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir data_path = cmd_opts_pre.data_dir
@ -21,3 +27,5 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions") extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states") config_states_dir = os.path.join(script_path, "config_states")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')

View File

@ -2,7 +2,6 @@ import json
import math import math
import os import os
import sys import sys
import warnings
import hashlib import hashlib
import torch import torch
@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List, Optional from typing import Any, Dict, List
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -30,6 +29,13 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType from blendmodes.blend import blendLayers, BlendType
import tomesd
# add a logger for the processing module
logger = logging.getLogger(__name__)
# manually set output level here since there is no option to do so yet through launch options
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4 opt_C = 4
@ -165,7 +171,7 @@ class StableDiffusionProcessing:
self.all_subseeds = None self.all_subseeds = None
self.iteration = 0 self.iteration = 0
self.is_hr_pass = False self.is_hr_pass = False
@property @property
def sd_model(self): def sd_model(self):
@ -458,10 +464,21 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed) p.subseed = get_fixed_seed(p.subseed)
def program_version():
import launch
res = launch.git_tag()
if res == "<none>":
res = None
return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
@ -480,16 +497,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None), "Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None, "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
"Version": program_version() if opts.add_version_to_infotext else None,
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else "" negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@ -512,9 +532,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() sd_vae.reload_vae_weights()
if opts.token_merging_ratio > 0:
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")
res = process_images_inner(p) res = process_images_inner(p)
finally: finally:
# undo model optimizations made by tomesd
if opts.token_merging_ratio > 0:
tomesd.remove_patch(p.sd_model)
logger.debug('Token merging model optimizations removed')
# restore opts to original state # restore opts to original state
if p.override_settings_restore_afterwards: if p.override_settings_restore_afterwards:
for k, v in stored_opts.items(): for k, v in stored_opts.items():
@ -653,7 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if not shared.opts.dont_fix_second_order_samplers_schedule: if not shared.opts.dont_fix_second_order_samplers_schedule:
try: try:
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1 step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
except: except Exception:
pass pass
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc) uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
@ -769,7 +798,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc() devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) res = Processed(
p,
images_list=output_images,
seed=p.all_seeds[0],
info=infotext(),
comments="".join(f"\n\n{comment}" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
@ -958,8 +996,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None x = None
devices.torch_gc() devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass
if opts.token_merging_ratio_hr > 0:
# in case the user has used separate merge ratios
if opts.token_merging_ratio > 0:
tomesd.remove_patch(self.sd_model)
logger.debug('Adjusting token merging ratio for high-res pass')
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
tomesd.remove_patch(self.sd_model)
logger.debug('Removed token merging optimizations from model')
self.is_hr_pass = False self.is_hr_pass = False
return samples return samples

View File

@ -95,8 +95,16 @@ def progressapi(req: ProgressRequest):
image = shared.state.current_image image = shared.state.current_image
if image is not None: if image is not None:
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="png")
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") if opts.live_previews_image_format == "png":
# using optimize for large images takes an enormous amount of time
save_kwargs = {"optimize": max(*image.size) > 256}
else:
save_kwargs = {}
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
id_live_preview = shared.state.id_live_preview id_live_preview = shared.state.id_live_preview
else: else:
live_preview = None live_preview = None

View File

@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
""" """
def collect_steps(steps, tree): def collect_steps(steps, tree):
l = [steps] res = [steps]
class CollectSteps(lark.Visitor): class CollectSteps(lark.Visitor):
def scheduled(self, tree): def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1]) tree.children[-1] = float(tree.children[-1])
if tree.children[-1] < 1: if tree.children[-1] < 1:
tree.children[-1] *= steps tree.children[-1] *= steps
tree.children[-1] = min(steps, int(tree.children[-1])) tree.children[-1] = min(steps, int(tree.children[-1]))
l.append(tree.children[-1]) res.append(tree.children[-1])
def alternate(self, tree): def alternate(self, tree):
l.extend(range(1, steps+1)) res.extend(range(1, steps+1))
CollectSteps().visit(tree) CollectSteps().visit(tree)
return sorted(set(l)) return sorted(set(res))
def at_step(step, tree): def at_step(step, tree):
class AtStep(lark.Transformer): class AtStep(lark.Transformer):
@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def get_schedule(prompt): def get_schedule(prompt):
try: try:
tree = schedule_parser.parse(prompt) tree = schedule_parser.parse(prompt)
except lark.exceptions.LarkError as e: except lark.exceptions.LarkError:
if 0: if 0:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
conds = model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)
cond_schedule = [] cond_schedule = []
for i, (end_at_step, text) in enumerate(prompt_schedule): for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
cache[prompt] = cond_schedule cache[prompt] = cond_schedule
@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c): for i, cond_schedule in enumerate(c):
target_index = 0 target_index = 0
for current, (end_at, cond) in enumerate(cond_schedule): for current, entry in enumerate(cond_schedule):
if current_step <= end_at: if current_step <= entry.end_at_step:
target_index = current target_index = current
break break
res[i] = cond_schedule[target_index].cond res[i] = cond_schedule[target_index].cond
@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
tensors = [] tensors = []
conds_list = [] conds_list = []
for batch_no, composable_prompts in enumerate(c.batch): for composable_prompts in c.batch:
conds_for_batch = [] conds_for_batch = []
for cond_index, composable_prompt in enumerate(composable_prompts): for composable_prompt in composable_prompts:
target_index = 0 target_index = 0
for current, (end_at, cond) in enumerate(composable_prompt.schedules): for current, entry in enumerate(composable_prompt.schedules):
if current_step <= end_at: if current_step <= entry.end_at_step:
target_index = current target_index = current
break break

View File

@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
from realesrgan import RealESRGANer from realesrgan import RealESRGANer # noqa: F401
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
self.enable = True self.enable = True
self.scalers = [] self.scalers = []
scalers = self.load_models(path) scalers = self.load_models(path)
@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler):
for scaler in scalers: for scaler in scalers:
if scaler.local_data_path.startswith("http"): if scaler.local_data_path.startswith("http"):
filename = modelloader.friendly_name(scaler.local_data_path) filename = modelloader.friendly_name(scaler.local_data_path)
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None) local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
if local: if local_model_candidates:
scaler.local_data_path = local scaler.local_data_path = local_model_candidates[0]
if scaler.name in opts.realesrgan_enabled_models: if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler) self.scalers.append(scaler)
@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
info = self.load_model(path) info = self.load_model(path)
if not os.path.exists(info.local_data_path): if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name) print(f"Unable to load RealESRGAN model: {info.name}")
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(
@ -134,6 +134,6 @@ def get_realesrgan_models(scaler):
), ),
] ]
return models return models
except Exception as e: except Exception:
print("Error making Real-ESRGAN models list:", file=sys.stderr) print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)

View File

@ -40,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)
@ -95,16 +95,16 @@ def check_pt(filename, extra_handler):
except zipfile.BadZipfile: except zipfile.BadZipfile:
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle # if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file: with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler unpickler.extra_handler = extra_handler
for i in range(5): for _ in range(5):
unpickler.load() unpickler.load()
def load(filename, *args, **kwargs): def load(filename, *args, **kwargs):
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs): def load_with_extra(filename, extra_handler=None, *args, **kwargs):

View File

@ -32,27 +32,42 @@ class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
self.image_cond = image_cond self.image_cond = image_cond
"""Conditioning image""" """Conditioning image"""
self.sigma = sigma self.sigma = sigma
"""Current sigma noise step value""" """Current sigma noise step value"""
self.sampling_step = sampling_step self.sampling_step = sampling_step
"""Current Sampling step number""" """Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
self.text_cond = text_cond self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt""" """ Encoder hidden states of text conditioning from prompt"""
self.text_uncond = text_uncond self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt""" """ Encoder hidden states of text conditioning from negative prompt"""
class CFGDenoisedParams: class CFGDenoisedParams:
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
self.x = x
"""Latent image representation in the process of being denoised"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
self.inner_model = inner_model
"""Inner model reference used for denoising"""
class AfterCFGCallbackParams:
def __init__(self, x, sampling_step, total_sampling_steps): def __init__(self, x, sampling_step, total_sampling_steps):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
@ -87,6 +102,7 @@ callback_map = dict(
callbacks_image_saved=[], callbacks_image_saved=[],
callbacks_cfg_denoiser=[], callbacks_cfg_denoiser=[],
callbacks_cfg_denoised=[], callbacks_cfg_denoised=[],
callbacks_cfg_after_cfg=[],
callbacks_before_component=[], callbacks_before_component=[],
callbacks_after_component=[], callbacks_after_component=[],
callbacks_image_grid=[], callbacks_image_grid=[],
@ -186,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
report_exception(c, 'cfg_denoised_callback') report_exception(c, 'cfg_denoised_callback')
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_after_cfg_callback')
def before_component_callback(component, **kwargs): def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']: for c in callback_map['callbacks_before_component']:
try: try:
@ -240,7 +264,7 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@ -332,6 +356,14 @@ def on_cfg_denoised(callback):
add_callback(callback_map['callbacks_cfg_denoised'], callback) add_callback(callback_map['callbacks_cfg_denoised'], callback)
def on_cfg_after_cfg(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
The callback is called with one argument:
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
"""
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
def on_before_component(callback): def on_before_component(callback):
"""register a function to be called before a component is created. """register a function to be called before a component is created.
The callback is called with arguments: The callback is called with arguments:

View File

@ -2,7 +2,6 @@ import os
import sys import sys
import traceback import traceback
import importlib.util import importlib.util
from types import ModuleType
def load_module(path): def load_module(path):

View File

@ -163,7 +163,8 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False) need_tabname = self.show(True) == self.show(False)
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" tabkind = 'img2img' if self.is_img2img else 'txt2txt'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{tabname}{title}_{item_id}' return f'script_{tabname}{title}_{item_id}'
@ -230,7 +231,7 @@ def load_scripts():
syspath = sys.path syspath = sys.path
def register_scripts_from_module(module): def register_scripts_from_module(module):
for key, script_class in module.__dict__.items(): for script_class in module.__dict__.values():
if type(script_class) != type: if type(script_class) != type:
continue continue
@ -294,9 +295,9 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: for script_data in auto_processing_scripts + scripts_data:
script = script_class() script = script_data.script_class()
script.filename = path script.filename = script_data.path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img
script.is_img2img = is_img2img script.is_img2img = is_img2img
@ -491,7 +492,7 @@ class ScriptRunner:
module = script_loading.load_module(script.filename) module = script_loading.load_module(script.filename)
cache[filename] = module cache[filename] = module
for key, script_class in module.__dict__.items(): for script_class in module.__dict__.values():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class() self.scripts[si] = script_class()
self.scripts[si].filename = filename self.scripts[si].filename = filename
@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp):
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
""" """
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])] comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False): if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect') comp.elem_classes.append('multiselect')

View File

@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
return self.postprocessing_controls.values() return self.postprocessing_controls.values()
def postprocess_image(self, p, script_pp, *args): def postprocess_image(self, p, script_pp, *args):
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} args_dict = dict(zip(self.postprocessing_controls, args))
pp = scripts_postprocessing.PostprocessedImage(script_pp.image) pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
pp.info = {} pp.info = {}

View File

@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
def initialize_scripts(self, scripts_data): def initialize_scripts(self, scripts_data):
self.scripts = [] self.scripts = []
for script_class, path, basedir, script_module in scripts_data: for script_data in scripts_data:
script: ScriptPostprocessing = script_class() script: ScriptPostprocessing = script_data.script_class()
script.filename = path script.filename = script_data.path
if script.name == "Simple Upscale": if script.name == "Simple Upscale":
continue continue
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, component), value in zip(script.controls.items(), script_args): for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
script.process(pp, **process_args) script.process(pp, **process_args)

View File

@ -61,7 +61,7 @@ class DisableInitialization:
if res is None: if res is None:
res = original(url, *args, local_files_only=False, **kwargs) res = original(url, *args, local_files_only=False, **kwargs)
return res return res
except Exception as e: except Exception:
return original(url, *args, local_files_only=False, **kwargs) return original(url, *args, local_files_only=False, **kwargs)
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):

View File

@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType from types import MethodType
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules import devices, sd_hijack_optimizations, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@ -34,10 +34,10 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
@ -92,12 +92,12 @@ def fix_checkpoint():
def weighted_loss(sd_model, pred, target, mean=True): def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean #Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False) loss = sd_model._old_get_loss(pred, target, mean=False)
#Check if we have weights available #Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None) weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None: if weight is not None:
loss *= weight loss *= weight
#Return the loss, as mean if specified #Return the loss, as mean if specified
return loss.mean() if mean else loss return loss.mean() if mean else loss
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try: try:
#Temporarily append weights to a place accessible during loss calc #Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w sd_model._custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'): if not hasattr(sd_model, '_old_get_loss'):
@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try: try:
#Delete temporary weights if appended #Delete temporary weights if appended
del sd_model._custom_loss_weight del sd_model._custom_loss_weight
except AttributeError as e: except AttributeError:
pass pass
#If we have an old loss function, reset the loss function to the original one #If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'): if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss sd_model.get_loss = sd_model._old_get_loss
@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
def undo_weighted_forward(sd_model): def undo_weighted_forward(sd_model):
try: try:
del sd_model.weighted_forward del sd_model.weighted_forward
except AttributeError as e: except AttributeError:
pass pass
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped

View File

@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack.fixes = [x.fixes for x in batch_chunk] self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes: for fixes in self.hijack.fixes:
for position, embedding in fixes: for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers) z = self.process_tokens(tokens, multipliers)

View File

@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers) return self.process_tokens(remade_batch_tokens, batch_multipliers)

View File

@ -1,16 +1,10 @@
import os
import torch import torch
from einops import repeat
from omegaconf import ListConfig
import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.ddim import noise_like
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding from ldm.models.diffusion.sampling_util import norm_thresholding
@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
if isinstance(c, dict): if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict) assert isinstance(unconditional_conditioning, dict)
c_in = dict() c_in = {}
for k in c: for k in c:
if isinstance(c[k], list): if isinstance(c[k], list):
c_in[k] = [ c_in[k] = [

View File

@ -1,8 +1,5 @@
import collections
import os.path import os.path
import sys
import gc
import time
def should_hijack_ip2p(checkpoint_info): def should_hijack_ip2p(checkpoint_info):
from modules import sd_models_config from modules import sd_models_config
@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
ckpt_basename = os.path.basename(checkpoint_info.filename).lower() ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename

View File

@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
del context, context_k, context_v, x del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
dtype = q.dtype dtype = q.dtype
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
end = i + 2 end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale s1 *= self.scale
s2 = s1.softmax(dim=-1) s2 = s1.softmax(dim=-1)
del s1 del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2 del s2
del q, k, v del q, k, v
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn): with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale k_in = k_in * self.scale
del context, x del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = get_available_vram() mem_free_total = get_available_vram()
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
steps = 1 steps = 1
if mem_required > mem_free_total: if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64: if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype) s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1 del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
del q, k, v del q, k, v
r1 = r1.to(dtype) r1 = r1.to(dtype)
@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn): with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale k = k * self.scale
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v) r = einsum_op(q, k, v)
r = r.to(dtype) r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
if q.device.type == 'mps':
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k = q.float(), k.float()
@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk, # the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path # i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype): with devices.without_autocast(disable=q.dtype == v.dtype):
@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k) k_in = self.to_k(context_k)
v_in = self.to_v(context_v) v_in = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
dtype = q.dtype dtype = q.dtype
@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in del q_in, k_in, v_in
dtype = q.dtype dtype = q.dtype
@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
h3 += x h3 += x
return h3 return h3
def xformers_attnblock_forward(self, x): def xformers_attnblock_forward(self, x):
try: try:
h_ = x h_ = x
@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k = q.float(), k.float()
@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype dtype = q.dtype
if shared.opts.upcast_attn: if shared.opts.upcast_attn:
q, k = q.float(), k.float() q, k = q.float(), k.float()
@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
b, c, h, w = q.shape b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous() q = q.contiguous()
k = k.contiguous() k = k.contiguous()
v = v.contiguous() v = v.contiguous()

View File

@ -18,7 +18,7 @@ class TorchHijackForUnet:
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def cat(self, tensors, *args, **kwargs): def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2: if len(tensors) == 2:

View File

@ -1,8 +1,6 @@
import open_clip.tokenizer
import torch import torch
from modules import sd_hijack_clip, devices from modules import sd_hijack_clip, devices
from modules.shared import opts
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):

View File

@ -2,6 +2,8 @@ import collections
import os.path import os.path
import sys import sys
import gc import gc
import threading
import torch import torch
import re import re
import safetensors.torch import safetensors.torch
@ -13,9 +15,9 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@ -45,7 +47,7 @@ class CheckpointInfo:
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename) self.hash = model_hash(filename)
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
@ -67,7 +69,7 @@ class CheckpointInfo:
checkpoint_alisases[id] = self checkpoint_alisases[id] = self
def calculate_shorthash(self): def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None: if self.sha256 is None:
return return
@ -85,8 +87,7 @@ class CheckpointInfo:
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging, CLIPModel # noqa: F401
from transformers import logging, CLIPModel
logging.set_verbosity_error() logging.set_verbosity_error()
except Exception: except Exception:
@ -165,7 +166,7 @@ def model_hash(filename):
def select_checkpoint(): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
@ -237,7 +238,7 @@ def read_metadata_from_safetensors(filename):
if isinstance(v, str) and v[0:1] == '{': if isinstance(v, str) and v[0:1] == '{':
try: try:
res[k] = json.loads(v) res[k] = json.loads(v)
except Exception as e: except Exception:
pass pass
return res return res
@ -372,7 +373,7 @@ def enable_midas_autodownload():
if not os.path.exists(path): if not os.path.exists(path):
if not os.path.exists(midas_path): if not os.path.exists(midas_path):
mkdir(midas_path) mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}") print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path) request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded") print(f"{model_type} downloaded")
@ -404,13 +405,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
class SdModelData:
def __init__(self):
self.sd_model = None
self.lock = threading.Lock()
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
return self.sd_model
def set_sd_model(self, v):
self.sd_model = v
model_data = SdModelData()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if shared.sd_model: if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
shared.sd_model = None model_data.sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
@ -439,7 +466,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
try: try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
except Exception as e: except Exception:
pass pass
if sd_model is None: if sd_model is None:
@ -464,7 +491,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
timer.record("hijack") timer.record("hijack")
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@ -484,7 +511,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = model_data.sd_model
if sd_model is None: # previous model load failed if sd_model is None: # previous model load failed
current_checkpoint_info = None current_checkpoint_info = None
@ -512,11 +539,11 @@ def reload_model_weights(sd_model=None, info=None):
del sd_model del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model return model_data.sd_model
try: try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
except Exception as e: except Exception:
print("Failed to load checkpoint, restoring previous") print("Failed to load checkpoint, restoring previous")
load_model_weights(sd_model, current_checkpoint_info, None, timer) load_model_weights(sd_model, current_checkpoint_info, None, timer)
raise raise
@ -535,17 +562,15 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model return sd_model
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import devices, sd_hijack
timer = Timer() timer = Timer()
if shared.sd_model: if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
# shared.sd_model.cond_stage_model.to(devices.cpu) sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
# shared.sd_model.first_stage_model.to(devices.cpu) model_data.sd_model = None
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
@ -554,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.") print(f"Unloaded weights {timer.summary()}.")
return sd_model return sd_model
def apply_token_merging(sd_model, hr: bool):
"""
Applies speed and memory optimizations from tomesd.
Args:
hr (bool): True if called in the context of a high-res pass
"""
ratio = shared.opts.token_merging_ratio
if hr:
ratio = shared.opts.token_merging_ratio_hr
tomesd.apply_patch(
sd_model,
ratio=ratio,
use_rand=False, # can cause issues with some samplers
merge_attn=True,
merge_crossattn=False,
merge_mlp=False
)

View File

@ -1,4 +1,3 @@
import re
import os import os
import torch import torch
@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
if info is None: if info is None:
return None return None
config = os.path.splitext(info.filename)[0] + ".yaml" config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config): if os.path.exists(config):
return config return config

View File

@ -1,7 +1,7 @@
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
# imports for functions that previously were here and are used by other modules # imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
all_samplers = [ all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_kdiffusion.samplers_data_k_diffusion,

View File

@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently; # for DDIM, shapes must match, we can't just process cond and uncond independently;

View File

@ -1,7 +1,6 @@
from collections import deque from collections import deque
import torch import torch
import inspect import inspect
import einops
import k_diffusion.sampling import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common from modules import prompt_parser, devices, sd_samplers_common
@ -9,6 +8,7 @@ from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@ -87,17 +87,17 @@ class CFGDenoiser(torch.nn.Module):
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm": if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond) image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else: else:
image_uncond = image_cond image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model: if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@ -161,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params) cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet") devices.test_for_nans(x_out, "unet")
@ -181,6 +181,10 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None: if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = self.init_latent * self.mask + self.nmask * denoised
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
denoised = after_cfg_callback_params.x
self.step += 1 self.step += 1
return denoised return denoised
@ -198,7 +202,7 @@ class TorchHijack:
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x): def randn_like(self, x):
if self.sampler_noises: if self.sampler_noises:
@ -317,7 +321,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:] sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0] xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters parameters = inspect.signature(self.func).parameters
@ -340,9 +344,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
extra_args={ extra_args={
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
} }
@ -375,9 +379,9 @@ class KDiffusionSampler:
self.last_latent = x self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }, disable=False, callback=self.callback_state, **extra_params_kwargs))

View File

@ -1,8 +1,5 @@
import torch
import safetensors.torch
import os import os
import collections import collections
from collections import namedtuple
from modules import paths, shared, devices, script_callbacks, sd_models from modules import paths, shared, devices, script_callbacks, sd_models
import glob import glob
from copy import deepcopy from copy import deepcopy

View File

@ -1,12 +1,9 @@
import argparse
import datetime import datetime
import json import json
import os import os
import sys import sys
import time import time
import requests
from PIL import Image
import gradio as gr import gradio as gr
import tqdm import tqdm
@ -15,7 +12,8 @@ import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None demo = None
@ -201,8 +199,9 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
@ -211,9 +210,33 @@ class OptionInfo:
self.section = section self.section = section
self.refresh = refresh self.refresh = refresh
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
for k, v in options_dict.items(): for v in options_dict.values():
v.section = section_identifier v.section = section_identifier
return options_dict return options_dict
@ -242,7 +265,7 @@ options_templates = {}
options_templates.update(options_section(('saving-images', "Saving images/grids"), { options_templates.update(options_section(('saving-images', "Saving images/grids"), {
"samples_save": OptionInfo(True, "Always save all generated images"), "samples_save": OptionInfo(True, "Always save all generated images"),
"samples_format": OptionInfo('png', 'File format for images'), "samples_format": OptionInfo('png', 'File format for images'),
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs), "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
"grid_save": OptionInfo(True, "Always save all generated image grids"), "grid_save": OptionInfo(True, "Always save all generated image grids"),
@ -261,10 +284,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"), "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"), "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"),
"img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number),
"target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number),
"img_max_size_mp": OptionInfo(200, "Maximum image size, in megapixels", gr.Number), "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"),
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
@ -292,28 +315,26 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
"save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"),
"grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"),
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
"directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs), "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}),
})) }))
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
"face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."), "show_warnings": OptionInfo(False, "Show warnings in console."),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
@ -338,20 +359,22 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_hr": OptionInfo(0.0, "Togen merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility"), {
@ -363,80 +386,87 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"),
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"),
})) }))
options_templates.update(options_section(('extra_networks', "Extra Networks"), { options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}), "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), }))
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
})) }))
options_templates.update(options_section(('ui', "Live previews"), { options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
"live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"),
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}).info("Full = slow but pretty; Approx NN = fast but low quality; Approx cheap = super fast but terrible otherwise"),
"live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
"live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}).needs_restart(),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}), 's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}),
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}),
'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"),
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"),
})) }))
@ -453,6 +483,7 @@ options_templates.update(options_section((None, "Hidden options"), {
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
})) }))
options_templates.update() options_templates.update()
@ -542,6 +573,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file: with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file) self.data = json.load(file)
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
bad_settings = 0 bad_settings = 0
for k, v in self.data.items(): for k, v in self.data.items():
info = self.data_labels.get(k, None) info = self.data_labels.get(k, None)
@ -560,7 +595,9 @@ class Options:
func() func()
def dumpjson(self): def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d) return json.dumps(d)
def add_option(self, key, info): def add_option(self, key, info):
@ -571,11 +608,11 @@ class Options:
section_ids = {} section_ids = {}
settings_items = self.data_labels.items() settings_items = self.data_labels.items()
for k, item in settings_items: for _, item in settings_items:
if item.section not in section_ids: if item.section not in section_ids:
section_ids[item.section] = len(section_ids) section_ids[item.section] = len(section_ids)
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value): def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key """casts an arbitrary to the same type as this setting's value with key
@ -600,13 +637,37 @@ class Options:
return value return value
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
opts.load(config_filename) opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings""" """assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent" latent_upscale_default_mode = "Latent"
latent_upscale_modes = { latent_upscale_modes = {
@ -620,8 +681,6 @@ latent_upscale_modes = {
sd_upscalers = [] sd_upscalers = []
sd_model = None
clip_model = None clip_model = None
progress_print_out = sys.stdout progress_print_out = sys.stdout
@ -634,14 +693,19 @@ def reload_gradio_theme(theme_name=None):
if not theme_name: if not theme_name:
theme_name = opts.gradio_theme theme_name = opts.gradio_theme
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
if theme_name == "Default": if theme_name == "Default":
gradio_theme = gr.themes.Default() gradio_theme = gr.themes.Default(**default_theme_args)
else: else:
try: try:
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
except requests.exceptions.ConnectionError: except Exception as e:
print("Can't access HuggingFace Hub, falling back to default Gradio theme") errors.display(e, "changing gradio theme")
gradio_theme = gr.themes.Default() gradio_theme = gr.themes.Default(**default_theme_args)
@ -701,3 +765,20 @@ def html(filename):
return file.read() return file.read()
return "" return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
for root, _, files in os.walk(path, followlinks=True):
for filename in files:
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
yield os.path.join(root, filename)

View File

@ -1,18 +1,9 @@
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
from __future__ import annotations
import csv import csv
import os import os
import os.path import os.path
import typing import typing
import collections.abc as abc
import tempfile
import shutil import shutil
if typing.TYPE_CHECKING:
# Only import this when code is being type-checked, it doesn't have any effect at runtime
from .processing import StableDiffusionProcessing
class PromptStyle(typing.NamedTuple): class PromptStyle(typing.NamedTuple):
name: str name: str
@ -74,7 +65,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None: def save_styles(self, path: str) -> None:
# Always keep a backup file around # Always keep a backup file around
if os.path.exists(path): if os.path.exists(path):
shutil.copy(path, path + ".bak") shutil.copy(path, f"{path}.bak")
fd = os.open(path, os.O_RDWR|os.O_CREAT) fd = os.open(path, os.O_RDWR|os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:

View File

@ -179,7 +179,7 @@ def efficient_dot_product_attention(
chunk_idx, chunk_idx,
min(query_chunk_size, q_tokens) min(query_chunk_size, q_tokens)
) )
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@ -201,14 +201,15 @@ def efficient_dot_product_attention(
key=key, key=key,
value=value, value=value,
) )
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, res = torch.zeros_like(query)
# and pass slices to be mutated, instead of torch.cat()ing the returned slices for i in range(math.ceil(q_tokens / query_chunk_size)):
res = torch.cat([ attn_scores = compute_query_chunk_attn(
compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size), query=get_query_chunk(i * query_chunk_size),
key=key, key=key,
value=value, value=value,
) for i in range(math.ceil(q_tokens / query_chunk_size)) )
], dim=1)
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
return res return res

View File

@ -1,10 +1,8 @@
import cv2 import cv2
import requests import requests
import os import os
from collections import defaultdict
from math import log, sqrt
import numpy as np import numpy as np
from PIL import Image, ImageDraw from PIL import ImageDraw
GREEN = "#0F0" GREEN = "#0F0"
BLUE = "#00F" BLUE = "#00F"
@ -12,63 +10,64 @@ RED = "#F00"
def crop_image(im, settings): def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """ """ Intelligently crop an image to the subject matter """
scale_by = 1 scale_by = 1
if is_landscape(im.width, im.height): if is_landscape(im.width, im.height):
scale_by = settings.crop_height / im.height scale_by = settings.crop_height / im.height
elif is_portrait(im.width, im.height): elif is_portrait(im.width, im.height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_square(im.width, im.height): elif is_square(im.width, im.height):
if is_square(settings.crop_width, settings.crop_height): if is_square(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_landscape(settings.crop_width, settings.crop_height): elif is_landscape(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_portrait(settings.crop_width, settings.crop_height): elif is_portrait(settings.crop_width, settings.crop_height):
scale_by = settings.crop_height / im.height scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
focus = focal_point(im_debug, settings) im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
# take the focal point and turn it into crop coordinates that try to center over the focal focus = focal_point(im_debug, settings)
# point but then get adjusted back into the frame
y_half = int(settings.crop_height / 2)
x_half = int(settings.crop_width / 2)
x1 = focus.x - x_half # take the focal point and turn it into crop coordinates that try to center over the focal
if x1 < 0: # point but then get adjusted back into the frame
x1 = 0 y_half = int(settings.crop_height / 2)
elif x1 + settings.crop_width > im.width: x_half = int(settings.crop_width / 2)
x1 = im.width - settings.crop_width
y1 = focus.y - y_half x1 = focus.x - x_half
if y1 < 0: if x1 < 0:
y1 = 0 x1 = 0
elif y1 + settings.crop_height > im.height: elif x1 + settings.crop_width > im.width:
y1 = im.height - settings.crop_height x1 = im.width - settings.crop_width
x2 = x1 + settings.crop_width y1 = focus.y - y_half
y2 = y1 + settings.crop_height if y1 < 0:
y1 = 0
elif y1 + settings.crop_height > im.height:
y1 = im.height - settings.crop_height
crop = [x1, y1, x2, y2] x2 = x1 + settings.crop_width
y2 = y1 + settings.crop_height
results = [] crop = [x1, y1, x2, y2]
results.append(im.crop(tuple(crop))) results = []
if settings.annotate_image: results.append(im.crop(tuple(crop)))
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results if settings.annotate_image:
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results
def focal_point(im, settings): def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
@ -88,7 +87,7 @@ def focal_point(im, settings):
corner_centroid = None corner_centroid = None
if len(corner_points) > 0: if len(corner_points) > 0:
corner_centroid = centroid(corner_points) corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid) pois.append(corner_centroid)
entropy_centroid = None entropy_centroid = None
@ -100,7 +99,7 @@ def focal_point(im, settings):
face_centroid = None face_centroid = None
if len(face_points) > 0: if len(face_points) > 0:
face_centroid = centroid(face_points) face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid) pois.append(face_centroid)
average_point = poi_average(pois, settings) average_point = poi_average(pois, settings)
@ -111,7 +110,7 @@ def focal_point(im, settings):
if corner_centroid is not None: if corner_centroid is not None:
color = BLUE color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight) box = corner_centroid.bounding(max_size * corner_centroid.weight)
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(corner_points) > 1: if len(corner_points) > 1:
for f in corner_points: for f in corner_points:
@ -119,7 +118,7 @@ def focal_point(im, settings):
if entropy_centroid is not None: if entropy_centroid is not None:
color = "#ff0" color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight) box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(entropy_points) > 1: if len(entropy_points) > 1:
for f in entropy_points: for f in entropy_points:
@ -127,14 +126,14 @@ def focal_point(im, settings):
if face_centroid is not None: if face_centroid is not None:
color = RED color = RED
box = face_centroid.bounding(max_size * face_centroid.weight) box = face_centroid.bounding(max_size * face_centroid.weight)
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color) d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color) d.ellipse(box, outline=color)
if len(face_points) > 1: if len(face_points) > 1:
for f in face_points: for f in face_points:
d.rectangle(f.bounding(4), outline=color) d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN) d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point return average_point
@ -185,7 +184,7 @@ def image_face_points(im, settings):
try: try:
faces = classifier.detectMultiScale(gray, scaleFactor=1.1, faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE) minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
except: except Exception:
continue continue
if len(faces) > 0: if len(faces) > 0:
@ -262,10 +261,11 @@ def image_entropy(im):
hist = hist[hist > 0] hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum() return -np.log2(hist / hist.sum()).sum()
def centroid(pois): def centroid(pois):
x = [poi.x for poi in pois] x = [poi.x for poi in pois]
y = [poi.y for poi in pois] y = [poi.y for poi in pois]
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
def poi_average(pois, settings): def poi_average(pois, settings):
@ -283,59 +283,59 @@ def poi_average(pois, settings):
def is_landscape(w, h): def is_landscape(w, h):
return w > h return w > h
def is_portrait(w, h): def is_portrait(w, h):
return h > w return h > w
def is_square(w, h): def is_square(w, h):
return w == h return w == h
def download_and_cache_models(dirname): def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx' model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name) cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file): if not os.path.exists(cache_file):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'") print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
response = requests.get(download_url) response = requests.get(download_url)
with open(cache_file, "wb") as f: with open(cache_file, "wb") as f:
f.write(response.content) f.write(response.content)
if os.path.exists(cache_file): if os.path.exists(cache_file):
return cache_file return cache_file
return None return None
class PointOfInterest: class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10): def __init__(self, x, y, weight=1.0, size=10):
self.x = x self.x = x
self.y = y self.y = y
self.weight = weight self.weight = weight
self.size = size self.size = size
def bounding(self, size): def bounding(self, size):
return [ return [
self.x - size//2, self.x - size // 2,
self.y - size//2, self.y - size // 2,
self.x + size//2, self.x + size // 2,
self.y + size//2 self.y + size // 2
] ]
class Settings: class Settings:
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width self.crop_width = crop_width
self.crop_height = crop_height self.crop_height = crop_height
self.corner_points_weight = corner_points_weight self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight self.entropy_points_weight = entropy_points_weight
self.face_points_weight = face_points_weight self.face_points_weight = face_points_weight
self.annotate_image = annotate_image self.annotate_image = annotate_image
self.destop_view_image = False self.destop_view_image = False
self.dnn_model_path = dnn_model_path self.dnn_model_path = dnn_model_path

Some files were not shown because too many files have changed in this diff Show More