mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-04 05:45:05 +08:00
Merge branch 'dev' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into tomesd
This commit is contained in:
commit
75b3692920
43
.github/workflows/on_pull_request.yaml
vendored
43
.github/workflows/on_pull_request.yaml
vendored
@ -18,22 +18,29 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- uses: actions/setup-python@v4
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.11
|
||||||
cache: pip
|
# NB: there's no cache: pip here since we're not installing anything
|
||||||
cache-dependency-path: |
|
# from the requirements.txt file(s) in the repository; it's faster
|
||||||
**/requirements*txt
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
- name: Install PyLint
|
# of PyTorch and other dependencies.
|
||||||
run: |
|
- name: Install Ruff
|
||||||
python -m pip install --upgrade pip
|
run: pip install ruff==0.0.265
|
||||||
pip install pylint
|
- name: Run Ruff
|
||||||
# This lets PyLint check to see if it can resolve imports
|
run: ruff .
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
# The rest are currently disabled pending fixing of e.g. installing the torch dependency.
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
|
||||||
python launch.py
|
# - name: Install PyLint
|
||||||
- name: Analysing the code with pylint
|
# run: |
|
||||||
run: |
|
# python -m pip install --upgrade pip
|
||||||
pylint $(git ls-files '*.py')
|
# pip install pylint
|
||||||
|
# # This lets PyLint check to see if it can resolve imports
|
||||||
|
# - name: Install dependencies
|
||||||
|
# run: |
|
||||||
|
# export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||||
|
# python launch.py
|
||||||
|
# - name: Analysing the code with pylint
|
||||||
|
# run: |
|
||||||
|
# pylint $(git ls-files '*.py')
|
||||||
|
6
.github/workflows/run_tests.yaml
vendored
6
.github/workflows/run_tests.yaml
vendored
@ -17,8 +17,14 @@ jobs:
|
|||||||
cache: pip
|
cache: pip
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
**/requirements*txt
|
**/requirements*txt
|
||||||
|
launch.py
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||||
|
env:
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK: "1"
|
||||||
|
PIP_PROGRESS_BAR: "off"
|
||||||
|
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||||
|
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||||
- name: Upload main app stdout-stderr
|
- name: Upload main app stdout-stderr
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
if: always()
|
if: always()
|
||||||
|
44
CHANGELOG.md
44
CHANGELOG.md
@ -1,3 +1,47 @@
|
|||||||
|
## Upcoming 1.2.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* do not load wait for stable diffusion model to load at startup
|
||||||
|
* add filename patterns: [denoising]
|
||||||
|
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
|
||||||
|
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
|
||||||
|
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
||||||
|
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
|
||||||
|
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
|
||||||
|
* add version to infotext, footer and console output when starting
|
||||||
|
* add links to wiki for filename pattern settings
|
||||||
|
* add extended info for quicksettings setting and use multiselect input instead of a text field
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* gradio bumped to 3.29.0
|
||||||
|
* torch bumped to 2.0.1
|
||||||
|
* --subpath option for gradio for use with reverse proxy
|
||||||
|
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
|
||||||
|
* possible frontend optimization: do not apply localizations if there are none
|
||||||
|
* Add extra `None` option for VAE in XYZ plot
|
||||||
|
* print error to console when batch processing in img2img fails
|
||||||
|
* create HTML for extra network pages only on demand
|
||||||
|
* allow directories starting with . to still list their models for lora, checkpoints, etc
|
||||||
|
* put infotext options into their own category in settings tab
|
||||||
|
* do not show licenses page when user selects Show all pages in settings
|
||||||
|
|
||||||
|
### Extensions:
|
||||||
|
* Tooltip localization support
|
||||||
|
* Add api method to get LoRA models with prompt
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* re-add /docs endpoint
|
||||||
|
* fix gamepad navigation
|
||||||
|
* make the lightbox fullscreen image function properly
|
||||||
|
* fix squished thumbnails in extras tab
|
||||||
|
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
|
||||||
|
* fix webui showing the same image if you configure the generation to always save results into same file
|
||||||
|
* fix bug with upscalers not working properly
|
||||||
|
* Fix MPS on PyTorch 2.0.1, Intel Macs
|
||||||
|
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
|
||||||
|
* prevent Reload UI button/link from reloading the page when it's not yet ready
|
||||||
|
|
||||||
|
|
||||||
## 1.1.1
|
## 1.1.1
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
||||||
|
@ -88,7 +88,7 @@ class LDSR:
|
|||||||
|
|
||||||
x_t = None
|
x_t = None
|
||||||
logs = None
|
logs = None
|
||||||
for n in range(n_runs):
|
for _ in range(n_runs):
|
||||||
if custom_shape is not None:
|
if custom_shape is not None:
|
||||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
@ -110,7 +110,6 @@ class LDSR:
|
|||||||
diffusion_steps = int(steps)
|
diffusion_steps = int(steps)
|
||||||
eta = 1.0
|
eta = 1.0
|
||||||
|
|
||||||
down_sample_method = 'Lanczos'
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
if torch.cuda.is_available:
|
||||||
@ -131,11 +130,11 @@ class LDSR:
|
|||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
@ -158,7 +157,7 @@ class LDSR:
|
|||||||
|
|
||||||
|
|
||||||
def get_cond(selected_path):
|
def get_cond(selected_path):
|
||||||
example = dict()
|
example = {}
|
||||||
up_f = 4
|
up_f = 4
|
||||||
c = selected_path.convert('RGB')
|
c = selected_path.convert('RGB')
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
log = dict()
|
log = {}
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
import sd_hijack_autoencoder # noqa: F401
|
||||||
|
import sd_hijack_ddpm_v1 # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
import ldm.models.autoencoder
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class VQModel(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
|
|||||||
n_embed,
|
n_embed,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=None):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
|
|||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
|
|||||||
return self.decoder.conv_out.weight
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.image_key)
|
x = self.get_input(batch, self.image_key)
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
if only_inputs:
|
if only_inputs:
|
||||||
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
|
|||||||
if plot_ema:
|
if plot_ema:
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
xrec_ema, _ = self(x)
|
xrec_ema, _ = self(x)
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
return log
|
return log
|
||||||
|
|
||||||
@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
|
|||||||
|
|
||||||
class VQModelInterface(VQModel):
|
class VQModelInterface(VQModel):
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
|
|||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
ldm.models.autoencoder.VQModel = VQModel
|
||||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
||||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface):
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors
|
from modules import shared, devices, sd_models, errors, scripts
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
@ -93,6 +92,7 @@ class LoraOnDisk:
|
|||||||
self.metadata = m
|
self.metadata = m
|
||||||
|
|
||||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
class LoraModule:
|
||||||
@ -165,12 +165,14 @@ def load_lora(name, filename):
|
|||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
@ -182,7 +184,7 @@ def load_lora(name, filename):
|
|||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||||
|
|
||||||
if len(keys_failed_to_match) > 0:
|
if len(keys_failed_to_match) > 0:
|
||||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||||
@ -199,11 +201,11 @@ def load_loras(names, multipliers=None):
|
|||||||
|
|
||||||
loaded_loras.clear()
|
loaded_loras.clear()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
if any([x is None for x in loras_on_disk]):
|
if any(x is None for x in loras_on_disk):
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
lora = already_loaded.get(name, None)
|
lora = already_loaded.get(name, None)
|
||||||
@ -232,6 +234,8 @@ def lora_calc_updown(lora, module, target):
|
|||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
@ -240,6 +244,19 @@ def lora_calc_updown(lora, module, target):
|
|||||||
return updown
|
return updown
|
||||||
|
|
||||||
|
|
||||||
|
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
if weights_backup is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||||
@ -264,12 +281,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
self.lora_weights_backup = weights_backup
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
if weights_backup is not None:
|
lora_restore_weights_from_backup(self)
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
|
||||||
else:
|
|
||||||
self.weight.copy_(weights_backup)
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
@ -297,15 +309,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
|
|
||||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||||
|
|
||||||
setattr(self, "lora_current_names", wanted_names)
|
self.lora_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
|
def lora_forward(module, input, original_forward):
|
||||||
|
"""
|
||||||
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
|
Stacking many loras this way results in big performance degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(loaded_loras) == 0:
|
||||||
|
return original_forward(module, input)
|
||||||
|
|
||||||
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
|
lora_restore_weights_from_backup(module)
|
||||||
|
lora_reset_cached_weight(module)
|
||||||
|
|
||||||
|
res = original_forward(module, input)
|
||||||
|
|
||||||
|
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||||
|
for lora in loaded_loras:
|
||||||
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module.up.to(device=devices.device)
|
||||||
|
module.down.to(device=devices.device)
|
||||||
|
|
||||||
|
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
setattr(self, "lora_current_names", ())
|
self.lora_current_names = ()
|
||||||
setattr(self, "lora_weights_backup", None)
|
self.lora_weights_backup = None
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
@ -318,6 +363,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
def lora_Conv2d_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
@ -343,24 +391,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
available_lora_aliases.clear()
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
candidates = \
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
|
||||||
|
|
||||||
for filename in sorted(candidates, key=str.lower):
|
for filename in sorted(candidates, key=str.lower):
|
||||||
if os.path.isdir(filename):
|
if os.path.isdir(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
entry = LoraOnDisk(name, filename)
|
||||||
|
|
||||||
available_loras[name] = LoraOnDisk(name, filename)
|
available_loras[name] = entry
|
||||||
|
|
||||||
|
available_lora_aliases[name] = entry
|
||||||
|
available_lora_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted(infotext, params):
|
||||||
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
|
added = []
|
||||||
|
|
||||||
|
for k in params:
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re_lora_name.match(name)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
|
|
||||||
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
available_loras = {}
|
available_loras = {}
|
||||||
|
available_lora_aliases = {}
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
|
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import lora
|
import lora
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
@ -49,8 +49,33 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
|
|||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
script_callbacks.on_before_ui(before_ui)
|
script_callbacks.on_before_ui(before_ui)
|
||||||
|
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
|
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_json(obj: lora.LoraOnDisk):
|
||||||
|
return {
|
||||||
|
"name": obj.name,
|
||||||
|
"alias": obj.alias,
|
||||||
|
"path": obj.filename,
|
||||||
|
"metadata": obj.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||||
|
@app.get("/sdapi/v1/loras")
|
||||||
|
async def get_loras():
|
||||||
|
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_app_started(api_loras)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,6 @@ import modules.upscaler
|
|||||||
from modules import devices, modelloader
|
from modules import devices, modelloader
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules import images
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -133,7 +132,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for k, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
@ -61,7 +61,9 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output: tensor shape [b h w c]
|
output: tensor shape [b h w c]
|
||||||
"""
|
"""
|
||||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
if self.type != 'W':
|
||||||
|
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
@ -85,8 +87,9 @@ class WMSA(nn.Module):
|
|||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
if self.type != 'W':
|
||||||
dims=(1, 2))
|
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def relative_embedding(self):
|
def relative_embedding(self):
|
||||||
@ -262,4 +265,4 @@ class SCUNet(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import cmd_opts, opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR as net
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
@ -644,7 +644,7 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
x = self.check_image_size(x)
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
self.mean = self.mean.type_as(x)
|
||||||
x = (x - self.mean) * self.img_range
|
x = (x - self.mean) * self.img_range
|
||||||
|
|
||||||
@ -844,7 +844,7 @@ class SwinIR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
|
@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=[0, 0]):
|
pretrained_window_size=(0, 0)):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
def calculate_mask(self, x_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
def forward(self, x, x_size):
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
else:
|
else:
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
flops += H * W * self.dim // 2
|
flops += H * W * self.dim // 2
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
""" A basic Swin Transformer layer for one stage.
|
""" A basic Swin Transformer layer for one stage.
|
||||||
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
|
|||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
nn.init.constant_(blk.norm2.bias, 0)
|
nn.init.constant_(blk.norm2.bias, 0)
|
||||||
nn.init.constant_(blk.norm2.weight, 0)
|
nn.init.constant_(blk.norm2.weight, 0)
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
r""" Image to Patch Embedding
|
r""" Image to Patch Embedding
|
||||||
Args:
|
Args:
|
||||||
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
|
|||||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
flops += Ho * Wo * self.embed_dim
|
flops += Ho * Wo * self.embed_dim
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
class RSTB(nn.Module):
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
"""Residual Swin Transformer Block (RSTB).
|
||||||
@ -531,7 +531,7 @@ class RSTB(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop, attn_drop=attn_drop,
|
drop=drop, attn_drop=attn_drop,
|
||||||
drop_path=drop_path,
|
drop_path=drop_path,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample, self).__init__(*m)
|
super(Upsample, self).__init__(*m)
|
||||||
|
|
||||||
class Upsample_hf(nn.Sequential):
|
class Upsample_hf(nn.Sequential):
|
||||||
"""Upsample module.
|
"""Upsample module.
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
|
|||||||
m.append(nn.PixelShuffle(3))
|
m.append(nn.PixelShuffle(3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample_hf, self).__init__(*m)
|
super(Upsample_hf, self).__init__(*m)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
class UpsampleOneStep(nn.Sequential):
|
||||||
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
flops = H * W * self.num_feat * 3 * 9
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Swin2SR(nn.Module):
|
class Swin2SR(nn.Module):
|
||||||
r""" Swin2SR
|
r""" Swin2SR
|
||||||
@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle_hf':
|
if self.upsampler == 'pixelshuffle_hf':
|
||||||
self.layers_hf = nn.ModuleList()
|
self.layers_hf = nn.ModuleList()
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers_hf.append(layer)
|
self.layers_hf.append(layer)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
# build the last conv layer in deep feature extraction
|
||||||
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
|
|||||||
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
self.conv_after_aux = nn.Sequential(
|
self.conv_after_aux = nn.Sequential(
|
||||||
nn.Conv2d(3, num_feat, 3, 1, 1),
|
nn.Conv2d(3, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
self.upsample = Upsample(upscale, num_feat)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffle_hf':
|
elif self.upsampler == 'pixelshuffle_hf':
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
|
|||||||
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
elif self.upsampler == 'pixelshuffledirect':
|
||||||
# for lightweight SR (to save parameters)
|
# for lightweight SR (to save parameters)
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||||
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_features_hf(self, x):
|
def forward_features_hf(self, x):
|
||||||
x_size = (x.shape[2], x.shape[3])
|
x_size = (x.shape[2], x.shape[3])
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.norm(x) # B L C
|
x = self.norm(x) # B L C
|
||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
x = self.conv_after_body(self.forward_features(x)) + x
|
||||||
x_before = self.conv_before_upsample(x)
|
x_before = self.conv_before_upsample(x)
|
||||||
x_out = self.conv_last(self.upsample(x_before))
|
x_out = self.conv_last(self.upsample(x_before))
|
||||||
|
|
||||||
x_hf = self.conv_first_hf(x_before)
|
x_hf = self.conv_first_hf(x_before)
|
||||||
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
||||||
x_hf = self.conv_before_upsample_hf(x_hf)
|
x_hf = self.conv_before_upsample_hf(x_hf)
|
||||||
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
|
|||||||
x_first = self.conv_first(x)
|
x_first = self.conv_first(x)
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||||
x = x + self.conv_last(res)
|
x = x + self.conv_last(res)
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
x = x / self.img_range + self.mean
|
||||||
if self.upsampler == "pixelshuffle_aux":
|
if self.upsampler == "pixelshuffle_aux":
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
||||||
|
|
||||||
elif self.upsampler == "pixelshuffle_hf":
|
elif self.upsampler == "pixelshuffle_hf":
|
||||||
x_out = x_out / self.img_range + self.mean
|
x_out = x_out / self.img_range + self.mean
|
||||||
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
x = torch.randn((1, 3, height, width))
|
||||||
x = model(x)
|
x = model(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
<ul>
|
<ul>
|
||||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||||
</ul>
|
</ul>
|
||||||
<span style="display:none" class='search_term'>{search_term}</span>
|
<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
<span class='description'>{description}</span>
|
<span class='description'>{description}</span>
|
||||||
|
@ -92,8 +92,7 @@ contextMenuInit = function(){
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
gradioApp().addEventListener("click", function(e) {
|
gradioApp().addEventListener("click", function(e) {
|
||||||
let source = e.composedPath()[0]
|
if(! e.isTrusted){
|
||||||
if(source.id && source.id.indexOf('check_progress')>-1){
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
function setupExtraNetworksForTab(tabname){
|
function setupExtraNetworksForTab(tabname){
|
||||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||||
|
|
||||||
@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
|
|||||||
tabs.appendChild(search)
|
tabs.appendChild(search)
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh)
|
||||||
|
|
||||||
search.addEventListener("input", function(){
|
var applyFilter = function(){
|
||||||
var searchTerm = search.value.toLowerCase()
|
var searchTerm = search.value.toLowerCase()
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||||
|
var searchOnly = elem.querySelector('.search_only')
|
||||||
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
|
||||||
|
var visible = text.indexOf(searchTerm) != -1
|
||||||
|
|
||||||
|
if(searchOnly && searchTerm.length < 4){
|
||||||
|
visible = false
|
||||||
|
}
|
||||||
|
|
||||||
|
elem.style.display = visible ? "" : "none"
|
||||||
})
|
})
|
||||||
});
|
}
|
||||||
|
|
||||||
|
search.addEventListener("input", applyFilter);
|
||||||
|
applyFilter();
|
||||||
|
|
||||||
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyExtraNetworkFilter(tabname){
|
||||||
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraNetworksApplyFilter = {}
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks(){
|
function setupExtraNetworks(){
|
||||||
|
@ -66,8 +66,8 @@ titles = {
|
|||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||||
@ -120,16 +120,16 @@ onUiUpdate(function(){
|
|||||||
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
||||||
if (span.title) return; // already has a title
|
if (span.title) return; // already has a title
|
||||||
|
|
||||||
let tooltip = titles[span.textContent];
|
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
||||||
|
|
||||||
if(!tooltip){
|
if(!tooltip){
|
||||||
tooltip = titles[span.value];
|
tooltip = localization[titles[span.value]] || titles[span.value];
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!tooltip){
|
if(!tooltip){
|
||||||
for (const c of span.classList) {
|
for (const c of span.classList) {
|
||||||
if (c in titles) {
|
if (c in titles) {
|
||||||
tooltip = titles[c];
|
tooltip = localization[titles[c]] || titles[c];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ onUiUpdate(function(){
|
|||||||
if (select.onchange != null) return;
|
if (select.onchange != null) return;
|
||||||
|
|
||||||
select.onchange = function(){
|
select.onchange = function(){
|
||||||
select.title = titles[select.value] || "";
|
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -1,36 +1,57 @@
|
|||||||
let delay = 350//ms
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
window.addEventListener('gamepadconnected', (e) => {
|
const index = e.gamepad.index;
|
||||||
console.log("Gamepad connected!")
|
let isWaiting = false;
|
||||||
const gamepad = e.gamepad;
|
setInterval(async () => {
|
||||||
setInterval(() => {
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
const xValue = gamepad.axes[0].toFixed(2);
|
const gamepad = navigator.getGamepads()[index];
|
||||||
if (xValue < -0.3) {
|
const xValue = gamepad.axes[0];
|
||||||
modalPrevImage(e);
|
if (xValue <= -0.3) {
|
||||||
} else if (xValue > 0.3) {
|
|
||||||
modalNextImage(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
}, delay);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
/*
|
|
||||||
Primarily for vr controller type pointer devices.
|
|
||||||
I use the wheel event because there's currently no way to do it properly with web xr.
|
|
||||||
*/
|
|
||||||
|
|
||||||
let isScrolling = false;
|
|
||||||
window.addEventListener('wheel', (e) => {
|
|
||||||
if (isScrolling) return;
|
|
||||||
isScrolling = true;
|
|
||||||
|
|
||||||
if (e.deltaX <= -0.6) {
|
|
||||||
modalPrevImage(e);
|
modalPrevImage(e);
|
||||||
} else if (e.deltaX >= 0.6) {
|
isWaiting = true;
|
||||||
|
} else if (xValue >= 0.3) {
|
||||||
modalNextImage(e);
|
modalNextImage(e);
|
||||||
|
isWaiting = true;
|
||||||
}
|
}
|
||||||
|
if (isWaiting) {
|
||||||
|
await sleepUntil(() => {
|
||||||
|
const xValue = navigator.getGamepads()[index].axes[0]
|
||||||
|
if (xValue < 0.3 && xValue > -0.3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
isWaiting = false;
|
||||||
|
}
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
setTimeout(() => {
|
/*
|
||||||
isScrolling = false;
|
Primarily for vr controller type pointer devices.
|
||||||
}, delay);
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
});
|
*/
|
||||||
|
let isScrolling = false;
|
||||||
|
window.addEventListener('wheel', (e) => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
|
||||||
|
isScrolling = true;
|
||||||
|
|
||||||
|
if (e.deltaX <= -0.6) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
} else if (e.deltaX >= 0.6) {
|
||||||
|
modalNextImage(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
isScrolling = false;
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
});
|
||||||
|
|
||||||
|
function sleepUntil(f, timeout) {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const timeStart = new Date();
|
||||||
|
const wait = setInterval(function() {
|
||||||
|
if (f() || new Date() - timeStart > timeout) {
|
||||||
|
clearInterval(wait);
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}, 20);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
|||||||
original_lines = {}
|
original_lines = {}
|
||||||
translated_lines = {}
|
translated_lines = {}
|
||||||
|
|
||||||
|
function hasLocalization() {
|
||||||
|
return window.localization && Object.keys(window.localization).length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
function textNodesUnder(el){
|
function textNodesUnder(el){
|
||||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||||
while(n=walk.nextNode()) a.push(n);
|
while(n=walk.nextNode()) a.push(n);
|
||||||
@ -119,37 +123,6 @@ function dumpTranslations(){
|
|||||||
return dumped
|
return dumped
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(m){
|
|
||||||
m.forEach(function(mutation){
|
|
||||||
mutation.addedNodes.forEach(function(node){
|
|
||||||
processNode(node)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
processNode(gradioApp())
|
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
|
||||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
|
||||||
mutations.forEach(mutation => {
|
|
||||||
mutation.addedNodes.forEach(node => {
|
|
||||||
if (node.tagName === 'STYLE') {
|
|
||||||
observer.disconnect();
|
|
||||||
|
|
||||||
for (const x of node.sheet.rules) { // find all rtl media rules
|
|
||||||
if (Array.from(x.media || []).includes('rtl')) {
|
|
||||||
x.media.appendMedium('all'); // enable them
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})).observe(gradioApp(), { childList: true });
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
function download_localization() {
|
function download_localization() {
|
||||||
var text = JSON.stringify(dumpTranslations(), null, 4)
|
var text = JSON.stringify(dumpTranslations(), null, 4)
|
||||||
|
|
||||||
@ -163,3 +136,36 @@ function download_localization() {
|
|||||||
|
|
||||||
document.body.removeChild(element);
|
document.body.removeChild(element);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(hasLocalization()) {
|
||||||
|
onUiUpdate(function (m) {
|
||||||
|
m.forEach(function (mutation) {
|
||||||
|
mutation.addedNodes.forEach(function (node) {
|
||||||
|
processNode(node)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
processNode(gradioApp())
|
||||||
|
|
||||||
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
|
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||||
|
mutations.forEach(mutation => {
|
||||||
|
mutation.addedNodes.forEach(node => {
|
||||||
|
if (node.tagName === 'STYLE') {
|
||||||
|
observer.disconnect();
|
||||||
|
|
||||||
|
for (const x of node.sheet.rules) { // find all rtl media rules
|
||||||
|
if (Array.from(x.media || []).includes('rtl')) {
|
||||||
|
x.media.appendMedium('all'); // enable them
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})).observe(gradioApp(), { childList: true });
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -348,6 +348,9 @@ onUiUpdate(function(){
|
|||||||
settings_tabs.appendChild(show_all_pages)
|
settings_tabs.appendChild(show_all_pages)
|
||||||
show_all_pages.onclick = function(){
|
show_all_pages.onclick = function(){
|
||||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
||||||
|
if(elem.id == "settings_tab_licenses")
|
||||||
|
return;
|
||||||
|
|
||||||
elem.style.display = "block";
|
elem.style.display = "block";
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -392,7 +395,16 @@ function update_token_counter(button_id) {
|
|||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
setTimeout(function(){location.reload()},2000)
|
|
||||||
|
var requestPing = function(){
|
||||||
|
requestGet("./internal/ping", {}, function(data){
|
||||||
|
location.reload();
|
||||||
|
}, function(){
|
||||||
|
setTimeout(requestPing, 500);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(requestPing, 2000);
|
||||||
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
41
javascript/ui_settings_hints.js
Normal file
41
javascript/ui_settings_hints.js
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// various hints and extra info for the settings tab
|
||||||
|
|
||||||
|
onUiLoaded(function(){
|
||||||
|
createLink = function(elem_id, text, href){
|
||||||
|
var a = document.createElement('A')
|
||||||
|
a.textContent = text
|
||||||
|
a.target = '_blank';
|
||||||
|
|
||||||
|
elem = gradioApp().querySelector('#'+elem_id)
|
||||||
|
elem.insertBefore(a, elem.querySelector('label'))
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
|
||||||
|
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
|
||||||
|
|
||||||
|
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
|
||||||
|
requestGet("./internal/quicksettings-hint", {}, function(data){
|
||||||
|
var table = document.createElement('table')
|
||||||
|
table.className = 'settings-value-table'
|
||||||
|
|
||||||
|
data.forEach(function(obj){
|
||||||
|
var tr = document.createElement('tr')
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.name
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.label
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
table.appendChild(tr)
|
||||||
|
})
|
||||||
|
|
||||||
|
popup(table);
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
|
|
73
launch.py
73
launch.py
@ -19,8 +19,12 @@ python = sys.executable
|
|||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
stored_commit_hash = None
|
stored_commit_hash = None
|
||||||
|
stored_git_tag = None
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
|
|
||||||
|
# Whether to default to printing command output
|
||||||
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
@ -70,32 +74,50 @@ def commit_hash():
|
|||||||
return stored_commit_hash
|
return stored_commit_hash
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
def git_tag():
|
||||||
|
global stored_git_tag
|
||||||
|
|
||||||
|
if stored_git_tag is not None:
|
||||||
|
return stored_git_tag
|
||||||
|
|
||||||
|
try:
|
||||||
|
stored_git_tag = run(f"{git} describe --tags").strip()
|
||||||
|
except Exception:
|
||||||
|
stored_git_tag = "<none>"
|
||||||
|
|
||||||
|
return stored_git_tag
|
||||||
|
|
||||||
|
|
||||||
|
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
if live:
|
run_kwargs = {
|
||||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
"args": command,
|
||||||
if result.returncode != 0:
|
"shell": True,
|
||||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
"env": os.environ if custom_env is None else custom_env,
|
||||||
Command: {command}
|
"encoding": 'utf8',
|
||||||
Error code: {result.returncode}""")
|
"errors": 'ignore',
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
if not live:
|
||||||
|
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||||
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
result = subprocess.run(**run_kwargs)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
|
error_bits = [
|
||||||
|
f"{errdesc or 'Error running command'}.",
|
||||||
|
f"Command: {command}",
|
||||||
|
f"Error code: {result.returncode}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
error_bits.append(f"stdout: {result.stdout}")
|
||||||
|
if result.stderr:
|
||||||
|
error_bits.append(f"stderr: {result.stderr}")
|
||||||
|
raise RuntimeError("\n".join(error_bits))
|
||||||
|
|
||||||
message = f"""{errdesc or 'Error running command'}.
|
return (result.stdout or "")
|
||||||
Command: {command}
|
|
||||||
Error code: {result.returncode}
|
|
||||||
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
|
||||||
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
|
||||||
"""
|
|
||||||
raise RuntimeError(message)
|
|
||||||
|
|
||||||
return result.stdout.decode(encoding="utf8", errors="ignore")
|
|
||||||
|
|
||||||
|
|
||||||
def check_run(command):
|
def check_run(command):
|
||||||
@ -120,7 +142,7 @@ def run_python(code, desc=None, errdesc=None):
|
|||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
||||||
|
|
||||||
|
|
||||||
def run_pip(command, desc=None, live=False):
|
def run_pip(command, desc=None, live=default_command_live):
|
||||||
if args.skip_install:
|
if args.skip_install:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -222,13 +244,14 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
@ -246,8 +269,10 @@ def prepare_environment():
|
|||||||
check_python_version()
|
check_python_version()
|
||||||
|
|
||||||
commit = commit_hash()
|
commit = commit_hash()
|
||||||
|
tag = git_tag()
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
|
print(f"Version: {tag}")
|
||||||
print(f"Commit hash: {commit}")
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
@ -302,7 +327,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@ -15,7 +15,8 @@ from secrets import compare_digest
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api import models
|
||||||
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
@ -25,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||||
|
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
config = sd_samplers.all_samplers_map.get(name, None)
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
@ -48,20 +52,23 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception as err:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
@ -92,6 +99,7 @@ def encode_pil_to_base64(image):
|
|||||||
|
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def api_middleware(app: FastAPI):
|
def api_middleware(app: FastAPI):
|
||||||
rich_available = True
|
rich_available = True
|
||||||
try:
|
try:
|
||||||
@ -99,7 +107,7 @@ def api_middleware(app: FastAPI):
|
|||||||
import starlette # importing just so it can be placed on silent list
|
import starlette # importing just so it can be placed on silent list
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@ -157,7 +165,7 @@ def api_middleware(app: FastAPI):
|
|||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
self.credentials = dict()
|
self.credentials = {}
|
||||||
for auth in shared.cmd_opts.api_auth.split(","):
|
for auth in shared.cmd_opts.api_auth.split(","):
|
||||||
user, password = auth.split(":")
|
user, password = auth.split(":")
|
||||||
self.credentials[user] = password
|
self.credentials[user] = password
|
||||||
@ -166,36 +174,36 @@ class Api:
|
|||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
api_middleware(self.app)
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
@ -219,17 +227,17 @@ class Api:
|
|||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
def get_scripts_list(self):
|
def get_scripts_list(self):
|
||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||||
|
|
||||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
return script_runner.scripts[script_idx]
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
@ -264,11 +272,11 @@ class Api:
|
|||||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||||
if alwayson_script == None:
|
if alwayson_script is None:
|
||||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||||
# Selectable script in always on script param check
|
# Selectable script in always on script param check
|
||||||
if alwayson_script.alwayson == False:
|
if alwayson_script.alwayson is False:
|
||||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
# min between arg length in scriptrunner and arg length in the request
|
# min between arg length in scriptrunner and arg length in the request
|
||||||
@ -276,7 +284,7 @@ class Api:
|
|||||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
@ -310,7 +318,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -320,9 +328,9 @@ class Api:
|
|||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -367,7 +375,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -381,9 +389,9 @@ class Api:
|
|||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
img2imgreq.mask = None
|
img2imgreq.mask = None
|
||||||
|
|
||||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||||
|
|
||||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
@ -391,9 +399,9 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
image_list = reqDict.pop('imageList', [])
|
image_list = reqDict.pop('imageList', [])
|
||||||
@ -402,15 +410,15 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
if(not req.image.strip()):
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
@ -418,13 +426,13 @@ class Api:
|
|||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
items = {**{'parameters': geninfo}, **items}
|
||||||
|
|
||||||
return PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
@ -446,9 +454,9 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
@ -465,7 +473,7 @@ class Api:
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return models.InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
@ -570,36 +578,36 @@ class Api:
|
|||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info = 'preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -617,10 +625,10 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -641,14 +649,15 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os, psutil
|
import os
|
||||||
|
import psutil
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||||
@ -675,10 +684,10 @@ class Api:
|
|||||||
'events': warnings,
|
'events': warnings,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cuda = { 'error': 'unavailable' }
|
cuda = {'error': 'unavailable'}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
cuda = { 'error': f'{err}' }
|
cuda = {'error': f'{err}'}
|
||||||
return MemoryResponse(ram = ram, cuda = cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
@ -223,8 +223,9 @@ for key in _options:
|
|||||||
if(_options[key].dest != 'help'):
|
if(_options[key].dest != 'help'):
|
||||||
flag = _options[key]
|
flag = _options[key]
|
||||||
_type = str
|
_type = str
|
||||||
if _options[key].default is not None: _type = type(_options[key].default)
|
if _options[key].default is not None:
|
||||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
_type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
FlagsModel = create_model("Flags", **flags)
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
@ -288,4 +289,4 @@ class MemoryResponse(BaseModel):
|
|||||||
|
|
||||||
class ScriptsList(BaseModel):
|
class ScriptsList(BaseModel):
|
||||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
||||||
|
@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
max_debug_str_len = 131072 # (1024*1024)/8
|
max_debug_str_len = 131072 # (1024*1024)/8
|
||||||
|
|
||||||
print("Error completing request", file=sys.stderr)
|
print("Error completing request", file=sys.stderr)
|
||||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
argStr = f"Arguments: {args} {kwargs}"
|
||||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||||
if len(argStr) > max_debug_str_len:
|
if len(argStr) > max_debug_str_len:
|
||||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||||
@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
if extra_outputs_array is None:
|
if extra_outputs_array is None:
|
||||||
extra_outputs_array = [None, '']
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
error_message = f'{type(e).__name__}: {e}'
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import *
|
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
def calc_mean_std(feat, eps=1e-5):
|
||||||
@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
|||||||
tgt_mask: Optional[Tensor] = None,
|
tgt_mask: Optional[Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
query_pos: Optional[Tensor] = None):
|
query_pos: Optional[Tensor] = None):
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
tgt2 = self.norm1(tgt)
|
tgt2 = self.norm1(tgt)
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=['32', '64', '128', '256'],
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=['quantize','generator']):
|
fix_modules=('quantize', 'generator')):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||||
for _ in range(self.n_layers)])
|
for _ in range(self.n_layers)])
|
||||||
|
|
||||||
# logits_predict head
|
# logits_predict head
|
||||||
self.idx_pred_layer = nn.Sequential(
|
self.idx_pred_layer = nn.Sequential(
|
||||||
nn.LayerNorm(dim_embd),
|
nn.LayerNorm(dim_embd),
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||||
|
|
||||||
self.channels = {
|
self.channels = {
|
||||||
'16': 512,
|
'16': 512,
|
||||||
'32': 256,
|
'32': 256,
|
||||||
@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
enc_feat_dict = {}
|
enc_feat_dict = {}
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
for i, block in enumerate(self.encoder.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in out_list:
|
if i in out_list:
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||||
|
|
||||||
@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
for i, block in enumerate(self.generator.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in fuse_list: # fuse after i-th block
|
if i in fuse_list: # fuse after i-th block
|
||||||
f_size = str(x.shape[-1])
|
f_size = str(x.shape[-1])
|
||||||
if w>0:
|
if w>0:
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||||
out = x
|
out = x
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
# logits doesn't need softmax before cross_entropy loss
|
||||||
return out, logits, lq_feat
|
return out, logits, lq_feat
|
||||||
|
@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
|||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import copy
|
|
||||||
from basicsr.utils import get_root_logger
|
from basicsr.utils import get_root_logger
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def normalize(in_channels):
|
def normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def swish(x):
|
def swish(x):
|
||||||
@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h*w)
|
q = q.reshape(b, c, h*w)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(b, c, h*w)
|
k = k.reshape(b, c, h*w)
|
||||||
w_ = torch.bmm(q, k)
|
w_ = torch.bmm(q, k)
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
w_ = F.softmax(w_, dim=2)
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = v.reshape(b, c, h*w)
|
v = v.reshape(b, c, h*w)
|
||||||
w_ = w_.permute(0, 2, 1)
|
w_ = w_.permute(0, 2, 1)
|
||||||
h_ = torch.bmm(v, w_)
|
h_ = torch.bmm(v, w_)
|
||||||
h_ = h_.reshape(b, c, h, w)
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.num_resolutions = len(self.ch_mult)
|
self.num_resolutions = len(self.ch_mult)
|
||||||
self.num_res_blocks = res_blocks
|
self.num_res_blocks = res_blocks
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions
|
||||||
self.in_channels = emb_dim
|
self.in_channels = emb_dim
|
||||||
self.out_channels = 3
|
self.out_channels = 3
|
||||||
@ -317,29 +315,29 @@ class Generator(nn.Module):
|
|||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
self.in_channels = 3
|
self.in_channels = 3
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.n_blocks = res_blocks
|
self.n_blocks = res_blocks
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions or [16]
|
||||||
self.quantizer_type = quantizer
|
self.quantizer_type = quantizer
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.kl_weight
|
self.kl_weight
|
||||||
)
|
)
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
self.nf,
|
self.nf,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.ch_mult,
|
self.ch_mult,
|
||||||
self.n_blocks,
|
self.n_blocks,
|
||||||
self.resolution,
|
self.resolution,
|
||||||
self.attn_resolutions
|
self.attn_resolutions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
|||||||
raise ValueError('Wrong params!')
|
raise ValueError('Wrong params!')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.main(x)
|
return self.main(x)
|
||||||
|
@ -33,11 +33,9 @@ def setup_model(dirname):
|
|||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils import img2tensor, tensor2img
|
||||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
|
||||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
from facelib.detection.retinaface import retinaface
|
from facelib.detection.retinaface import retinaface
|
||||||
from modules.shared import cmd_opts
|
|
||||||
|
|
||||||
net_class = CodeFormer
|
net_class = CodeFormer
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ def setup_model(dirname):
|
|||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
self.face_helper.align_warp_face()
|
self.face_helper.align_warp_face()
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
for cropped_face in self.face_helper.cropped_faces:
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
@ -14,7 +14,7 @@ from collections import OrderedDict
|
|||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared, extensions
|
from modules import shared, extensions
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
all_config_states = OrderedDict()
|
all_config_states = OrderedDict()
|
||||||
@ -35,7 +35,7 @@ def list_config_states():
|
|||||||
j["filepath"] = path
|
j["filepath"] = path
|
||||||
config_states.append(j)
|
config_states.append(j)
|
||||||
|
|
||||||
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
for cs in config_states:
|
for cs in config_states:
|
||||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||||
@ -79,7 +78,7 @@ class DeepDanbooru:
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||||
|
|
||||||
for tag in [x for x in tags if x not in filtertags]:
|
for tag in [x for x in tags if x not in filtertags]:
|
||||||
probability = probability_dict[tag]
|
probability = probability_dict[tag]
|
||||||
|
@ -65,7 +65,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import shared, modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -16,9 +16,7 @@ def mod2normal(state_dict):
|
|||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
if 'conv_first.weight' in state_dict:
|
if 'conv_first.weight' in state_dict:
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
|
|||||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||||
re8x = 0
|
re8x = 0
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if "http" in path:
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
filename = load_file_from_url(
|
||||||
file_name="%s.pth" % self.model_name,
|
url=self.model_url,
|
||||||
progress=True)
|
model_dir=self.model_path,
|
||||||
|
file_name=f"{self.model_name}.pth",
|
||||||
|
progress=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
if not os.path.exists(filename) or filename is None:
|
||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print(f"Unable to load {self.model_path} from {filename}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import math
|
import math
|
||||||
import functools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
|
|||||||
elif upsample_mode == 'pixelshuffle':
|
elif upsample_mode == 'pixelshuffle':
|
||||||
upsample_block = pixelshuffle_block
|
upsample_block = pixelshuffle_block
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||||
if upscale == 3:
|
if upscale == 3:
|
||||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
else:
|
else:
|
||||||
@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
Modified options that can be used:
|
Modified options that can be used:
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
|||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
x = x + sampled_noise
|
x = x + sampled_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
@ -261,10 +260,10 @@ class Upsample(nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
info = 'scale_factor=' + str(self.scale_factor)
|
info = f'scale_factor={self.scale_factor}'
|
||||||
else:
|
else:
|
||||||
info = 'size=' + str(self.size)
|
info = f'size={self.size}'
|
||||||
info += ', mode=' + self.mode
|
info += f', mode={self.mode}'
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
|||||||
elif act_type == 'sigmoid': # [0, 1] range output
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
layer = nn.Sigmoid()
|
layer = nn.Sigmoid()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -372,7 +371,7 @@ def norm(norm_type, nc):
|
|||||||
elif norm_type == 'none':
|
elif norm_type == 'none':
|
||||||
def norm_layer(x): return Identity()
|
def norm_layer(x): return Identity()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -388,7 +387,7 @@ def pad(pad_type, padding):
|
|||||||
elif pad_type == 'zero':
|
elif pad_type == 'zero':
|
||||||
layer = nn.ZeroPad2d(padding)
|
layer = nn.ZeroPad2d(padding)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
spectral_norm=False):
|
spectral_norm=False):
|
||||||
""" Conv layer with padding, normalization, activation """
|
""" Conv layer with padding, normalization, activation """
|
||||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
padding = padding if pad_type == 'zero' else 0
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
if convtype=='PartialConv2D':
|
||||||
|
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='DeformConv2D':
|
elif convtype=='DeformConv2D':
|
||||||
|
from torchvision.ops import DeformConv2d # not tested
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='Conv3D':
|
elif convtype=='Conv3D':
|
||||||
|
@ -3,11 +3,10 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
|
|||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
deactivate for all remaining registered networks"""
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network_name in extra_network_data:
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
if extra_network is None:
|
if extra_network is None:
|
||||||
continue
|
continue
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared, extra_networks
|
from modules import extra_networks, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
|||||||
additional = shared.opts.sd_hypernetwork
|
additional = shared.opts.sd_hypernetwork
|
||||||
|
|
||||||
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||||
|
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
|
@ -136,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_instruct_pix2pix_model = False
|
result_is_instruct_pix2pix_model = False
|
||||||
|
|
||||||
if theta_func2:
|
if theta_func2:
|
||||||
shared.state.textinfo = f"Loading B"
|
shared.state.textinfo = "Loading B"
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
else:
|
else:
|
||||||
theta_1 = None
|
theta_1 = None
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
shared.state.textinfo = f"Loading C"
|
shared.state.textinfo = "Loading C"
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import html
|
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
@ -23,14 +19,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
self.paste_field_names = paste_field_names
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
@ -59,6 +55,7 @@ def image_from_url_text(filedata):
|
|||||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||||
|
|
||||||
|
filename = filename.rsplit('?', 1)[0]
|
||||||
return Image.open(filename)
|
return Image.open(filename)
|
||||||
|
|
||||||
if type(filedata) == list:
|
if type(filedata) == list:
|
||||||
@ -129,6 +126,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=jsfunc,
|
_js=jsfunc,
|
||||||
inputs=[binding.source_image_component],
|
inputs=[binding.source_image_component],
|
||||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if binding.source_text_component is not None and fields is not None:
|
if binding.source_text_component is not None and fields is not None:
|
||||||
@ -140,6 +138,7 @@ def connect_paste_params_buttons():
|
|||||||
fn=lambda *x: x,
|
fn=lambda *x: x,
|
||||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||||
outputs=[field for field, name in fields if name in paste_field_names],
|
outputs=[field for field, name in fields if name in paste_field_names],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
binding.paste_button.click(
|
binding.paste_button.click(
|
||||||
@ -147,6 +146,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=f"switch_to_{binding.tabname}",
|
_js=f"switch_to_{binding.tabname}",
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
lines.append(lastline)
|
lines.append(lastline)
|
||||||
lastline = ''
|
lastline = ''
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith("Negative prompt:"):
|
if line.startswith("Negative prompt:"):
|
||||||
done_with_prompt = True
|
done_with_prompt = True
|
||||||
@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||||
m = re_imagesize.match(v)
|
m = re_imagesize.match(v)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
res[k+"-1"] = m.group(1)
|
res[f"{k}-1"] = m.group(1)
|
||||||
res[k+"-2"] = m.group(2)
|
res[f"{k}-2"] = m.group(2)
|
||||||
else:
|
else:
|
||||||
res[k] = v
|
res[k] = v
|
||||||
|
|
||||||
@ -447,12 +447,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
fn=paste_func,
|
fn=paste_func,
|
||||||
inputs=[input_comp],
|
inputs=[input_comp],
|
||||||
outputs=[x[0] for x in paste_fields],
|
outputs=[x[0] for x in paste_fields],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
button.click(
|
button.click(
|
||||||
fn=None,
|
fn=None,
|
||||||
_js=f"recalculate_prompts_{tabname}",
|
_js=f"recalculate_prompts_{tabname}",
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from facexlib import detection, parsing
|
from facexlib import detection, parsing # noqa: F401
|
||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
@ -13,7 +13,7 @@ cache_data = None
|
|||||||
|
|
||||||
|
|
||||||
def dump_cache():
|
def dump_cache():
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
with open(cache_filename, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ def cache(subsection):
|
|||||||
global cache_data
|
global cache_data
|
||||||
|
|
||||||
if cache_data is None:
|
if cache_data is None:
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
if not os.path.isfile(cache_filename):
|
if not os.path.isfile(cache_filename):
|
||||||
cache_data = {}
|
cache_data = {}
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import csv
|
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
|
|
||||||
@ -178,34 +177,34 @@ class Hypernetwork:
|
|||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
res += layer.parameters()
|
res += layer.parameters()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train(mode=mode)
|
layer.train(mode=mode)
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = mode
|
param.requires_grad = mode
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.to(device)
|
layer.to(device)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.multiplier = multiplier
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.eval()
|
layer.eval()
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
try:
|
try:
|
||||||
sd_hijack_checkpoint.add()
|
sd_hijack_checkpoint.add()
|
||||||
|
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
loss_logging.append(_loss_step)
|
loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
_loss_step = 0
|
_loss_step = 0
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
@ -1,19 +1,17 @@
|
|||||||
import html
|
import html
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
@ -13,17 +13,24 @@ import numpy as np
|
|||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from fonts.ttf import Roboto
|
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks, errors
|
from modules import sd_samplers, shared, script_callbacks, errors
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.paths_internal import roboto_ttf_file
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_font(fontsize: int):
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||||
|
except Exception:
|
||||||
|
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size=1, rows=None):
|
def image_grid(imgs, batch_size=1, rows=None):
|
||||||
if rows is None:
|
if rows is None:
|
||||||
if opts.n_rows > 0:
|
if opts.n_rows > 0:
|
||||||
@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
lines.append(word)
|
lines.append(word)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def get_font(fontsize):
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
|
||||||
except Exception:
|
|
||||||
return ImageFont.truetype(Roboto, fontsize)
|
|
||||||
|
|
||||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
fnt = initial_fnt
|
fnt = initial_fnt
|
||||||
fontsize = initial_fontsize
|
fontsize = initial_fontsize
|
||||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||||
@ -357,6 +358,7 @@ class FilenameGenerator:
|
|||||||
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||||
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -365,7 +367,7 @@ class FilenameGenerator:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@ -408,13 +410,13 @@ class FilenameGenerator:
|
|||||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||||
try:
|
try:
|
||||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
except pytz.exceptions.UnknownTimeZoneError:
|
||||||
time_zone = None
|
time_zone = None
|
||||||
|
|
||||||
time_zone_time = time_datetime.astimezone(time_zone)
|
time_zone_time = time_datetime.astimezone(time_zone)
|
||||||
try:
|
try:
|
||||||
formatted_time = time_zone_time.strftime(time_format)
|
formatted_time = time_zone_time.strftime(time_format)
|
||||||
except (ValueError, TypeError) as _:
|
except (ValueError, TypeError):
|
||||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
@ -466,14 +468,14 @@ def get_next_sequence_number(path, basename):
|
|||||||
"""
|
"""
|
||||||
result = -1
|
result = -1
|
||||||
if basename != '':
|
if basename != '':
|
||||||
basename = basename + "-"
|
basename = f"{basename}-"
|
||||||
|
|
||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(parts[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -535,7 +537,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
add_number = opts.save_images_add_number or file_decoration == ''
|
add_number = opts.save_images_add_number or file_decoration == ''
|
||||||
|
|
||||||
if file_decoration != "" and add_number:
|
if file_decoration != "" and add_number:
|
||||||
file_decoration = "-" + file_decoration
|
file_decoration = f"-{file_decoration}"
|
||||||
|
|
||||||
file_decoration = namegen.apply(file_decoration) + suffix
|
file_decoration = namegen.apply(file_decoration) + suffix
|
||||||
|
|
||||||
@ -565,7 +567,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
|
|
||||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||||
temp_file_path = filename_without_extension + ".tmp"
|
temp_file_path = f"{filename_without_extension}.tmp"
|
||||||
image_format = Image.registered_extensions()[extension]
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
if extension.lower() == '.png':
|
||||||
@ -625,7 +627,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
if opts.save_txt and info is not None:
|
if opts.save_txt and info is not None:
|
||||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||||
with open(txt_fullfn, "w", encoding="utf8") as file:
|
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||||
file.write(info + "\n")
|
file.write(f"{info}\n")
|
||||||
else:
|
else:
|
||||||
txt_fullfn = None
|
txt_fullfn = None
|
||||||
|
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +44,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(image)
|
img = Image.open(image)
|
||||||
except UnidentifiedImageError:
|
except UnidentifiedImageError as e:
|
||||||
|
print(e)
|
||||||
continue
|
continue
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
# Use the EXIF orientation of photos taken by smartphones.
|
||||||
img = ImageOps.exif_transpose(img)
|
img = ImageOps.exif_transpose(img)
|
||||||
@ -58,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
# try to find corresponding mask for an image using simple filename matching
|
# try to find corresponding mask for an image using simple filename matching
|
||||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||||
# if not found use first one ("same mask for all images" use-case)
|
# if not found use first one ("same mask for all images" use-case)
|
||||||
if not mask_image_path in inpaint_masks:
|
if mask_image_path not in inpaint_masks:
|
||||||
mask_image_path = inpaint_masks[0]
|
mask_image_path = inpaint_masks[0]
|
||||||
mask_image = Image.open(mask_image_path)
|
mask_image = Image.open(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
|
@ -11,7 +11,6 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
@ -28,7 +27,7 @@ def category_types():
|
|||||||
def download_default_clip_interrogate_categories(content_dir):
|
def download_default_clip_interrogate_categories(content_dir):
|
||||||
print("Downloading CLIP categories...")
|
print("Downloading CLIP categories...")
|
||||||
|
|
||||||
tmpdir = content_dir + "_tmp"
|
tmpdir = f"{content_dir}_tmp"
|
||||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -160,7 +159,7 @@ class InterrogateModels:
|
|||||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||||
|
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
@ -208,13 +207,13 @@ class InterrogateModels:
|
|||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
for name, topn, items in self.categories():
|
for cat in self.categories():
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if shared.opts.interrogate_return_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
res += f", ({match}:{score/100:.3f})"
|
res += f", ({match}:{score/100:.3f})"
|
||||||
else:
|
else:
|
||||||
res += ", " + match
|
res += f", {match}"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error interrogating", file=sys.stderr)
|
print("Error interrogating", file=sys.stderr)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
from modules import paths
|
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
@ -54,6 +53,11 @@ if has_mps:
|
|||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
|
if platform.processor() == 'i386':
|
||||||
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
|
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||||
|
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
"""
|
"""
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
if ext_filter is None:
|
|
||||||
ext_filter = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
places = []
|
places = []
|
||||||
|
|
||||||
@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
places.append(model_path)
|
places.append(model_path)
|
||||||
|
|
||||||
for place in places:
|
for place in places:
|
||||||
if os.path.exists(place):
|
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
||||||
for file in glob.iglob(place + '**/**', recursive=True):
|
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||||
full_path = file
|
print(f"Skipping broken symlink: {full_path}")
|
||||||
if os.path.isdir(full_path):
|
continue
|
||||||
continue
|
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
continue
|
||||||
print(f"Skipping broken symlink: {full_path}")
|
if full_path not in output:
|
||||||
continue
|
output.append(full_path)
|
||||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
|
||||||
continue
|
|
||||||
if len(ext_filter) != 0:
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
if extension not in ext_filter:
|
|
||||||
continue
|
|
||||||
if file not in output:
|
|
||||||
output.append(full_path)
|
|
||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
try:
|
try:
|
||||||
shutil.move(fullpath, dest_path)
|
shutil.move(fullpath, dest_path)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if len(os.listdir(src_path)) == 0:
|
if len(os.listdir(src_path)) == 0:
|
||||||
print(f"Removing empty folder: {src_path}")
|
print(f"Removing empty folder: {src_path}")
|
||||||
shutil.rmtree(src_path, True)
|
shutil.rmtree(src_path, True)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
builtin_upscaler_classes = []
|
|
||||||
forbidden_upscaler_classes = set()
|
|
||||||
|
|
||||||
|
|
||||||
def list_builtin_upscalers():
|
|
||||||
load_upscalers()
|
|
||||||
|
|
||||||
builtin_upscaler_classes.clear()
|
|
||||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
|
||||||
|
|
||||||
|
|
||||||
def forbid_loaded_nonbuiltin_upscalers():
|
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls not in builtin_upscaler_classes:
|
|
||||||
forbidden_upscaler_classes.add(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
@ -155,15 +126,22 @@ def load_upscalers():
|
|||||||
full_model = f"modules.{model_name}_model"
|
full_model = f"modules.{model_name}_model"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(full_model)
|
importlib.import_module(full_model)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
commandline_options = vars(shared.cmd_opts)
|
commandline_options = vars(shared.cmd_opts)
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls in forbidden_upscaler_classes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||||
|
# up with two copies of those classes. The newest copy will always be the last in the list,
|
||||||
|
# so we go from end to beginning and ignore duplicates
|
||||||
|
used_classes = {}
|
||||||
|
for cls in reversed(Upscaler.__subclasses__()):
|
||||||
|
classname = str(cls)
|
||||||
|
if classname not in used_classes:
|
||||||
|
used_classes[classname] = cls
|
||||||
|
|
||||||
|
for cls in reversed(used_classes.values()):
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||||
scaler = cls(commandline_options.get(cmd_name, None))
|
scaler = cls(commandline_options.get(cmd_name, None))
|
||||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||||
if self.use_ema and not load_ema:
|
if self.use_ema and not load_ema:
|
||||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
|
ignore_keys = ignore_keys or []
|
||||||
|
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print(f"Deleting key {k} from state_dict.")
|
||||||
del sd[k]
|
del sd[k]
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
sd, strict=False)
|
sd, strict=False)
|
||||||
@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
|
|||||||
_, loss_dict_no_ema = self.shared_step(batch)
|
_, loss_dict_no_ema = self.shared_step(batch)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
_, loss_dict_ema = self.shared_step(batch)
|
_, loss_dict_ema = self.shared_step(batch)
|
||||||
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
||||||
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
|
|
||||||
@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
use_ddim = False
|
use_ddim = False
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .sampler import UniPCSampler
|
from .sampler import UniPCSampler # noqa: F401
|
||||||
|
@ -54,7 +54,8 @@ class UniPCSampler(object):
|
|||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -94,7 +93,7 @@ class NoiseScheduleVP:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||||
|
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
if schedule == 'discrete':
|
if schedule == 'discrete':
|
||||||
@ -179,13 +178,13 @@ def model_wrapper(
|
|||||||
model,
|
model,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs=None,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
#condition=None,
|
#condition=None,
|
||||||
#unconditional_condition=None,
|
#unconditional_condition=None,
|
||||||
guidance_scale=1.,
|
guidance_scale=1.,
|
||||||
classifier_fn=None,
|
classifier_fn=None,
|
||||||
classifier_kwargs={},
|
classifier_kwargs=None,
|
||||||
):
|
):
|
||||||
"""Create a wrapper function for the noise prediction model.
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
@ -276,6 +275,9 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
classifier_kwargs = classifier_kwargs or {}
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
@ -342,7 +344,7 @@ def model_wrapper(
|
|||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
assert isinstance(unconditional_condition, dict)
|
assert isinstance(unconditional_condition, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in condition:
|
for k in condition:
|
||||||
if isinstance(condition[k], list):
|
if isinstance(condition[k], list):
|
||||||
c_in[k] = [torch.cat([
|
c_in[k] = [torch.cat([
|
||||||
@ -353,7 +355,7 @@ def model_wrapper(
|
|||||||
unconditional_condition[k],
|
unconditional_condition[k],
|
||||||
condition[k]])
|
condition[k]])
|
||||||
elif isinstance(condition, list):
|
elif isinstance(condition, list):
|
||||||
c_in = list()
|
c_in = []
|
||||||
assert isinstance(unconditional_condition, list)
|
assert isinstance(unconditional_condition, list)
|
||||||
for i in range(len(condition)):
|
for i in range(len(condition)):
|
||||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||||
@ -469,7 +471,7 @@ class UniPC:
|
|||||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||||
return t
|
return t
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||||
|
|
||||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||||
"""
|
"""
|
||||||
@ -757,40 +759,44 @@ class UniPC:
|
|||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
t_prev_list = [vec_t]
|
t_prev_list = [vec_t]
|
||||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
with tqdm.tqdm(total=steps) as pbar:
|
||||||
for init_order in range(1, order):
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
for init_order in range(1, order):
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
if model_x is None:
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
model_x = self.model_fn(x, vec_t)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
model_prev_list.append(model_x)
|
|
||||||
t_prev_list.append(vec_t)
|
|
||||||
for step in trange(order, steps + 1):
|
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
|
||||||
if lower_order_final:
|
|
||||||
step_order = min(order, steps + 1 - step)
|
|
||||||
else:
|
|
||||||
step_order = order
|
|
||||||
#print('this step order:', step_order)
|
|
||||||
if step == steps:
|
|
||||||
#print('do not run corrector at the last step')
|
|
||||||
use_corrector = False
|
|
||||||
else:
|
|
||||||
use_corrector = True
|
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
for i in range(order - 1):
|
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
|
||||||
t_prev_list[-1] = vec_t
|
|
||||||
# We do not need to evaluate the final model value.
|
|
||||||
if step < steps:
|
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
pbar.update()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
@ -7,13 +7,13 @@ def connect(token, port, region):
|
|||||||
else:
|
else:
|
||||||
if ':' in token:
|
if ':' in token:
|
||||||
# token = authtoken:username:password
|
# token = authtoken:username:password
|
||||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
token, username, password = token.split(':', 2)
|
||||||
token = token.split(':')[0]
|
account = f"{username}:{password}"
|
||||||
|
|
||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
|
|
||||||
# Guard for existing tunnels
|
# Guard for existing tunnels
|
||||||
existing = ngrok.get_tunnels(pyngrok_config=config)
|
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||||
if existing:
|
if existing:
|
||||||
@ -24,7 +24,7 @@ def connect(token, port, region):
|
|||||||
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||||
'You can use this link after the launch is complete.')
|
'You can use this link after the launch is complete.')
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if account is None:
|
if account is None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
import modules.safe
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
# data_path = cmd_opts_pre.data
|
||||||
@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
|
|||||||
sd_path = os.path.abspath(possible_sd_path)
|
sd_path = os.path.abspath(possible_sd_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
|
@ -3,7 +3,8 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
script_path = os.path.dirname(modules_path)
|
||||||
|
|
||||||
sd_configs_path = os.path.join(script_path, "configs")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
@ -12,7 +13,7 @@ default_sd_model_file = sd_model_file
|
|||||||
|
|
||||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||||
|
|
||||||
data_path = cmd_opts_pre.data_dir
|
data_path = cmd_opts_pre.data_dir
|
||||||
@ -21,3 +22,5 @@ models_path = os.path.join(data_path, "models")
|
|||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
config_states_dir = os.path.join(script_path, "config_states")
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
|
||||||
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -2,7 +2,6 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -11,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -172,7 +171,7 @@ class StableDiffusionProcessing:
|
|||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
@ -465,6 +464,16 @@ def fix_seed(p):
|
|||||||
p.subseed = get_fixed_seed(p.subseed)
|
p.subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
|
||||||
|
def program_version():
|
||||||
|
import launch
|
||||||
|
|
||||||
|
res = launch.git_tag()
|
||||||
|
if res == "<none>":
|
||||||
|
res = None
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
@ -499,13 +508,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@ -678,7 +688,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||||
try:
|
try:
|
||||||
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
@ -794,7 +804,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
res = Processed(
|
||||||
|
p,
|
||||||
|
images_list=output_images,
|
||||||
|
seed=p.all_seeds[0],
|
||||||
|
info=infotext(),
|
||||||
|
comments="".join(f"\n\n{comment}" for comment in comments),
|
||||||
|
subseed=p.all_subseeds[0],
|
||||||
|
index_of_first_image=index_of_first_image,
|
||||||
|
infotexts=infotexts,
|
||||||
|
)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess(p, res)
|
p.scripts.postprocess(p, res)
|
||||||
|
@ -95,8 +95,17 @@ def progressapi(req: ProgressRequest):
|
|||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
if image is not None:
|
if image is not None:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="png")
|
format = opts.live_previews_format
|
||||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
save_kwargs = {}
|
||||||
|
if format == "auto":
|
||||||
|
if max(*image.size) > 256:
|
||||||
|
format = "jpeg"
|
||||||
|
else:
|
||||||
|
format = "png"
|
||||||
|
save_kwargs = {"optimize": True}
|
||||||
|
image.save(buffered, format=format, **save_kwargs)
|
||||||
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
|
live_preview = f"data:image/{format};base64,{base64_image}"
|
||||||
id_live_preview = shared.state.id_live_preview
|
id_live_preview = shared.state.id_live_preview
|
||||||
else:
|
else:
|
||||||
live_preview = None
|
live_preview = None
|
||||||
|
@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
l = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-1] = float(tree.children[-1])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-1] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
res.append(tree.children[-1])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
l.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
|
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(res))
|
||||||
|
|
||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def get_schedule(prompt):
|
def get_schedule(prompt):
|
||||||
try:
|
try:
|
||||||
tree = schedule_parser.parse(prompt)
|
tree = schedule_parser.parse(prompt)
|
||||||
except lark.exceptions.LarkError as e:
|
except lark.exceptions.LarkError:
|
||||||
if 0:
|
if 0:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
conds = model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||||
|
|
||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
|||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
for i, cond_schedule in enumerate(c):
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
for current, entry in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
res[i] = cond_schedule[target_index].cond
|
||||||
@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
tensors = []
|
tensors = []
|
||||||
conds_list = []
|
conds_list = []
|
||||||
|
|
||||||
for batch_no, composable_prompts in enumerate(c.batch):
|
for composable_prompts in c.batch:
|
||||||
conds_for_batch = []
|
conds_for_batch = []
|
||||||
|
|
||||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
for composable_prompt in composable_prompts:
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
for current, entry in enumerate(composable_prompt.schedules):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer # noqa: F401
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = self.load_models(path)
|
||||||
@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
for scaler in scalers:
|
for scaler in scalers:
|
||||||
if scaler.local_data_path.startswith("http"):
|
if scaler.local_data_path.startswith("http"):
|
||||||
filename = modelloader.friendly_name(scaler.local_data_path)
|
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||||
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None)
|
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||||
if local:
|
if local_model_candidates:
|
||||||
scaler.local_data_path = local
|
scaler.local_data_path = local_model_candidates[0]
|
||||||
|
|
||||||
if scaler.name in opts.realesrgan_enabled_models:
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
|
|
||||||
info = self.load_model(path)
|
info = self.load_model(path)
|
||||||
if not os.path.exists(info.local_data_path):
|
if not os.path.exists(info.local_data_path):
|
||||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -134,6 +134,6 @@ def get_realesrgan_models(scaler):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
@ -95,16 +95,16 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
except zipfile.BadZipfile:
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
unpickler.extra_handler = extra_handler
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
@ -32,22 +32,22 @@ class CFGDenoiserParams:
|
|||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
self.image_cond = image_cond
|
self.image_cond = image_cond
|
||||||
"""Conditioning image"""
|
"""Conditioning image"""
|
||||||
|
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
"""Current sigma noise step value"""
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
self.sampling_step = sampling_step
|
self.sampling_step = sampling_step
|
||||||
"""Current Sampling step number"""
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
self.text_cond = text_cond
|
self.text_cond = text_cond
|
||||||
""" Encoder hidden states of text conditioning from prompt"""
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ def add_callback(callbacks, fun):
|
|||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
|
@ -163,7 +163,8 @@ class Script:
|
|||||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||||
|
|
||||||
need_tabname = self.show(True) == self.show(False)
|
need_tabname = self.show(True) == self.show(False)
|
||||||
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
||||||
|
tabname = f"{tabkind}_" if need_tabname else ""
|
||||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||||
|
|
||||||
return f'script_{tabname}{title}_{item_id}'
|
return f'script_{tabname}{title}_{item_id}'
|
||||||
@ -230,7 +231,7 @@ def load_scripts():
|
|||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if type(script_class) != type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -294,9 +295,9 @@ class ScriptRunner:
|
|||||||
|
|
||||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
for script_data in auto_processing_scripts + scripts_data:
|
||||||
script = script_class()
|
script = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
|
||||||
@ -491,7 +492,7 @@ class ScriptRunner:
|
|||||||
module = script_loading.load_module(script.filename)
|
module = script_loading.load_module(script.filename)
|
||||||
cache[filename] = module
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
self.scripts[si] = script_class()
|
self.scripts[si] = script_class()
|
||||||
self.scripts[si].filename = filename
|
self.scripts[si].filename = filename
|
||||||
@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp):
|
|||||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||||
"""
|
"""
|
||||||
|
|
||||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||||
|
|
||||||
if getattr(comp, 'multiselect', False):
|
if getattr(comp, 'multiselect', False):
|
||||||
comp.elem_classes.append('multiselect')
|
comp.elem_classes.append('multiselect')
|
||||||
|
@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||||||
return self.postprocessing_controls.values()
|
return self.postprocessing_controls.values()
|
||||||
|
|
||||||
def postprocess_image(self, p, script_pp, *args):
|
def postprocess_image(self, p, script_pp, *args):
|
||||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||||
pp.info = {}
|
pp.info = {}
|
||||||
|
@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
|||||||
def initialize_scripts(self, scripts_data):
|
def initialize_scripts(self, scripts_data):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in scripts_data:
|
for script_data in scripts_data:
|
||||||
script: ScriptPostprocessing = script_class()
|
script: ScriptPostprocessing = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
|
|
||||||
if script.name == "Simple Upscale":
|
if script.name == "Simple Upscale":
|
||||||
continue
|
continue
|
||||||
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
|||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, component), value in zip(script.controls.items(), script_args):
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
script.process(pp, **process_args)
|
||||||
|
@ -61,7 +61,7 @@ class DisableInitialization:
|
|||||||
if res is None:
|
if res is None:
|
||||||
res = original(url, *args, local_files_only=False, **kwargs)
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import devices, sd_hijack_optimizations, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -34,10 +34,10 @@ def apply_optimizations():
|
|||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
@ -92,12 +92,12 @@ def fix_checkpoint():
|
|||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
#Calculate the weight normally, but ignore the mean
|
#Calculate the weight normally, but ignore the mean
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
|
||||||
#Check if we have weights available
|
#Check if we have weights available
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
loss *= weight
|
loss *= weight
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
#Return the loss, as mean if specified
|
||||||
return loss.mean() if mean else loss
|
return loss.mean() if mean else loss
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
#Temporarily append weights to a place accessible during loss calc
|
||||||
sd_model._custom_loss_weight = w
|
sd_model._custom_loss_weight = w
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
if not hasattr(sd_model, '_old_get_loss'):
|
||||||
@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Delete temporary weights if appended
|
#Delete temporary weights if appended
|
||||||
del sd_model._custom_loss_weight
|
del sd_model._custom_loss_weight
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
#If we have an old loss function, reset the loss function to the original one
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
if hasattr(sd_model, '_old_get_loss'):
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
sd_model.get_loss = sd_model._old_get_loss
|
||||||
@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
|
|||||||
def undo_weighted_forward(sd_model):
|
def undo_weighted_forward(sd_model):
|
||||||
try:
|
try:
|
||||||
del sd_model.weighted_forward
|
del sd_model.weighted_forward
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for position, embedding in fixes:
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
|
|||||||
self.hijack.comments += hijack_comments
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
from omegaconf import ListConfig
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
import ldm.models.diffusion.ddpm
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddim import noise_like
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
if isinstance(c, dict):
|
if isinstance(c, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in c:
|
for k in c:
|
||||||
if isinstance(c[k], list):
|
if isinstance(c[k], list):
|
||||||
c_in[k] = [
|
c_in[k] = [
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import collections
|
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
|
|
||||||
def should_hijack_ip2p(checkpoint_info):
|
def should_hijack_ip2p(checkpoint_info):
|
||||||
from modules import sd_models_config
|
from modules import sd_models_config
|
||||||
@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
|||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||||
|
@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
end = i + 2
|
end = i + 2
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
s1 *= self.scale
|
s1 *= self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
del q, k, v
|
del q, k, v
|
||||||
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k_in = k_in * self.scale
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = r1.to(dtype)
|
r1 = r1.to(dtype)
|
||||||
@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
r = r.to(dtype)
|
r = r.to(dtype)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
if q.device.type == 'mps':
|
||||||
|
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
# i.e. send it down the unchunked fast-path
|
# i.e. send it down the unchunked fast-path
|
||||||
query_chunk_size = q_tokens
|
|
||||||
kv_chunk_size = k_tokens
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||||
@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
@ -18,7 +18,7 @@ class TorchHijackForUnet:
|
|||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def cat(self, tensors, *args, **kwargs):
|
def cat(self, tensors, *args, **kwargs):
|
||||||
if len(tensors) == 2:
|
if len(tensors) == 2:
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import open_clip.tokenizer
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import sd_hijack_clip, devices
|
from modules import sd_hijack_clip, devices
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
@ -15,7 +15,6 @@ import ldm.modules.midas as midas
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
@ -48,7 +47,7 @@ class CheckpointInfo:
|
|||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
self.hash = model_hash(filename)
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
@ -70,7 +69,7 @@ class CheckpointInfo:
|
|||||||
checkpoint_alisases[id] = self
|
checkpoint_alisases[id] = self
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||||
if self.sha256 is None:
|
if self.sha256 is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -88,8 +87,7 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
from transformers import logging, CLIPModel # noqa: F401
|
||||||
from transformers import logging, CLIPModel
|
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -168,7 +166,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -240,7 +238,7 @@ def read_metadata_from_safetensors(filename):
|
|||||||
if isinstance(v, str) and v[0:1] == '{':
|
if isinstance(v, str) and v[0:1] == '{':
|
||||||
try:
|
try:
|
||||||
res[k] = json.loads(v)
|
res[k] = json.loads(v)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -375,7 +373,7 @@ def enable_midas_autodownload():
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
if not os.path.exists(midas_path):
|
if not os.path.exists(midas_path):
|
||||||
mkdir(midas_path)
|
mkdir(midas_path)
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
print(f"{model_type} downloaded")
|
print(f"{model_type} downloaded")
|
||||||
@ -468,8 +466,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
@ -546,7 +543,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
@ -567,7 +564,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
|
|||||||
if info is None:
|
if info is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||||
if os.path.exists(config):
|
if os.path.exists(config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
|
@ -55,7 +55,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
|
||||||
|
|
||||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
cond = tensor
|
cond = tensor
|
||||||
|
|
||||||
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
# for DDIM, shapes must match, we can't just process cond and uncond independently;
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import einops
|
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import prompt_parser, devices, sd_samplers_common
|
||||||
|
|
||||||
@ -87,17 +86,17 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||||
else:
|
else:
|
||||||
image_uncond = image_cond
|
image_uncond = image_cond
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
@ -198,7 +197,7 @@ class TorchHijack:
|
|||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def randn_like(self, x):
|
def randn_like(self, x):
|
||||||
if self.sampler_noises:
|
if self.sampler_noises:
|
||||||
@ -317,7 +316,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@ -340,9 +339,9 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args={
|
extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
@ -375,9 +374,9 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from collections import namedtuple
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -89,7 +86,7 @@ def refresh_vae_list():
|
|||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
checkpoint_path = os.path.splitext(checkpoint_file)[0]
|
||||||
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
|
for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
|
||||||
if os.path.isfile(vae_location):
|
if os.path.isfile(vae_location):
|
||||||
return vae_location
|
return vae_location
|
||||||
|
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import requests
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
@ -15,7 +12,7 @@ import modules.memmon
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
@ -214,7 +211,7 @@ class OptionInfo:
|
|||||||
|
|
||||||
|
|
||||||
def options_section(section_identifier, options_dict):
|
def options_section(section_identifier, options_dict):
|
||||||
for k, v in options_dict.items():
|
for v in options_dict.values():
|
||||||
v.section = section_identifier
|
v.section = section_identifier
|
||||||
|
|
||||||
return options_dict
|
return options_dict
|
||||||
@ -384,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
@ -392,31 +389,38 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
|
||||||
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
|
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
|
||||||
|
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
|
||||||
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
|
||||||
|
"live_previews_format": OptionInfo("auto", "Live preview file format", gr.Radio, {"choices": ["auto", "jpeg", "png", "webp"]}),
|
||||||
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
|
||||||
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
|
||||||
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
"show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
|
||||||
@ -588,6 +592,10 @@ class Options:
|
|||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.1.1 quicksettings list migration
|
||||||
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
|
||||||
bad_settings = 0
|
bad_settings = 0
|
||||||
for k, v in self.data.items():
|
for k, v in self.data.items():
|
||||||
info = self.data_labels.get(k, None)
|
info = self.data_labels.get(k, None)
|
||||||
@ -617,11 +625,11 @@ class Options:
|
|||||||
|
|
||||||
section_ids = {}
|
section_ids = {}
|
||||||
settings_items = self.data_labels.items()
|
settings_items = self.data_labels.items()
|
||||||
for k, item in settings_items:
|
for _, item in settings_items:
|
||||||
if item.section not in section_ids:
|
if item.section not in section_ids:
|
||||||
section_ids[item.section] = len(section_ids)
|
section_ids[item.section] = len(section_ids)
|
||||||
|
|
||||||
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
|
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
|
||||||
|
|
||||||
def cast_value(self, key, value):
|
def cast_value(self, key, value):
|
||||||
"""casts an arbitrary to the same type as this setting's value with key
|
"""casts an arbitrary to the same type as this setting's value with key
|
||||||
@ -707,8 +715,8 @@ def reload_gradio_theme(theme_name=None):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||||
except requests.exceptions.ConnectionError:
|
except Exception as e:
|
||||||
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
|
errors.display(e, "changing gradio theme")
|
||||||
gradio_theme = gr.themes.Default()
|
gradio_theme = gr.themes.Default()
|
||||||
|
|
||||||
|
|
||||||
@ -769,3 +777,20 @@ def html(filename):
|
|||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def walk_files(path, allowed_extensions=None):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return
|
||||||
|
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
allowed_extensions = set(allowed_extensions)
|
||||||
|
|
||||||
|
for root, _, files in os.walk(path, followlinks=True):
|
||||||
|
for filename in files:
|
||||||
|
if allowed_extensions is not None:
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
if ext not in allowed_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield os.path.join(root, filename)
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import typing
|
import typing
|
||||||
import collections.abc as abc
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
# Only import this when code is being type-checked, it doesn't have any effect at runtime
|
|
||||||
from .processing import StableDiffusionProcessing
|
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(typing.NamedTuple):
|
class PromptStyle(typing.NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
@ -74,7 +65,7 @@ class StyleDatabase:
|
|||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Always keep a backup file around
|
# Always keep a backup file around
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
shutil.copy(path, path + ".bak")
|
shutil.copy(path, f"{path}.bak")
|
||||||
|
|
||||||
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
||||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
|
@ -179,7 +179,7 @@ def efficient_dot_product_attention(
|
|||||||
chunk_idx,
|
chunk_idx,
|
||||||
min(query_chunk_size, q_tokens)
|
min(query_chunk_size, q_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
@ -201,14 +201,15 @@ def efficient_dot_product_attention(
|
|||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
res = torch.zeros_like(query)
|
||||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
for i in range(math.ceil(q_tokens / query_chunk_size)):
|
||||||
res = torch.cat([
|
attn_scores = compute_query_chunk_attn(
|
||||||
compute_query_chunk_attn(
|
|
||||||
query=get_query_chunk(i * query_chunk_size),
|
query=get_query_chunk(i * query_chunk_size),
|
||||||
key=key,
|
key=key,
|
||||||
value=value,
|
value=value,
|
||||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
)
|
||||||
], dim=1)
|
|
||||||
|
res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
|
||||||
from math import log, sqrt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw
|
from PIL import ImageDraw
|
||||||
|
|
||||||
GREEN = "#0F0"
|
GREEN = "#0F0"
|
||||||
BLUE = "#00F"
|
BLUE = "#00F"
|
||||||
@ -12,63 +10,64 @@ RED = "#F00"
|
|||||||
|
|
||||||
|
|
||||||
def crop_image(im, settings):
|
def crop_image(im, settings):
|
||||||
""" Intelligently crop an image to the subject matter """
|
""" Intelligently crop an image to the subject matter """
|
||||||
|
|
||||||
scale_by = 1
|
scale_by = 1
|
||||||
if is_landscape(im.width, im.height):
|
if is_landscape(im.width, im.height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
elif is_portrait(im.width, im.height):
|
elif is_portrait(im.width, im.height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_square(im.width, im.height):
|
elif is_square(im.width, im.height):
|
||||||
if is_square(settings.crop_width, settings.crop_height):
|
if is_square(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
|
|
||||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
|
||||||
im_debug = im.copy()
|
|
||||||
|
|
||||||
focus = focal_point(im_debug, settings)
|
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||||
|
im_debug = im.copy()
|
||||||
|
|
||||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
focus = focal_point(im_debug, settings)
|
||||||
# point but then get adjusted back into the frame
|
|
||||||
y_half = int(settings.crop_height / 2)
|
|
||||||
x_half = int(settings.crop_width / 2)
|
|
||||||
|
|
||||||
x1 = focus.x - x_half
|
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||||
if x1 < 0:
|
# point but then get adjusted back into the frame
|
||||||
x1 = 0
|
y_half = int(settings.crop_height / 2)
|
||||||
elif x1 + settings.crop_width > im.width:
|
x_half = int(settings.crop_width / 2)
|
||||||
x1 = im.width - settings.crop_width
|
|
||||||
|
|
||||||
y1 = focus.y - y_half
|
x1 = focus.x - x_half
|
||||||
if y1 < 0:
|
if x1 < 0:
|
||||||
y1 = 0
|
x1 = 0
|
||||||
elif y1 + settings.crop_height > im.height:
|
elif x1 + settings.crop_width > im.width:
|
||||||
y1 = im.height - settings.crop_height
|
x1 = im.width - settings.crop_width
|
||||||
|
|
||||||
x2 = x1 + settings.crop_width
|
y1 = focus.y - y_half
|
||||||
y2 = y1 + settings.crop_height
|
if y1 < 0:
|
||||||
|
y1 = 0
|
||||||
|
elif y1 + settings.crop_height > im.height:
|
||||||
|
y1 = im.height - settings.crop_height
|
||||||
|
|
||||||
crop = [x1, y1, x2, y2]
|
x2 = x1 + settings.crop_width
|
||||||
|
y2 = y1 + settings.crop_height
|
||||||
|
|
||||||
results = []
|
crop = [x1, y1, x2, y2]
|
||||||
|
|
||||||
results.append(im.crop(tuple(crop)))
|
results = []
|
||||||
|
|
||||||
if settings.annotate_image:
|
results.append(im.crop(tuple(crop)))
|
||||||
d = ImageDraw.Draw(im_debug)
|
|
||||||
rect = list(crop)
|
|
||||||
rect[2] -= 1
|
|
||||||
rect[3] -= 1
|
|
||||||
d.rectangle(rect, outline=GREEN)
|
|
||||||
results.append(im_debug)
|
|
||||||
if settings.destop_view_image:
|
|
||||||
im_debug.show()
|
|
||||||
|
|
||||||
return results
|
if settings.annotate_image:
|
||||||
|
d = ImageDraw.Draw(im_debug)
|
||||||
|
rect = list(crop)
|
||||||
|
rect[2] -= 1
|
||||||
|
rect[3] -= 1
|
||||||
|
d.rectangle(rect, outline=GREEN)
|
||||||
|
results.append(im_debug)
|
||||||
|
if settings.destop_view_image:
|
||||||
|
im_debug.show()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def focal_point(im, settings):
|
def focal_point(im, settings):
|
||||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||||
@ -88,7 +87,7 @@ def focal_point(im, settings):
|
|||||||
corner_centroid = None
|
corner_centroid = None
|
||||||
if len(corner_points) > 0:
|
if len(corner_points) > 0:
|
||||||
corner_centroid = centroid(corner_points)
|
corner_centroid = centroid(corner_points)
|
||||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||||
pois.append(corner_centroid)
|
pois.append(corner_centroid)
|
||||||
|
|
||||||
entropy_centroid = None
|
entropy_centroid = None
|
||||||
@ -100,7 +99,7 @@ def focal_point(im, settings):
|
|||||||
face_centroid = None
|
face_centroid = None
|
||||||
if len(face_points) > 0:
|
if len(face_points) > 0:
|
||||||
face_centroid = centroid(face_points)
|
face_centroid = centroid(face_points)
|
||||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||||
pois.append(face_centroid)
|
pois.append(face_centroid)
|
||||||
|
|
||||||
average_point = poi_average(pois, settings)
|
average_point = poi_average(pois, settings)
|
||||||
@ -111,7 +110,7 @@ def focal_point(im, settings):
|
|||||||
if corner_centroid is not None:
|
if corner_centroid is not None:
|
||||||
color = BLUE
|
color = BLUE
|
||||||
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
box = corner_centroid.bounding(max_size * corner_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(corner_points) > 1:
|
if len(corner_points) > 1:
|
||||||
for f in corner_points:
|
for f in corner_points:
|
||||||
@ -119,7 +118,7 @@ def focal_point(im, settings):
|
|||||||
if entropy_centroid is not None:
|
if entropy_centroid is not None:
|
||||||
color = "#ff0"
|
color = "#ff0"
|
||||||
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(entropy_points) > 1:
|
if len(entropy_points) > 1:
|
||||||
for f in entropy_points:
|
for f in entropy_points:
|
||||||
@ -127,14 +126,14 @@ def focal_point(im, settings):
|
|||||||
if face_centroid is not None:
|
if face_centroid is not None:
|
||||||
color = RED
|
color = RED
|
||||||
box = face_centroid.bounding(max_size * face_centroid.weight)
|
box = face_centroid.bounding(max_size * face_centroid.weight)
|
||||||
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
|
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
|
||||||
d.ellipse(box, outline=color)
|
d.ellipse(box, outline=color)
|
||||||
if len(face_points) > 1:
|
if len(face_points) > 1:
|
||||||
for f in face_points:
|
for f in face_points:
|
||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
|
||||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||||
|
|
||||||
return average_point
|
return average_point
|
||||||
|
|
||||||
|
|
||||||
@ -185,7 +184,7 @@ def image_face_points(im, settings):
|
|||||||
try:
|
try:
|
||||||
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
faces = classifier.detectMultiScale(gray, scaleFactor=1.1,
|
||||||
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
minNeighbors=7, minSize=(minsize, minsize), flags=cv2.CASCADE_SCALE_IMAGE)
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(faces) > 0:
|
if len(faces) > 0:
|
||||||
@ -262,10 +261,11 @@ def image_entropy(im):
|
|||||||
hist = hist[hist > 0]
|
hist = hist[hist > 0]
|
||||||
return -np.log2(hist / hist.sum()).sum()
|
return -np.log2(hist / hist.sum()).sum()
|
||||||
|
|
||||||
|
|
||||||
def centroid(pois):
|
def centroid(pois):
|
||||||
x = [poi.x for poi in pois]
|
x = [poi.x for poi in pois]
|
||||||
y = [poi.y for poi in pois]
|
y = [poi.y for poi in pois]
|
||||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||||
|
|
||||||
|
|
||||||
def poi_average(pois, settings):
|
def poi_average(pois, settings):
|
||||||
@ -283,59 +283,59 @@ def poi_average(pois, settings):
|
|||||||
|
|
||||||
|
|
||||||
def is_landscape(w, h):
|
def is_landscape(w, h):
|
||||||
return w > h
|
return w > h
|
||||||
|
|
||||||
|
|
||||||
def is_portrait(w, h):
|
def is_portrait(w, h):
|
||||||
return h > w
|
return h > w
|
||||||
|
|
||||||
|
|
||||||
def is_square(w, h):
|
def is_square(w, h):
|
||||||
return w == h
|
return w == h
|
||||||
|
|
||||||
|
|
||||||
def download_and_cache_models(dirname):
|
def download_and_cache_models(dirname):
|
||||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
model_file_name = 'face_detection_yunet.onnx'
|
model_file_name = 'face_detection_yunet.onnx'
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
cache_file = os.path.join(dirname, model_file_name)
|
cache_file = os.path.join(dirname, model_file_name)
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(cache_file):
|
||||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||||
response = requests.get(download_url)
|
response = requests.get(download_url)
|
||||||
with open(cache_file, "wb") as f:
|
with open(cache_file, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
if os.path.exists(cache_file):
|
if os.path.exists(cache_file):
|
||||||
return cache_file
|
return cache_file
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PointOfInterest:
|
class PointOfInterest:
|
||||||
def __init__(self, x, y, weight=1.0, size=10):
|
def __init__(self, x, y, weight=1.0, size=10):
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def bounding(self, size):
|
def bounding(self, size):
|
||||||
return [
|
return [
|
||||||
self.x - size//2,
|
self.x - size // 2,
|
||||||
self.y - size//2,
|
self.y - size // 2,
|
||||||
self.x + size//2,
|
self.x + size // 2,
|
||||||
self.y + size//2
|
self.y + size // 2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||||
self.crop_width = crop_width
|
self.crop_width = crop_width
|
||||||
self.crop_height = crop_height
|
self.crop_height = crop_height
|
||||||
self.corner_points_weight = corner_points_weight
|
self.corner_points_weight = corner_points_weight
|
||||||
self.entropy_points_weight = entropy_points_weight
|
self.entropy_points_weight = entropy_points_weight
|
||||||
self.face_points_weight = face_points_weight
|
self.face_points_weight = face_points_weight
|
||||||
self.annotate_image = annotate_image
|
self.annotate_image = annotate_image
|
||||||
self.destop_view_image = False
|
self.destop_view_image = False
|
||||||
self.dnn_model_path = dnn_model_path
|
self.dnn_model_path = dnn_model_path
|
||||||
|
@ -72,7 +72,7 @@ class PersonalizedBase(Dataset):
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text_filename = os.path.splitext(path)[0] + ".txt"
|
text_filename = f"{os.path.splitext(path)[0]}.txt"
|
||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
if os.path.exists(text_filename):
|
if os.path.exists(text_filename):
|
||||||
@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
|
|||||||
weight = torch.ones(latent_sample.shape)
|
weight = torch.ones(latent_sample.shape)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
|
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||||
else:
|
else:
|
||||||
@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def collate_wrapper_random(batch):
|
def collate_wrapper_random(batch):
|
||||||
return BatchLoaderRandom(batch)
|
return BatchLoaderRandom(batch)
|
||||||
|
@ -2,10 +2,8 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zlib
|
import zlib
|
||||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from fonts.ttf import Roboto
|
|
||||||
import torch
|
import torch
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoder(json.JSONEncoder):
|
class EmbeddingEncoder(json.JSONEncoder):
|
||||||
@ -17,7 +15,7 @@ class EmbeddingEncoder(json.JSONEncoder):
|
|||||||
|
|
||||||
class EmbeddingDecoder(json.JSONDecoder):
|
class EmbeddingDecoder(json.JSONDecoder):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
|
||||||
|
|
||||||
def object_hook(self, d):
|
def object_hook(self, d):
|
||||||
if 'TORCHTENSOR' in d:
|
if 'TORCHTENSOR' in d:
|
||||||
@ -136,11 +134,8 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||||||
image = srcimage.copy()
|
image = srcimage.copy()
|
||||||
fontsize = 32
|
fontsize = 32
|
||||||
if textfont is None:
|
if textfont is None:
|
||||||
try:
|
from modules.images import get_font
|
||||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
textfont = get_font(fontsize)
|
||||||
textfont = opts.font or Roboto
|
|
||||||
except Exception:
|
|
||||||
textfont = Roboto
|
|
||||||
|
|
||||||
factor = 1.5
|
factor = 1.5
|
||||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||||
|
@ -12,7 +12,7 @@ class LearnScheduleIterator:
|
|||||||
self.it = 0
|
self.it = 0
|
||||||
self.maxit = 0
|
self.maxit = 0
|
||||||
try:
|
try:
|
||||||
for i, pair in enumerate(pairs):
|
for pair in pairs:
|
||||||
if not pair.strip():
|
if not pair.strip():
|
||||||
continue
|
continue
|
||||||
tmp = pair.split(':')
|
tmp = pair.split(':')
|
||||||
@ -32,8 +32,8 @@ class LearnScheduleIterator:
|
|||||||
self.maxit += 1
|
self.maxit += 1
|
||||||
return
|
return
|
||||||
assert self.rates
|
assert self.rates
|
||||||
except (ValueError, AssertionError):
|
except (ValueError, AssertionError) as e:
|
||||||
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
|
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
import math
|
import math
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import time
|
|
||||||
|
|
||||||
from modules import paths, shared, images, deepbooru
|
from modules import paths, shared, images, deepbooru
|
||||||
from modules.shared import opts, cmd_opts
|
|
||||||
from modules.textual_inversion import autocrop
|
from modules.textual_inversion import autocrop
|
||||||
|
|
||||||
|
|
||||||
@ -63,9 +59,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
|
|||||||
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
image.save(os.path.join(params.dstdir, f"{basename}.png"))
|
||||||
|
|
||||||
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
if params.preprocess_txt_action == 'prepend' and existing_caption:
|
||||||
caption = existing_caption + ' ' + caption
|
caption = f"{existing_caption} {caption}"
|
||||||
elif params.preprocess_txt_action == 'append' and existing_caption:
|
elif params.preprocess_txt_action == 'append' and existing_caption:
|
||||||
caption = caption + ' ' + existing_caption
|
caption = f"{caption} {existing_caption}"
|
||||||
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
elif params.preprocess_txt_action == 'copy' and existing_caption:
|
||||||
caption = existing_caption
|
caption = existing_caption
|
||||||
|
|
||||||
@ -129,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
|
|||||||
default=None
|
default=None
|
||||||
)
|
)
|
||||||
return wh and center_crop(image, *wh)
|
return wh and center_crop(image, *wh)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||||
width = process_width
|
width = process_width
|
||||||
@ -174,7 +170,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||||||
params.src = filename
|
params.src = filename
|
||||||
|
|
||||||
existing_caption = None
|
existing_caption = None
|
||||||
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
|
||||||
if os.path.exists(existing_caption_filename):
|
if os.path.exists(existing_caption_filename):
|
||||||
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
||||||
existing_caption = file.read()
|
existing_caption = file.read()
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -30,7 +29,7 @@ textual_inversion_templates = {}
|
|||||||
def list_textual_inversion_templates():
|
def list_textual_inversion_templates():
|
||||||
textual_inversion_templates.clear()
|
textual_inversion_templates.clear()
|
||||||
|
|
||||||
for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
path = os.path.join(root, fn)
|
path = os.path.join(root, fn)
|
||||||
|
|
||||||
@ -69,7 +68,7 @@ class Embedding:
|
|||||||
'hash': self.checksum(),
|
'hash': self.checksum(),
|
||||||
'optimizer_state_dict': self.optimizer_state_dict,
|
'optimizer_state_dict': self.optimizer_state_dict,
|
||||||
}
|
}
|
||||||
torch.save(optimizer_saved_dict, filename + '.optim')
|
torch.save(optimizer_saved_dict, f"{filename}.optim")
|
||||||
|
|
||||||
def checksum(self):
|
def checksum(self):
|
||||||
if self.cached_checksum is not None:
|
if self.cached_checksum is not None:
|
||||||
@ -167,8 +166,7 @@ class EmbeddingDatabase:
|
|||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
param_dict = data['string_to_param']
|
param_dict = data['string_to_param']
|
||||||
if hasattr(param_dict, '_parameters'):
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
# diffuser concepts
|
# diffuser concepts
|
||||||
@ -199,7 +197,7 @@ class EmbeddingDatabase:
|
|||||||
if not os.path.isdir(embdir.path):
|
if not os.path.isdir(embdir.path):
|
||||||
return
|
return
|
||||||
|
|
||||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
for root, _, fns in os.walk(embdir.path, followlinks=True):
|
||||||
for fn in fns:
|
for fn in fns:
|
||||||
try:
|
try:
|
||||||
fullfn = os.path.join(root, fn)
|
fullfn = os.path.join(root, fn)
|
||||||
@ -216,7 +214,7 @@ class EmbeddingDatabase:
|
|||||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||||
if not force_reload:
|
if not force_reload:
|
||||||
need_reload = False
|
need_reload = False
|
||||||
for path, embdir in self.embedding_dirs.items():
|
for embdir in self.embedding_dirs.values():
|
||||||
if embdir.has_changed():
|
if embdir.has_changed():
|
||||||
need_reload = True
|
need_reload = True
|
||||||
break
|
break
|
||||||
@ -229,7 +227,7 @@ class EmbeddingDatabase:
|
|||||||
self.skipped_embeddings.clear()
|
self.skipped_embeddings.clear()
|
||||||
self.expected_shape = self.get_expected_shape()
|
self.expected_shape = self.get_expected_shape()
|
||||||
|
|
||||||
for path, embdir in self.embedding_dirs.items():
|
for embdir in self.embedding_dirs.values():
|
||||||
self.load_from_dir(embdir)
|
self.load_from_dir(embdir)
|
||||||
embdir.update()
|
embdir.update()
|
||||||
|
|
||||||
@ -325,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
|
|||||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||||
|
|
||||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||||
tensorboard_writer.add_scalar(tag=tag,
|
tensorboard_writer.add_scalar(tag=tag,
|
||||||
scalar_value=value, global_step=step)
|
scalar_value=value, global_step=step)
|
||||||
|
|
||||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||||
# Convert a pil image to a torch tensor
|
# Convert a pil image to a torch tensor
|
||||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||||
len(pil_image.getbands()))
|
len(pil_image.getbands()))
|
||||||
img_tensor = img_tensor.permute((2, 0, 1))
|
img_tensor = img_tensor.permute((2, 0, 1))
|
||||||
|
|
||||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
@ -404,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
if initial_step >= steps:
|
if initial_step >= steps:
|
||||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||||
@ -414,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
if shared.opts.training_enable_tensorboard:
|
if shared.opts.training_enable_tensorboard:
|
||||||
tensorboard_writer = tensorboard_setup(log_directory)
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
|
||||||
@ -437,11 +435,11 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
|
||||||
if shared.opts.save_optimizer_state:
|
if shared.opts.save_optimizer_state:
|
||||||
optimizer_state_dict = None
|
optimizer_state_dict = None
|
||||||
if os.path.exists(filename + '.optim'):
|
if os.path.exists(f"{filename}.optim"):
|
||||||
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
|
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
||||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
|
|
||||||
if optimizer_state_dict is not None:
|
if optimizer_state_dict is not None:
|
||||||
optimizer.load_state_dict(optimizer_state_dict)
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
print("Loaded existing optimizer from checkpoint")
|
print("Loaded existing optimizer from checkpoint")
|
||||||
@ -470,7 +468,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
try:
|
try:
|
||||||
sd_hijack_checkpoint.add()
|
sd_hijack_checkpoint.add()
|
||||||
|
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
@ -487,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(embedding.step)
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -515,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
@ -599,17 +597,17 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
data = torch.load(last_saved_file)
|
data = torch.load(last_saved_file)
|
||||||
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
info.add_text("sd-ti-embedding", embedding_to_b64(data))
|
||||||
|
|
||||||
title = "<{}>".format(data.get('name', '???'))
|
title = f"<{data.get('name', '???')}>"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
vectorSize = list(data['string_to_param'].values())[0].shape[0]
|
||||||
except Exception as e:
|
except Exception:
|
||||||
vectorSize = '?'
|
vectorSize = '?'
|
||||||
|
|
||||||
checkpoint = sd_models.select_checkpoint()
|
checkpoint = sd_models.select_checkpoint()
|
||||||
footer_left = checkpoint.model_name
|
footer_left = checkpoint.model_name
|
||||||
footer_mid = '[{}]'.format(checkpoint.shorthash)
|
footer_mid = f'[{checkpoint.shorthash}]'
|
||||||
footer_right = '{}v {}s'.format(vectorSize, steps_done)
|
footer_right = f'{vectorSize}v {steps_done}s'
|
||||||
|
|
||||||
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
|
||||||
captioned_image = insert_image_data_embed(captioned_image, data)
|
captioned_image = insert_image_data_embed(captioned_image, data)
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers, processing
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
|
||||||
StableDiffusionProcessingImg2Img, process_images
|
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, override_settings_texts, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||||
@ -53,7 +50,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||||
|
|
||||||
if processed is None:
|
if processed is None:
|
||||||
processed = process_images(p)
|
processed = processing.process_images(p)
|
||||||
|
|
||||||
p.close()
|
p.close()
|
||||||
|
|
||||||
|
246
modules/ui.py
246
modules/ui.py
@ -1,29 +1,23 @@
|
|||||||
import html
|
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import platform
|
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from functools import partial, reduce
|
from functools import reduce
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import gradio.routes
|
import gradio.routes
|
||||||
import gradio.utils
|
import gradio.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
|
||||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path, data_path
|
from modules.paths import script_path, data_path
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, restricted_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
|
||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
@ -34,7 +28,6 @@ import modules.shared as shared
|
|||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.textual_inversion.ui
|
import modules.textual_inversion.ui
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.images import save_image
|
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.textual_inversion import textual_inversion
|
from modules.textual_inversion import textual_inversion
|
||||||
@ -93,16 +86,6 @@ def send_gradio_gallery_to_image(x):
|
|||||||
return None
|
return None
|
||||||
return image_from_url_text(x[0])
|
return image_from_url_text(x[0])
|
||||||
|
|
||||||
def visit(x, func, path=""):
|
|
||||||
if hasattr(x, 'children'):
|
|
||||||
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
|
||||||
# Tabs element can't have a label, have to use elem_id instead
|
|
||||||
func(f"{path}/Tabs@{x.elem_id}", x)
|
|
||||||
for c in x.children:
|
|
||||||
visit(c, func, path)
|
|
||||||
elif x.label is not None:
|
|
||||||
func(path + "/" + str(x.label), x)
|
|
||||||
|
|
||||||
|
|
||||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||||
if name is None:
|
if name is None:
|
||||||
@ -166,7 +149,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
|
|||||||
img = Image.open(image)
|
img = Image.open(image)
|
||||||
filename = os.path.basename(image)
|
filename = os.path.basename(image)
|
||||||
left, _ = os.path.splitext(filename)
|
left, _ = os.path.splitext(filename)
|
||||||
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
|
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a'))
|
||||||
|
|
||||||
return [gr.update(), None]
|
return [gr.update(), None]
|
||||||
|
|
||||||
@ -182,29 +165,29 @@ def interrogate_deepbooru(image):
|
|||||||
|
|
||||||
|
|
||||||
def create_seed_inputs(target_interface):
|
def create_seed_inputs(target_interface):
|
||||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
||||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
||||||
seed.style(container=False)
|
seed.style(container=False)
|
||||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
|
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
||||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
|
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
||||||
|
|
||||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
|
||||||
|
|
||||||
# Components to show/hide based on the 'Extra' checkbox
|
# Components to show/hide based on the 'Extra' checkbox
|
||||||
seed_extras = []
|
seed_extras = []
|
||||||
|
|
||||||
with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1:
|
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
||||||
seed_extras.append(seed_extra_row_1)
|
seed_extras.append(seed_extra_row_1)
|
||||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
|
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
||||||
subseed.style(container=False)
|
subseed.style(container=False)
|
||||||
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
|
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
||||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
|
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
||||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
||||||
|
|
||||||
with FormRow(visible=False) as seed_extra_row_2:
|
with FormRow(visible=False) as seed_extra_row_2:
|
||||||
seed_extras.append(seed_extra_row_2)
|
seed_extras.append(seed_extra_row_2)
|
||||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w')
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
|
||||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h')
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
|
||||||
|
|
||||||
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
|
||||||
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
|
||||||
@ -246,7 +229,7 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
all_seeds = gen_info.get('all_seeds', [-1])
|
all_seeds = gen_info.get('all_seeds', [-1])
|
||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError as e:
|
except json.decoder.JSONDecodeError:
|
||||||
if gen_info_string != '':
|
if gen_info_string != '':
|
||||||
print("Error parsing JSON generation info:", file=sys.stderr)
|
print("Error parsing JSON generation info:", file=sys.stderr)
|
||||||
print(gen_info_string, file=sys.stderr)
|
print(gen_info_string, file=sys.stderr)
|
||||||
@ -423,7 +406,7 @@ def create_sampler_and_steps_selection(choices, tabname):
|
|||||||
def ordered_ui_categories():
|
def ordered_ui_categories():
|
||||||
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
|
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))}
|
||||||
|
|
||||||
for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):
|
||||||
yield category
|
yield category
|
||||||
|
|
||||||
|
|
||||||
@ -736,8 +719,8 @@ def create_ui():
|
|||||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
gr.HTML(
|
gr.HTML(
|
||||||
f"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||||
f"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||||
f"{hidden}</p>"
|
f"{hidden}</p>"
|
||||||
)
|
)
|
||||||
@ -746,7 +729,6 @@ def create_ui():
|
|||||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||||
|
|
||||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||||
img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch]
|
|
||||||
|
|
||||||
for i, tab in enumerate(img2img_tabs):
|
for i, tab in enumerate(img2img_tabs):
|
||||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||||
@ -765,7 +747,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
button.click(
|
button.click(
|
||||||
fn=lambda: None,
|
fn=lambda: None,
|
||||||
_js="switch_to_"+name.replace(" ", "_"),
|
_js=f"switch_to_{name.replace(' ', '_')}",
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
@ -1189,7 +1171,7 @@ def create_ui():
|
|||||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
||||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
||||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
||||||
|
|
||||||
with gr.Column(visible=False) as process_multicrop_col:
|
with gr.Column(visible=False) as process_multicrop_col:
|
||||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1201,7 +1183,7 @@ def create_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
||||||
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
gr.HTML(value="")
|
gr.HTML(value="")
|
||||||
@ -1230,7 +1212,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_textual_inversion_template_names():
|
def get_textual_inversion_template_names():
|
||||||
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
return sorted(textual_inversion.textual_inversion_templates)
|
||||||
|
|
||||||
with gr.Tab(label="Train", id="train"):
|
with gr.Tab(label="Train", id="train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
@ -1238,13 +1220,13 @@ def create_ui():
|
|||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
|
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||||
@ -1290,8 +1272,8 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Column(elem_id='ti_gallery_container'):
|
with gr.Column(elem_id='ti_gallery_container'):
|
||||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
gr.HTML(elem_id="ti_progress", value="")
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
|
||||||
create_embedding.click(
|
create_embedding.click(
|
||||||
@ -1462,23 +1444,25 @@ def create_ui():
|
|||||||
elif t == bool:
|
elif t == bool:
|
||||||
comp = gr.Checkbox
|
comp = gr.Checkbox
|
||||||
else:
|
else:
|
||||||
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
raise Exception(f'bad options item type: {t} for key {key}')
|
||||||
|
|
||||||
elem_id = "setting_"+key
|
elem_id = f"setting_{key}"
|
||||||
|
|
||||||
if info.refresh is not None:
|
if info.refresh is not None:
|
||||||
if is_quicksettings:
|
if is_quicksettings:
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||||
else:
|
else:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
|
create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
|
||||||
else:
|
else:
|
||||||
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||||
|
|
||||||
components = []
|
components = []
|
||||||
component_dict = {}
|
component_dict = {}
|
||||||
shared.settings_components = component_dict
|
shared.settings_components = component_dict
|
||||||
@ -1525,7 +1509,7 @@ def create_ui():
|
|||||||
|
|
||||||
result = gr.HTML(elem_id="settings_result")
|
result = gr.HTML(elem_id="settings_result")
|
||||||
|
|
||||||
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
|
quicksettings_names = opts.quicksettings_list
|
||||||
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
|
quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
|
||||||
|
|
||||||
quicksettings_list = []
|
quicksettings_list = []
|
||||||
@ -1545,7 +1529,7 @@ def create_ui():
|
|||||||
current_tab.__exit__()
|
current_tab.__exit__()
|
||||||
|
|
||||||
gr.Group()
|
gr.Group()
|
||||||
current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text)
|
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||||
current_tab.__enter__()
|
current_tab.__enter__()
|
||||||
current_row = gr.Column(variant='compact')
|
current_row = gr.Column(variant='compact')
|
||||||
current_row.__enter__()
|
current_row.__enter__()
|
||||||
@ -1566,7 +1550,10 @@ def create_ui():
|
|||||||
current_row.__exit__()
|
current_row.__exit__()
|
||||||
current_tab.__exit__()
|
current_tab.__exit__()
|
||||||
|
|
||||||
with gr.TabItem("Actions", id="actions"):
|
with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
|
||||||
|
loadsave.create_ui()
|
||||||
|
|
||||||
|
with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
@ -1574,11 +1561,11 @@ def create_ui():
|
|||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||||
|
|
||||||
with gr.TabItem("Licenses", id="licenses"):
|
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
|
||||||
|
|
||||||
def unload_sd_weights():
|
def unload_sd_weights():
|
||||||
modules.sd_models.unload_model_weights()
|
modules.sd_models.unload_model_weights()
|
||||||
@ -1639,7 +1626,7 @@ def create_ui():
|
|||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
(train_interface, "Train", "ti"),
|
(train_interface, "Train", "train"),
|
||||||
]
|
]
|
||||||
|
|
||||||
interfaces += script_callbacks.ui_tabs_callback()
|
interfaces += script_callbacks.ui_tabs_callback()
|
||||||
@ -1654,7 +1641,7 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||||
component = create_setting_component(k, is_quicksettings=True)
|
component = create_setting_component(k, is_quicksettings=True)
|
||||||
component_dict[k] = component
|
component_dict[k] = component
|
||||||
|
|
||||||
@ -1664,11 +1651,21 @@ def create_ui():
|
|||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
if label in shared.opts.hidden_tabs:
|
if label in shared.opts.hidden_tabs:
|
||||||
continue
|
continue
|
||||||
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
|
with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
|
||||||
interface.render()
|
interface.render()
|
||||||
|
|
||||||
|
for interface, _label, ifid in interfaces:
|
||||||
|
if ifid in ["extensions", "settings"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loadsave.add_block(interface, ifid)
|
||||||
|
|
||||||
|
loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)
|
||||||
|
|
||||||
|
loadsave.setup_ui()
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||||
audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||||
|
|
||||||
footer = shared.html("footer.html")
|
footer = shared.html("footer.html")
|
||||||
footer = footer.format(versions=versions_html())
|
footer = footer.format(versions=versions_html())
|
||||||
@ -1681,7 +1678,7 @@ def create_ui():
|
|||||||
outputs=[text_settings, result],
|
outputs=[text_settings, result],
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, k, item in quicksettings_list:
|
for _i, k, _item in quicksettings_list:
|
||||||
component = component_dict[k]
|
component = component_dict[k]
|
||||||
info = opts.data_labels[k]
|
info = opts.data_labels[k]
|
||||||
|
|
||||||
@ -1755,97 +1752,8 @@ def create_ui():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
ui_config_file = cmd_opts.ui_config_file
|
loadsave.dump_defaults()
|
||||||
ui_settings = {}
|
demo.ui_loadsave = loadsave
|
||||||
settings_count = len(ui_settings)
|
|
||||||
error_loading = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
if os.path.exists(ui_config_file):
|
|
||||||
with open(ui_config_file, "r", encoding="utf8") as file:
|
|
||||||
ui_settings = json.load(file)
|
|
||||||
except Exception:
|
|
||||||
error_loading = True
|
|
||||||
print("Error loading settings:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
def loadsave(path, x):
|
|
||||||
def apply_field(obj, field, condition=None, init_field=None):
|
|
||||||
key = path + "/" + field
|
|
||||||
|
|
||||||
if getattr(obj, 'custom_script_source', None) is not None:
|
|
||||||
key = 'customscript/' + obj.custom_script_source + '/' + key
|
|
||||||
|
|
||||||
if getattr(obj, 'do_not_save_to_config', False):
|
|
||||||
return
|
|
||||||
|
|
||||||
saved_value = ui_settings.get(key, None)
|
|
||||||
if saved_value is None:
|
|
||||||
ui_settings[key] = getattr(obj, field)
|
|
||||||
elif condition and not condition(saved_value):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# this warning is generally not useful;
|
|
||||||
# print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
|
|
||||||
else:
|
|
||||||
setattr(obj, field, saved_value)
|
|
||||||
if init_field is not None:
|
|
||||||
init_field(saved_value)
|
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
|
||||||
apply_field(x, 'visible')
|
|
||||||
|
|
||||||
if type(x) == gr.Slider:
|
|
||||||
apply_field(x, 'value')
|
|
||||||
apply_field(x, 'minimum')
|
|
||||||
apply_field(x, 'maximum')
|
|
||||||
apply_field(x, 'step')
|
|
||||||
|
|
||||||
if type(x) == gr.Radio:
|
|
||||||
apply_field(x, 'value', lambda val: val in x.choices)
|
|
||||||
|
|
||||||
if type(x) == gr.Checkbox:
|
|
||||||
apply_field(x, 'value')
|
|
||||||
|
|
||||||
if type(x) == gr.Textbox:
|
|
||||||
apply_field(x, 'value')
|
|
||||||
|
|
||||||
if type(x) == gr.Number:
|
|
||||||
apply_field(x, 'value')
|
|
||||||
|
|
||||||
if type(x) == gr.Dropdown:
|
|
||||||
def check_dropdown(val):
|
|
||||||
if getattr(x, 'multiselect', False):
|
|
||||||
return all([value in x.choices for value in val])
|
|
||||||
else:
|
|
||||||
return val in x.choices
|
|
||||||
|
|
||||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
|
||||||
|
|
||||||
def check_tab_id(tab_id):
|
|
||||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
|
||||||
if type(tab_id) == str:
|
|
||||||
tab_ids = [t.id for t in tab_items]
|
|
||||||
return tab_id in tab_ids
|
|
||||||
elif type(tab_id) == int:
|
|
||||||
return tab_id >= 0 and tab_id < len(tab_items)
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if type(x) == gr.Tabs:
|
|
||||||
apply_field(x, 'selected', check_tab_id)
|
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
|
||||||
visit(extras_interface, loadsave, "extras")
|
|
||||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
|
||||||
visit(train_interface, loadsave, "train")
|
|
||||||
|
|
||||||
loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
|
|
||||||
|
|
||||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
|
||||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
|
||||||
json.dump(ui_settings, file, indent=4)
|
|
||||||
|
|
||||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||||
interp_description.value = update_interp_description(interp_method.value)
|
interp_description.value = update_interp_description(interp_method.value)
|
||||||
@ -1923,7 +1831,7 @@ def versions_html():
|
|||||||
|
|
||||||
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
||||||
commit = launch.commit_hash()
|
commit = launch.commit_hash()
|
||||||
short_commit = commit[0:8]
|
tag = launch.git_tag()
|
||||||
|
|
||||||
if shared.xformers_available:
|
if shared.xformers_available:
|
||||||
import xformers
|
import xformers
|
||||||
@ -1932,15 +1840,31 @@ def versions_html():
|
|||||||
xformers_version = "N/A"
|
xformers_version = "N/A"
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
|
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
||||||
|
•
|
||||||
python: <span title="{sys.version}">{python_version}</span>
|
python: <span title="{sys.version}">{python_version}</span>
|
||||||
•
|
•
|
||||||
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
||||||
•
|
•
|
||||||
xformers: {xformers_version}
|
xformers: {xformers_version}
|
||||||
•
|
•
|
||||||
gradio: {gr.__version__}
|
gradio: {gr.__version__}
|
||||||
•
|
•
|
||||||
commit: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{short_commit}</a>
|
|
||||||
•
|
|
||||||
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def setup_ui_api(app):
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class QuicksettingsHint(BaseModel):
|
||||||
|
name: str = Field(title="Name of the quicksettings field")
|
||||||
|
label: str = Field(title="Label of the quicksettings field")
|
||||||
|
|
||||||
|
def quicksettings_hint():
|
||||||
|
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
||||||
|
|
||||||
|
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
||||||
|
|
||||||
|
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||||
|
@ -61,7 +61,8 @@ def save_config_state(name):
|
|||||||
if not name:
|
if not name:
|
||||||
name = "Config"
|
name = "Config"
|
||||||
current_config_state["name"] = name
|
current_config_state["name"] = name
|
||||||
filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json")
|
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
|
||||||
|
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
|
||||||
print(f"Saving backup of webui/extension state to {filename}.")
|
print(f"Saving backup of webui/extension state to {filename}.")
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(current_config_state, f)
|
json.dump(current_config_state, f)
|
||||||
@ -466,7 +467,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||||
<td>{install_code}</td>
|
<td>{install_code}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for tag in [x for x in extension_tags if x not in tags]:
|
for tag in [x for x in extension_tags if x not in tags]:
|
||||||
@ -489,7 +490,7 @@ def create_ui():
|
|||||||
config_states.list_config_states()
|
config_states.list_config_states()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as ui:
|
with gr.Blocks(analytics_enabled=False) as ui:
|
||||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
with gr.Tabs(elem_id="tabs_extensions"):
|
||||||
with gr.TabItem("Installed", id="installed"):
|
with gr.TabItem("Installed", id="installed"):
|
||||||
|
|
||||||
with gr.Row(elem_id="extensions_installed_top"):
|
with gr.Row(elem_id="extensions_installed_top"):
|
||||||
@ -534,9 +535,9 @@ def create_ui():
|
|||||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||||
|
|
||||||
install_result = gr.HTML()
|
install_result = gr.HTML()
|
||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os.path
|
import os.path
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -27,7 +26,7 @@ def register_page(page):
|
|||||||
def fetch_file(filename: str = ""):
|
def fetch_file(filename: str = ""):
|
||||||
from starlette.responses import FileResponse
|
from starlette.responses import FileResponse
|
||||||
|
|
||||||
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
|
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||||
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
@ -69,7 +68,9 @@ class ExtraNetworksPage:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def link_preview(self, filename):
|
def link_preview(self, filename):
|
||||||
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
|
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
||||||
|
mtime = os.path.getmtime(filename)
|
||||||
|
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
||||||
|
|
||||||
def search_terms_from_path(self, filename, possible_directories=None):
|
def search_terms_from_path(self, filename, possible_directories=None):
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
@ -89,19 +90,22 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
subdirs = {}
|
subdirs = {}
|
||||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
||||||
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
|
for root, dirs, _ in os.walk(parentdir, followlinks=True):
|
||||||
if not os.path.isdir(x):
|
for dirname in dirs:
|
||||||
continue
|
x = os.path.join(root, dirname)
|
||||||
|
|
||||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
if not os.path.isdir(x):
|
||||||
while subdir.startswith("/"):
|
continue
|
||||||
subdir = subdir[1:]
|
|
||||||
|
|
||||||
is_empty = len(os.listdir(x)) == 0
|
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
||||||
if not is_empty and not subdir.endswith("/"):
|
while subdir.startswith("/"):
|
||||||
subdir = subdir + "/"
|
subdir = subdir[1:]
|
||||||
|
|
||||||
subdirs[subdir] = 1
|
is_empty = len(os.listdir(x)) == 0
|
||||||
|
if not is_empty and not subdir.endswith("/"):
|
||||||
|
subdir = subdir + "/"
|
||||||
|
|
||||||
|
subdirs[subdir] = 1
|
||||||
|
|
||||||
if subdirs:
|
if subdirs:
|
||||||
subdirs = {"": 1, **subdirs}
|
subdirs = {"": 1, **subdirs}
|
||||||
@ -157,8 +161,20 @@ class ExtraNetworksPage:
|
|||||||
if metadata:
|
if metadata:
|
||||||
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
|
||||||
|
|
||||||
|
local_path = ""
|
||||||
|
filename = item.get("filename", "")
|
||||||
|
for reldir in self.allowed_directories_for_previews():
|
||||||
|
absdir = os.path.abspath(reldir)
|
||||||
|
|
||||||
|
if filename.startswith(absdir):
|
||||||
|
local_path = filename[len(absdir):]
|
||||||
|
|
||||||
|
# if this is true, the item must not be show in the default view, and must instead only be
|
||||||
|
# shown when searching for it
|
||||||
|
serach_only = "/." in local_path or "\\." in local_path
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"style": f"'{height}{width}{background_image}'",
|
"style": f"'display: none; {height}{width}{background_image}'",
|
||||||
"prompt": item.get("prompt", None),
|
"prompt": item.get("prompt", None),
|
||||||
"tabname": json.dumps(tabname),
|
"tabname": json.dumps(tabname),
|
||||||
"local_preview": json.dumps(item["local_preview"]),
|
"local_preview": json.dumps(item["local_preview"]),
|
||||||
@ -168,6 +184,7 @@ class ExtraNetworksPage:
|
|||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
|
||||||
"search_term": item.get("search_term", ""),
|
"search_term": item.get("search_term", ""),
|
||||||
"metadata_button": metadata_button,
|
"metadata_button": metadata_button,
|
||||||
|
"serach_only": " search_only" if serach_only else "",
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.card_page.format(**args)
|
return self.card_page.format(**args)
|
||||||
@ -209,6 +226,11 @@ def intialize():
|
|||||||
class ExtraNetworksUi:
|
class ExtraNetworksUi:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pages = None
|
self.pages = None
|
||||||
|
"""gradio HTML components related to extra networks' pages"""
|
||||||
|
|
||||||
|
self.page_contents = None
|
||||||
|
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
|
||||||
|
|
||||||
self.stored_extra_pages = None
|
self.stored_extra_pages = None
|
||||||
|
|
||||||
self.button_save_preview = None
|
self.button_save_preview = None
|
||||||
@ -236,17 +258,22 @@ def pages_in_preferred_order(pages):
|
|||||||
def create_ui(container, button, tabname):
|
def create_ui(container, button, tabname):
|
||||||
ui = ExtraNetworksUi()
|
ui = ExtraNetworksUi()
|
||||||
ui.pages = []
|
ui.pages = []
|
||||||
|
ui.pages_contents = []
|
||||||
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
|
||||||
ui.tabname = tabname
|
ui.tabname = tabname
|
||||||
|
|
||||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
with gr.Tabs(elem_id=tabname+"_extra_tabs"):
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
|
page_id = page.title.lower().replace(" ", "_")
|
||||||
|
|
||||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
with gr.Tab(page.title, id=page_id):
|
||||||
|
elem_id = f"{tabname}_{page_id}_cards_html"
|
||||||
|
page_elem = gr.HTML('', elem_id=elem_id)
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
|
|
||||||
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
|
||||||
|
|
||||||
|
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||||
@ -254,19 +281,22 @@ def create_ui(container, button, tabname):
|
|||||||
|
|
||||||
def toggle_visibility(is_visible):
|
def toggle_visibility(is_visible):
|
||||||
is_visible = not is_visible
|
is_visible = not is_visible
|
||||||
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
|
|
||||||
|
if is_visible and not ui.pages_contents:
|
||||||
|
refresh()
|
||||||
|
|
||||||
|
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
|
||||||
|
|
||||||
state_visible = gr.State(value=False)
|
state_visible = gr.State(value=False)
|
||||||
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
|
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
|
||||||
|
|
||||||
def refresh():
|
def refresh():
|
||||||
res = []
|
|
||||||
|
|
||||||
for pg in ui.stored_extra_pages:
|
for pg in ui.stored_extra_pages:
|
||||||
pg.refresh()
|
pg.refresh()
|
||||||
res.append(pg.create_html(ui.tabname))
|
|
||||||
|
|
||||||
return res
|
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||||
|
|
||||||
|
return ui.pages_contents
|
||||||
|
|
||||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||||
|
|
||||||
@ -296,7 +326,7 @@ def setup_ui(ui, gallery):
|
|||||||
|
|
||||||
is_allowed = False
|
is_allowed = False
|
||||||
for extra_page in ui.stored_extra_pages:
|
for extra_page in ui.stored_extra_pages:
|
||||||
if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]):
|
if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):
|
||||||
is_allowed = True
|
is_allowed = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
208
modules/ui_loadsave.py
Normal file
208
modules/ui_loadsave.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import errors
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class UiLoadsave:
|
||||||
|
"""allows saving and restorig default values for gradio components"""
|
||||||
|
|
||||||
|
def __init__(self, filename):
|
||||||
|
self.filename = filename
|
||||||
|
self.ui_settings = {}
|
||||||
|
self.component_mapping = {}
|
||||||
|
self.error_loading = False
|
||||||
|
self.finalized_ui = False
|
||||||
|
|
||||||
|
self.ui_defaults_view = None
|
||||||
|
self.ui_defaults_apply = None
|
||||||
|
self.ui_defaults_review = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.filename):
|
||||||
|
self.ui_settings = self.read_from_file()
|
||||||
|
except Exception as e:
|
||||||
|
self.error_loading = True
|
||||||
|
errors.display(e, "loading settings")
|
||||||
|
|
||||||
|
def add_component(self, path, x):
|
||||||
|
"""adds component to the registry of tracked components"""
|
||||||
|
|
||||||
|
assert not self.finalized_ui
|
||||||
|
|
||||||
|
def apply_field(obj, field, condition=None, init_field=None):
|
||||||
|
key = f"{path}/{field}"
|
||||||
|
|
||||||
|
if getattr(obj, 'custom_script_source', None) is not None:
|
||||||
|
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||||
|
|
||||||
|
if getattr(obj, 'do_not_save_to_config', False):
|
||||||
|
return
|
||||||
|
|
||||||
|
saved_value = self.ui_settings.get(key, None)
|
||||||
|
if saved_value is None:
|
||||||
|
self.ui_settings[key] = getattr(obj, field)
|
||||||
|
elif condition and not condition(saved_value):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
setattr(obj, field, saved_value)
|
||||||
|
if init_field is not None:
|
||||||
|
init_field(saved_value)
|
||||||
|
|
||||||
|
if field == 'value' and key not in self.component_mapping:
|
||||||
|
self.component_mapping[key] = x
|
||||||
|
|
||||||
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
||||||
|
apply_field(x, 'visible')
|
||||||
|
|
||||||
|
if type(x) == gr.Slider:
|
||||||
|
apply_field(x, 'value')
|
||||||
|
apply_field(x, 'minimum')
|
||||||
|
apply_field(x, 'maximum')
|
||||||
|
apply_field(x, 'step')
|
||||||
|
|
||||||
|
if type(x) == gr.Radio:
|
||||||
|
apply_field(x, 'value', lambda val: val in x.choices)
|
||||||
|
|
||||||
|
if type(x) == gr.Checkbox:
|
||||||
|
apply_field(x, 'value')
|
||||||
|
|
||||||
|
if type(x) == gr.Textbox:
|
||||||
|
apply_field(x, 'value')
|
||||||
|
|
||||||
|
if type(x) == gr.Number:
|
||||||
|
apply_field(x, 'value')
|
||||||
|
|
||||||
|
if type(x) == gr.Dropdown:
|
||||||
|
def check_dropdown(val):
|
||||||
|
if getattr(x, 'multiselect', False):
|
||||||
|
return all(value in x.choices for value in val)
|
||||||
|
else:
|
||||||
|
return val in x.choices
|
||||||
|
|
||||||
|
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||||
|
|
||||||
|
def check_tab_id(tab_id):
|
||||||
|
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||||
|
if type(tab_id) == str:
|
||||||
|
tab_ids = [t.id for t in tab_items]
|
||||||
|
return tab_id in tab_ids
|
||||||
|
elif type(tab_id) == int:
|
||||||
|
return 0 <= tab_id < len(tab_items)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if type(x) == gr.Tabs:
|
||||||
|
apply_field(x, 'selected', check_tab_id)
|
||||||
|
|
||||||
|
def add_block(self, x, path=""):
|
||||||
|
"""adds all components inside a gradio block x to the registry of tracked components"""
|
||||||
|
|
||||||
|
if hasattr(x, 'children'):
|
||||||
|
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
||||||
|
# Tabs element can't have a label, have to use elem_id instead
|
||||||
|
self.add_component(f"{path}/Tabs@{x.elem_id}", x)
|
||||||
|
for c in x.children:
|
||||||
|
self.add_block(c, path)
|
||||||
|
elif x.label is not None:
|
||||||
|
self.add_component(f"{path}/{x.label}", x)
|
||||||
|
|
||||||
|
def read_from_file(self):
|
||||||
|
with open(self.filename, "r", encoding="utf8") as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
def write_to_file(self, current_ui_settings):
|
||||||
|
with open(self.filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(current_ui_settings, file, indent=4)
|
||||||
|
|
||||||
|
def dump_defaults(self):
|
||||||
|
"""saves default values to a file unless tjhe file is present and there was an error loading default values at start"""
|
||||||
|
|
||||||
|
if self.error_loading and os.path.exists(self.filename):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.write_to_file(self.ui_settings)
|
||||||
|
|
||||||
|
def iter_changes(self, current_ui_settings, values):
|
||||||
|
"""
|
||||||
|
given a dictionary with defaults from a file and current values from gradio elements, returns
|
||||||
|
an iterator over tuples of values that are not the same between the file and the current;
|
||||||
|
tuple contents are: path, old value, new value
|
||||||
|
"""
|
||||||
|
|
||||||
|
for (path, component), new_value in zip(self.component_mapping.items(), values):
|
||||||
|
old_value = current_ui_settings.get(path)
|
||||||
|
|
||||||
|
choices = getattr(component, 'choices', None)
|
||||||
|
if isinstance(new_value, int) and choices:
|
||||||
|
if new_value >= len(choices):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_value = choices[new_value]
|
||||||
|
|
||||||
|
if new_value == old_value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if old_value is None and new_value == '' or new_value == []:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield path, old_value, new_value
|
||||||
|
|
||||||
|
def ui_view(self, *values):
|
||||||
|
text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]
|
||||||
|
|
||||||
|
for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
|
||||||
|
if old_value is None:
|
||||||
|
old_value = "<span class='ui-defaults-none'>None</span>"
|
||||||
|
|
||||||
|
text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")
|
||||||
|
|
||||||
|
if len(text) == 1:
|
||||||
|
text.append("<tr><td colspan=3>No changes</td></tr>")
|
||||||
|
|
||||||
|
text.append("</tbody>")
|
||||||
|
return "".join(text)
|
||||||
|
|
||||||
|
def ui_apply(self, *values):
|
||||||
|
num_changed = 0
|
||||||
|
|
||||||
|
current_ui_settings = self.read_from_file()
|
||||||
|
|
||||||
|
for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
|
||||||
|
num_changed += 1
|
||||||
|
current_ui_settings[path] = new_value
|
||||||
|
|
||||||
|
if num_changed == 0:
|
||||||
|
return "No changes."
|
||||||
|
|
||||||
|
self.write_to_file(current_ui_settings)
|
||||||
|
|
||||||
|
return f"Wrote {num_changed} changes."
|
||||||
|
|
||||||
|
def create_ui(self):
|
||||||
|
"""creates ui elements for editing defaults UI, without adding any logic to them"""
|
||||||
|
|
||||||
|
gr.HTML(
|
||||||
|
f"This page allows you to change default values in UI elements on other tabs.<br />"
|
||||||
|
f"Make your changes, press 'View changes' to review the changed default values,<br />"
|
||||||
|
f"then press 'Apply' to write them to {self.filename}.<br />"
|
||||||
|
f"New defaults will apply after you restart the UI.<br />"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
|
||||||
|
self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")
|
||||||
|
|
||||||
|
self.ui_defaults_review = gr.HTML("")
|
||||||
|
|
||||||
|
def setup_ui(self):
|
||||||
|
"""adds logic to elements created with create_ui; all add_block class must be made before this"""
|
||||||
|
|
||||||
|
assert not self.finalized_ui
|
||||||
|
self.finalized_ui = True
|
||||||
|
|
||||||
|
self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
||||||
|
self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
|
@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue
|
from modules import scripts, shared, ui_common, postprocessing, call_queue
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user