mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-29 19:05:05 +08:00
Merge branch 'dev' into npu_support
This commit is contained in:
commit
74ff85a1a1
10
.github/workflows/run_tests.yaml
vendored
10
.github/workflows/run_tests.yaml
vendored
@ -20,6 +20,12 @@ jobs:
|
|||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
**/requirements*txt
|
**/requirements*txt
|
||||||
launch.py
|
launch.py
|
||||||
|
- name: Cache models
|
||||||
|
id: cache-models
|
||||||
|
uses: actions/cache@v3
|
||||||
|
with:
|
||||||
|
path: models
|
||||||
|
key: "2023-12-30"
|
||||||
- name: Install test dependencies
|
- name: Install test dependencies
|
||||||
run: pip install wait-for-it -r requirements-test.txt
|
run: pip install wait-for-it -r requirements-test.txt
|
||||||
env:
|
env:
|
||||||
@ -33,6 +39,8 @@ jobs:
|
|||||||
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||||
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||||
PYTHONUNBUFFERED: "1"
|
PYTHONUNBUFFERED: "1"
|
||||||
|
- name: Print installed packages
|
||||||
|
run: pip freeze
|
||||||
- name: Start test server
|
- name: Start test server
|
||||||
run: >
|
run: >
|
||||||
python -m coverage run
|
python -m coverage run
|
||||||
@ -49,7 +57,7 @@ jobs:
|
|||||||
2>&1 | tee output.txt &
|
2>&1 | tee output.txt &
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
wait-for-it --service 127.0.0.1:7860 -t 600
|
wait-for-it --service 127.0.0.1:7860 -t 20
|
||||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||||
- name: Kill test server
|
- name: Kill test server
|
||||||
if: always()
|
if: always()
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -37,3 +37,4 @@ notification.mp3
|
|||||||
/node_modules
|
/node_modules
|
||||||
/package-lock.json
|
/package-lock.json
|
||||||
/.coverage*
|
/.coverage*
|
||||||
|
/test/test_outputs
|
||||||
|
13
README.md
13
README.md
@ -1,5 +1,5 @@
|
|||||||
# Stable Diffusion web UI
|
# Stable Diffusion web UI
|
||||||
A browser interface based on Gradio library for Stable Diffusion.
|
A web interface for Stable Diffusion, implemented using Gradio library.
|
||||||
|
|
||||||
![](screenshot.png)
|
![](screenshot.png)
|
||||||
|
|
||||||
@ -151,11 +151,12 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
|
|
||||||
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
|
||||||
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git
|
||||||
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing
|
||||||
- CodeFormer - https://github.com/sczhou/CodeFormer
|
- GFPGAN - https://github.com/TencentARC/GFPGAN.git
|
||||||
- ESRGAN - https://github.com/xinntao/ESRGAN
|
- CodeFormer - https://github.com/sczhou/CodeFormer
|
||||||
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
- ESRGAN - https://github.com/xinntao/ESRGAN
|
||||||
- Swin2SR - https://github.com/mv-lab/swin2sr
|
- SwinIR - https://github.com/JingyunLiang/SwinIR
|
||||||
|
- Swin2SR - https://github.com/mv-lab/swin2sr
|
||||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||||
- MiDaS - https://github.com/isl-org/MiDaS
|
- MiDaS - https://github.com/isl-org/MiDaS
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
|
98
configs/sd_xl_inpaint.yaml
Normal file
98
configs/sd_xl_inpaint.yaml
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.13025
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
|
||||||
|
params:
|
||||||
|
num_idx: 1000
|
||||||
|
|
||||||
|
weighting_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
adm_in_channels: 2816
|
||||||
|
num_classes: sequential
|
||||||
|
use_checkpoint: True
|
||||||
|
in_channels: 9
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [4, 2]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [1, 2, 4]
|
||||||
|
num_head_channels: 64
|
||||||
|
use_spatial_transformer: True
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
|
||||||
|
context_dim: 2048
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
# crossattn cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
params:
|
||||||
|
layer: hidden
|
||||||
|
layer_idx: 11
|
||||||
|
# crossattn and vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: txt
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
|
||||||
|
params:
|
||||||
|
arch: ViT-bigG-14
|
||||||
|
version: laion2b_s39b_b160k
|
||||||
|
freeze: True
|
||||||
|
layer: penultimate
|
||||||
|
always_return_pooled: True
|
||||||
|
legacy: False
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: original_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: crop_coords_top_left
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
# vector cond
|
||||||
|
- is_trainable: False
|
||||||
|
input_key: target_size_as_tuple
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256 # multiplied by two
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
@ -3,6 +3,9 @@ import os
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from modules import sd_models, cache, errors, hashes, shared
|
from modules import sd_models, cache, errors, hashes, shared
|
||||||
|
|
||||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||||
@ -115,6 +118,29 @@ class NetworkModule:
|
|||||||
if hasattr(self.sd_module, 'weight'):
|
if hasattr(self.sd_module, 'weight'):
|
||||||
self.shape = self.sd_module.weight.shape
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.ops = None
|
||||||
|
self.extra_kwargs = {}
|
||||||
|
if isinstance(self.sd_module, nn.Conv2d):
|
||||||
|
self.ops = F.conv2d
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'stride': self.sd_module.stride,
|
||||||
|
'padding': self.sd_module.padding
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.Linear):
|
||||||
|
self.ops = F.linear
|
||||||
|
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||||
|
self.ops = F.layer_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'normalized_shape': self.sd_module.normalized_shape,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||||
|
self.ops = F.group_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'num_groups': self.sd_module.num_groups,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
|
||||||
self.dim = None
|
self.dim = None
|
||||||
self.bias = weights.w.get("bias")
|
self.bias = weights.w.get("bias")
|
||||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
@ -137,7 +163,7 @@ class NetworkModule:
|
|||||||
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
|
||||||
updown = updown.reshape(output_shape)
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
if len(output_shape) == 4:
|
if len(output_shape) == 4:
|
||||||
@ -155,5 +181,10 @@ class NetworkModule:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
|
"""A general forward implementation for all modules"""
|
||||||
|
if self.ops is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||||
|
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||||
|
|
||||||
|
@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule):
|
|||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
output_shape = self.weight.shape
|
output_shape = self.weight.shape
|
||||||
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown = self.weight.to(orig_weight.device)
|
||||||
if self.ex_bias is not None:
|
if self.ex_bias is not None:
|
||||||
ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
ex_bias = self.ex_bias.to(orig_weight.device)
|
||||||
else:
|
else:
|
||||||
ex_bias = None
|
ex_bias = None
|
||||||
|
|
||||||
|
@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule):
|
|||||||
self.w2b = weights.w["b2.weight"]
|
self.w2b = weights.w["b2.weight"]
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
|
||||||
output_shape = [w1a.size(0), w1b.size(1)]
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))
|
||||||
|
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule):
|
|||||||
self.t2 = weights.w.get("hada_t2")
|
self.t2 = weights.w.get("hada_t2")
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
|
|
||||||
output_shape = [w1a.size(0), w1b.size(1)]
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
|
||||||
if self.t1 is not None:
|
if self.t1 is not None:
|
||||||
output_shape = [w1a.size(1), w1b.size(1)]
|
output_shape = [w1a.size(1), w1b.size(1)]
|
||||||
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
t1 = self.t1.to(orig_weight.device)
|
||||||
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||||
output_shape += t1.shape[2:]
|
output_shape += t1.shape[2:]
|
||||||
else:
|
else:
|
||||||
@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule):
|
|||||||
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||||
|
|
||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
t2 = self.t2.to(orig_weight.device)
|
||||||
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
else:
|
else:
|
||||||
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||||
|
@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule):
|
|||||||
self.on_input = weights.w["on_input"].item()
|
self.on_input = weights.w["on_input"].item()
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
w = self.w.to(orig_weight.device)
|
||||||
|
|
||||||
output_shape = [w.size(0), orig_weight.size(1)]
|
output_shape = [w.size(0), orig_weight.size(1)]
|
||||||
if self.on_input:
|
if self.on_input:
|
||||||
|
@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule):
|
|||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
if self.w1 is not None:
|
if self.w1 is not None:
|
||||||
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1 = self.w1.to(orig_weight.device)
|
||||||
else:
|
else:
|
||||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1a = self.w1a.to(orig_weight.device)
|
||||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w1b = self.w1b.to(orig_weight.device)
|
||||||
w1 = w1a @ w1b
|
w1 = w1a @ w1b
|
||||||
|
|
||||||
if self.w2 is not None:
|
if self.w2 is not None:
|
||||||
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2 = self.w2.to(orig_weight.device)
|
||||||
elif self.t2 is None:
|
elif self.t2 is None:
|
||||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
w2 = w2a @ w2b
|
w2 = w2a @ w2b
|
||||||
else:
|
else:
|
||||||
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
t2 = self.t2.to(orig_weight.device)
|
||||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2a = self.w2a.to(orig_weight.device)
|
||||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
w2b = self.w2b.to(orig_weight.device)
|
||||||
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
|
|
||||||
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||||
|
@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
up = self.up_model.weight.to(orig_weight.device)
|
||||||
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
down = self.down_model.weight.to(orig_weight.device)
|
||||||
|
|
||||||
output_shape = [up.size(0), down.size(1)]
|
output_shape = [up.size(0), down.size(1)]
|
||||||
if self.mid_model is not None:
|
if self.mid_model is not None:
|
||||||
# cp-decomposition
|
# cp-decomposition
|
||||||
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
mid = self.mid_model.weight.to(orig_weight.device)
|
||||||
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||||
output_shape += mid.shape[2:]
|
output_shape += mid.shape[2:]
|
||||||
else:
|
else:
|
||||||
|
@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule):
|
|||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
output_shape = self.w_norm.shape
|
output_shape = self.w_norm.shape
|
||||||
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown = self.w_norm.to(orig_weight.device)
|
||||||
|
|
||||||
if self.b_norm is not None:
|
if self.b_norm is not None:
|
||||||
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
ex_bias = self.b_norm.to(orig_weight.device)
|
||||||
else:
|
else:
|
||||||
ex_bias = None
|
ex_bias = None
|
||||||
|
|
||||||
|
@ -56,17 +56,17 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||||
|
|
||||||
def calc_updown(self, orig_weight):
|
def calc_updown(self, orig_weight):
|
||||||
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
oft_blocks = self.oft_blocks.to(orig_weight.device)
|
||||||
eye = torch.eye(self.block_size, device=self.oft_blocks.device)
|
eye = torch.eye(self.block_size, device=oft_blocks.device)
|
||||||
|
|
||||||
if self.is_kohya:
|
if self.is_kohya:
|
||||||
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
|
||||||
norm_Q = torch.norm(block_Q.flatten())
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
|
||||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())
|
||||||
|
|
||||||
R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
|
R = oft_blocks.to(orig_weight.device)
|
||||||
|
|
||||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||||
@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule):
|
|||||||
)
|
)
|
||||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||||
|
|
||||||
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
|
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||||
output_shape = orig_weight.shape
|
output_shape = orig_weight.shape
|
||||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import gradio as gr
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -259,11 +260,11 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
loaded_networks.clear()
|
loaded_networks.clear()
|
||||||
|
|
||||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||||
if any(x is None for x in networks_on_disk):
|
if any(x is None for x in networks_on_disk):
|
||||||
list_available_networks()
|
list_available_networks()
|
||||||
|
|
||||||
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
failed_to_load_networks = []
|
failed_to_load_networks = []
|
||||||
|
|
||||||
@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
emb_db.skipped_embeddings[name] = embedding
|
emb_db.skipped_embeddings[name] = embedding
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
|
||||||
|
sd_hijack.model_hijack.comments.append(lora_not_found_message)
|
||||||
|
if shared.opts.lora_not_found_warning_console:
|
||||||
|
print(f'\n{lora_not_found_message}\n')
|
||||||
|
if shared.opts.lora_not_found_gradio_warning:
|
||||||
|
gr.Warning(lora_not_found_message)
|
||||||
|
|
||||||
purge_networks_from_memory()
|
purge_networks_from_memory()
|
||||||
|
|
||||||
@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
updown, ex_bias = module.calc_updown(self.weight)
|
if getattr(self, 'fp16_weight', None) is None:
|
||||||
|
weight = self.weight
|
||||||
|
bias = self.bias
|
||||||
|
else:
|
||||||
|
weight = self.fp16_weight.clone().to(self.weight.device)
|
||||||
|
bias = getattr(self, 'fp16_bias', None)
|
||||||
|
if bias is not None:
|
||||||
|
bias = bias.clone().to(self.bias.device)
|
||||||
|
updown, ex_bias = module.calc_updown(weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(weight.shape) == 4 and weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
|
||||||
if ex_bias is not None and hasattr(self, 'bias'):
|
if ex_bias is not None and hasattr(self, 'bias'):
|
||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
self.bias = torch.nn.Parameter(ex_bias)
|
self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
|
||||||
else:
|
else:
|
||||||
self.bias += ex_bias
|
self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
def network_forward(module, input, original_forward):
|
def network_forward(org_module, input, original_forward):
|
||||||
"""
|
"""
|
||||||
Old way of applying Lora by executing operations during layer's forward.
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
Stacking many loras this way results in big performance degradation.
|
Stacking many loras this way results in big performance degradation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(loaded_networks) == 0:
|
if len(loaded_networks) == 0:
|
||||||
return original_forward(module, input)
|
return original_forward(org_module, input)
|
||||||
|
|
||||||
input = devices.cond_cast_unet(input)
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
network_restore_weights_from_backup(module)
|
network_restore_weights_from_backup(org_module)
|
||||||
network_reset_cached_weight(module)
|
network_reset_cached_weight(org_module)
|
||||||
|
|
||||||
y = original_forward(module, input)
|
y = original_forward(org_module, input)
|
||||||
|
|
||||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
network_layer_name = getattr(org_module, 'network_layer_name', None)
|
||||||
for lora in loaded_networks:
|
for lora in loaded_networks:
|
||||||
module = lora.modules.get(network_layer_name, None)
|
module = lora.modules.get(network_layer_name, None)
|
||||||
if module is None:
|
if module is None:
|
||||||
|
@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
|
|||||||
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
|
||||||
|
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
|
||||||
|
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
self.slider_preferred_weight = None
|
self.slider_preferred_weight = None
|
||||||
self.edit_notes = None
|
self.edit_notes = None
|
||||||
|
|
||||||
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
|
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
user_metadata["description"] = desc
|
user_metadata["description"] = desc
|
||||||
user_metadata["sd version"] = sd_version
|
user_metadata["sd version"] = sd_version
|
||||||
user_metadata["activation text"] = activation_text
|
user_metadata["activation text"] = activation_text
|
||||||
user_metadata["preferred weight"] = preferred_weight
|
user_metadata["preferred weight"] = preferred_weight
|
||||||
|
user_metadata["negative text"] = negative_text
|
||||||
user_metadata["notes"] = notes
|
user_metadata["notes"] = notes
|
||||||
|
|
||||||
self.write_user_metadata(name, user_metadata)
|
self.write_user_metadata(name, user_metadata)
|
||||||
@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
||||||
user_metadata.get('activation text', ''),
|
user_metadata.get('activation text', ''),
|
||||||
float(user_metadata.get('preferred weight', 0.0)),
|
float(user_metadata.get('preferred weight', 0.0)),
|
||||||
|
user_metadata.get('negative text', ''),
|
||||||
gr.update(visible=True if tags else False),
|
gr.update(visible=True if tags else False),
|
||||||
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
||||||
]
|
]
|
||||||
@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
||||||
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
||||||
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
||||||
|
self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
|
||||||
with gr.Row() as row_random_prompt:
|
with gr.Row() as row_random_prompt:
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||||
@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
self.taginfo,
|
self.taginfo,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
|
self.edit_negative_text,
|
||||||
row_random_prompt,
|
row_random_prompt,
|
||||||
random_prompt,
|
random_prompt,
|
||||||
]
|
]
|
||||||
@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
self.select_sd_version,
|
self.select_sd_version,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
|
self.edit_negative_text,
|
||||||
self.edit_notes,
|
self.edit_notes,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
|
||||||
|
@ -24,13 +24,16 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
|
|
||||||
alias = lora_on_disk.get_alias()
|
alias = lora_on_disk.get_alias()
|
||||||
|
|
||||||
|
search_terms = [self.search_terms_from_path(lora_on_disk.filename)]
|
||||||
|
if lora_on_disk.hash:
|
||||||
|
search_terms.append(lora_on_disk.hash)
|
||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
"shorthash": lora_on_disk.shorthash,
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
"search_terms": search_terms,
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
@ -45,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
if activation_text:
|
if activation_text:
|
||||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||||
|
|
||||||
|
negative_prompt = item["user_metadata"].get("negative text")
|
||||||
|
item["negative_prompt"] = quote_js("")
|
||||||
|
if negative_prompt:
|
||||||
|
item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
|
||||||
|
|
||||||
sd_version = item["user_metadata"].get("sd version")
|
sd_version = item["user_metadata"].get("sd version")
|
||||||
if sd_version in network.SdVersion.__members__:
|
if sd_version in network.SdVersion.__members__:
|
||||||
item["sd_version"] = sd_version
|
item["sd_version"] = sd_version
|
||||||
|
@ -1,16 +1,9 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils
|
||||||
from scunet_model_arch import SCUNet
|
|
||||||
|
|
||||||
from modules.modelloader import load_file_from_url
|
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers.append(scaler_data2)
|
scalers.append(scaler_data2)
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def tiled_inference(img, model):
|
|
||||||
# test the image tile by tile
|
|
||||||
h, w = img.shape[2:]
|
|
||||||
tile = opts.SCUNET_tile
|
|
||||||
tile_overlap = opts.SCUNET_tile_overlap
|
|
||||||
if tile == 0:
|
|
||||||
return model(img)
|
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
|
||||||
assert tile % 8 == 0, "tile size should be a multiple of window_size"
|
|
||||||
sf = 1
|
|
||||||
|
|
||||||
stride = tile - tile_overlap
|
|
||||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
|
||||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
|
||||||
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
|
|
||||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
|
|
||||||
|
|
||||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
|
|
||||||
for h_idx in h_idx_list:
|
|
||||||
|
|
||||||
for w_idx in w_idx_list:
|
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
|
||||||
|
|
||||||
out_patch = model(in_patch)
|
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
|
||||||
|
|
||||||
E[
|
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
|
||||||
].add_(out_patch)
|
|
||||||
W[
|
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
|
||||||
].add_(out_patch_mask)
|
|
||||||
pbar.update(1)
|
|
||||||
output = E.div_(W)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
img = upscaler_utils.upscale_2(
|
||||||
tile = opts.SCUNET_tile
|
img,
|
||||||
h, w = img.height, img.width
|
model,
|
||||||
np_img = np.array(img)
|
tile_size=shared.opts.SCUNET_tile,
|
||||||
np_img = np_img[:, :, ::-1] # RGB to BGR
|
tile_overlap=shared.opts.SCUNET_tile_overlap,
|
||||||
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
scale=1, # ScuNET is a denoising model, not an upscaler
|
||||||
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
desc='ScuNET',
|
||||||
|
)
|
||||||
if tile > h or tile > w:
|
|
||||||
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
|
||||||
_img[:, :, :h, :w] = torch_img # pad image
|
|
||||||
torch_img = _img
|
|
||||||
|
|
||||||
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
|
|
||||||
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
|
||||||
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
|
||||||
del torch_img, torch_output
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
return img
|
||||||
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
|
||||||
output = output[:, :, ::-1] # BGR to RGB
|
|
||||||
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
if path.startswith("http"):
|
if path.startswith("http"):
|
||||||
# TODO: this doesn't use `path` at all?
|
# TODO: this doesn't use `path` at all?
|
||||||
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
|
||||||
model.eval()
|
|
||||||
for _, v in model.named_parameters():
|
|
||||||
v.requires_grad = False
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import shared
|
|
||||||
|
|
||||||
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||||
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||||
|
@ -1,268 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from einops import rearrange
|
|
||||||
from einops.layers.torch import Rearrange
|
|
||||||
from timm.models.layers import trunc_normal_, DropPath
|
|
||||||
|
|
||||||
|
|
||||||
class WMSA(nn.Module):
|
|
||||||
""" Self-attention module in Swin Transformer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
|
||||||
super(WMSA, self).__init__()
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.scale = self.head_dim ** -0.5
|
|
||||||
self.n_heads = input_dim // head_dim
|
|
||||||
self.window_size = window_size
|
|
||||||
self.type = type
|
|
||||||
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
|
||||||
|
|
||||||
self.relative_position_params = nn.Parameter(
|
|
||||||
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
|
||||||
|
|
||||||
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
|
||||||
|
|
||||||
trunc_normal_(self.relative_position_params, std=.02)
|
|
||||||
self.relative_position_params = torch.nn.Parameter(
|
|
||||||
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
|
||||||
2).transpose(
|
|
||||||
0, 1))
|
|
||||||
|
|
||||||
def generate_mask(self, h, w, p, shift):
|
|
||||||
""" generating the mask of SW-MSA
|
|
||||||
Args:
|
|
||||||
shift: shift parameters in CyclicShift.
|
|
||||||
Returns:
|
|
||||||
attn_mask: should be (1 1 w p p),
|
|
||||||
"""
|
|
||||||
# supporting square.
|
|
||||||
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
|
||||||
if self.type == 'W':
|
|
||||||
return attn_mask
|
|
||||||
|
|
||||||
s = p - shift
|
|
||||||
attn_mask[-1, :, :s, :, s:, :] = True
|
|
||||||
attn_mask[-1, :, s:, :, :s, :] = True
|
|
||||||
attn_mask[:, -1, :, :s, :, s:] = True
|
|
||||||
attn_mask[:, -1, :, s:, :, :s] = True
|
|
||||||
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
|
||||||
return attn_mask
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
""" Forward pass of Window Multi-head Self-attention module.
|
|
||||||
Args:
|
|
||||||
x: input tensor with shape of [b h w c];
|
|
||||||
attn_mask: attention mask, fill -inf where the value is True;
|
|
||||||
Returns:
|
|
||||||
output: tensor shape [b h w c]
|
|
||||||
"""
|
|
||||||
if self.type != 'W':
|
|
||||||
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
|
||||||
h_windows = x.size(1)
|
|
||||||
w_windows = x.size(2)
|
|
||||||
# square validation
|
|
||||||
# assert h_windows == w_windows
|
|
||||||
|
|
||||||
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
|
||||||
qkv = self.embedding_layer(x)
|
|
||||||
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
|
||||||
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
|
||||||
# Adding learnable relative embedding
|
|
||||||
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
|
||||||
# Using Attn Mask to distinguish different subwindows.
|
|
||||||
if self.type != 'W':
|
|
||||||
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
|
||||||
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
|
||||||
|
|
||||||
probs = nn.functional.softmax(sim, dim=-1)
|
|
||||||
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
|
||||||
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
|
||||||
output = self.linear(output)
|
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
|
||||||
|
|
||||||
if self.type != 'W':
|
|
||||||
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def relative_embedding(self):
|
|
||||||
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
|
||||||
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
|
||||||
# negative is allowed
|
|
||||||
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
|
||||||
""" SwinTransformer Block
|
|
||||||
"""
|
|
||||||
super(Block, self).__init__()
|
|
||||||
self.input_dim = input_dim
|
|
||||||
self.output_dim = output_dim
|
|
||||||
assert type in ['W', 'SW']
|
|
||||||
self.type = type
|
|
||||||
if input_resolution <= window_size:
|
|
||||||
self.type = 'W'
|
|
||||||
|
|
||||||
self.ln1 = nn.LayerNorm(input_dim)
|
|
||||||
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
||||||
self.ln2 = nn.LayerNorm(input_dim)
|
|
||||||
self.mlp = nn.Sequential(
|
|
||||||
nn.Linear(input_dim, 4 * input_dim),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(4 * input_dim, output_dim),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x + self.drop_path(self.msa(self.ln1(x)))
|
|
||||||
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ConvTransBlock(nn.Module):
|
|
||||||
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
|
||||||
""" SwinTransformer and Conv Block
|
|
||||||
"""
|
|
||||||
super(ConvTransBlock, self).__init__()
|
|
||||||
self.conv_dim = conv_dim
|
|
||||||
self.trans_dim = trans_dim
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.window_size = window_size
|
|
||||||
self.drop_path = drop_path
|
|
||||||
self.type = type
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
|
|
||||||
assert self.type in ['W', 'SW']
|
|
||||||
if self.input_resolution <= self.window_size:
|
|
||||||
self.type = 'W'
|
|
||||||
|
|
||||||
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
|
||||||
self.type, self.input_resolution)
|
|
||||||
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
|
||||||
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
|
||||||
|
|
||||||
self.conv_block = nn.Sequential(
|
|
||||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
|
||||||
conv_x = self.conv_block(conv_x) + conv_x
|
|
||||||
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
|
||||||
trans_x = self.trans_block(trans_x)
|
|
||||||
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
|
||||||
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
|
||||||
x = x + res
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SCUNet(nn.Module):
|
|
||||||
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
|
||||||
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
|
||||||
super(SCUNet, self).__init__()
|
|
||||||
if config is None:
|
|
||||||
config = [2, 2, 2, 2, 2, 2, 2]
|
|
||||||
self.config = config
|
|
||||||
self.dim = dim
|
|
||||||
self.head_dim = 32
|
|
||||||
self.window_size = 8
|
|
||||||
|
|
||||||
# drop path rate for each layer
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
|
||||||
|
|
||||||
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
|
||||||
|
|
||||||
begin = 0
|
|
||||||
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution)
|
|
||||||
for i in range(config[0])] + \
|
|
||||||
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[0]
|
|
||||||
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
|
||||||
for i in range(config[1])] + \
|
|
||||||
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[1]
|
|
||||||
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
|
||||||
for i in range(config[2])] + \
|
|
||||||
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
|
||||||
|
|
||||||
begin += config[2]
|
|
||||||
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution // 8)
|
|
||||||
for i in range(config[3])]
|
|
||||||
|
|
||||||
begin += config[3]
|
|
||||||
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
|
||||||
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution // 4)
|
|
||||||
for i in range(config[4])]
|
|
||||||
|
|
||||||
begin += config[4]
|
|
||||||
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
|
||||||
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution // 2)
|
|
||||||
for i in range(config[5])]
|
|
||||||
|
|
||||||
begin += config[5]
|
|
||||||
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
|
||||||
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
|
||||||
'W' if not i % 2 else 'SW', input_resolution)
|
|
||||||
for i in range(config[6])]
|
|
||||||
|
|
||||||
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
|
||||||
|
|
||||||
self.m_head = nn.Sequential(*self.m_head)
|
|
||||||
self.m_down1 = nn.Sequential(*self.m_down1)
|
|
||||||
self.m_down2 = nn.Sequential(*self.m_down2)
|
|
||||||
self.m_down3 = nn.Sequential(*self.m_down3)
|
|
||||||
self.m_body = nn.Sequential(*self.m_body)
|
|
||||||
self.m_up3 = nn.Sequential(*self.m_up3)
|
|
||||||
self.m_up2 = nn.Sequential(*self.m_up2)
|
|
||||||
self.m_up1 = nn.Sequential(*self.m_up1)
|
|
||||||
self.m_tail = nn.Sequential(*self.m_tail)
|
|
||||||
# self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def forward(self, x0):
|
|
||||||
|
|
||||||
h, w = x0.size()[-2:]
|
|
||||||
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
|
||||||
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
|
||||||
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
|
||||||
|
|
||||||
x1 = self.m_head(x0)
|
|
||||||
x2 = self.m_down1(x1)
|
|
||||||
x3 = self.m_down2(x2)
|
|
||||||
x4 = self.m_down3(x3)
|
|
||||||
x = self.m_body(x4)
|
|
||||||
x = self.m_up3(x + x4)
|
|
||||||
x = self.m_up2(x + x3)
|
|
||||||
x = self.m_up1(x + x2)
|
|
||||||
x = self.m_tail(x + x1)
|
|
||||||
|
|
||||||
x = x[..., :h, :w]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
@ -1,20 +1,15 @@
|
|||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import platform
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import devices, modelloader, script_callbacks, shared, upscaler_utils
|
||||||
from modules.shared import opts, state
|
|
||||||
from swinir_model_arch import SwinIR
|
|
||||||
from swinir_model_arch_v2 import Swin2SR
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||||
|
|
||||||
device_swinir = devices.get_device_for('swinir')
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
scalers.append(model_data)
|
scalers.append(model_data)
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:
|
||||||
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
current_config = (model_file, shared.opts.SWIN_tile)
|
||||||
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
|
||||||
current_config = (model_file, opts.SWIN_tile)
|
|
||||||
|
|
||||||
if use_compile and self._cached_model_config == current_config:
|
if self._cached_model_config == current_config:
|
||||||
model = self._cached_model
|
model = self._cached_model
|
||||||
else:
|
else:
|
||||||
self._cached_model = None
|
|
||||||
try:
|
try:
|
||||||
model = self.load_model(model_file)
|
model = self.load_model(model_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
|
||||||
if use_compile:
|
|
||||||
model = torch.compile(model)
|
|
||||||
self._cached_model = model
|
self._cached_model = model
|
||||||
self._cached_model_config = current_config
|
self._cached_model_config = current_config
|
||||||
img = upscale(img, model)
|
|
||||||
|
img = upscaler_utils.upscale_2(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
tile_size=shared.opts.SWIN_tile,
|
||||||
|
tile_overlap=shared.opts.SWIN_tile_overlap,
|
||||||
|
scale=model.scale,
|
||||||
|
desc="SwinIR",
|
||||||
|
)
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if filename.endswith(".v2.pth"):
|
|
||||||
model = Swin2SR(
|
model_descriptor = modelloader.load_spandrel_model(
|
||||||
upscale=scale,
|
filename,
|
||||||
in_chans=3,
|
device=self._get_device(),
|
||||||
img_size=64,
|
prefer_half=(devices.dtype == torch.float16),
|
||||||
window_size=8,
|
expected_architecture="SwinIR",
|
||||||
img_range=1.0,
|
|
||||||
depths=[6, 6, 6, 6, 6, 6],
|
|
||||||
embed_dim=180,
|
|
||||||
num_heads=[6, 6, 6, 6, 6, 6],
|
|
||||||
mlp_ratio=2,
|
|
||||||
upsampler="nearest+conv",
|
|
||||||
resi_connection="1conv",
|
|
||||||
)
|
)
|
||||||
params = None
|
if getattr(shared.opts, 'SWIN_torch_compile', False):
|
||||||
else:
|
try:
|
||||||
model = SwinIR(
|
model_descriptor.model.compile()
|
||||||
upscale=scale,
|
except Exception:
|
||||||
in_chans=3,
|
logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True)
|
||||||
img_size=64,
|
return model_descriptor
|
||||||
window_size=8,
|
|
||||||
img_range=1.0,
|
|
||||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
|
||||||
embed_dim=240,
|
|
||||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
|
||||||
mlp_ratio=2,
|
|
||||||
upsampler="nearest+conv",
|
|
||||||
resi_connection="3conv",
|
|
||||||
)
|
|
||||||
params = "params_ema"
|
|
||||||
|
|
||||||
pretrained_model = torch.load(filename)
|
def _get_device(self):
|
||||||
if params is not None:
|
return devices.get_device_for('swinir')
|
||||||
model.load_state_dict(pretrained_model[params], strict=True)
|
|
||||||
else:
|
|
||||||
model.load_state_dict(pretrained_model, strict=True)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
tile=None,
|
|
||||||
tile_overlap=None,
|
|
||||||
window_size=8,
|
|
||||||
scale=4,
|
|
||||||
):
|
|
||||||
tile = tile or opts.SWIN_tile
|
|
||||||
tile_overlap = tile_overlap or opts.SWIN_tile_overlap
|
|
||||||
|
|
||||||
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
|
||||||
with torch.no_grad(), devices.autocast():
|
|
||||||
_, _, h_old, w_old = img.size()
|
|
||||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
|
||||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
|
||||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
|
||||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
|
||||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
|
||||||
output = output[..., : h_old * scale, : w_old * scale]
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
if output.ndim == 3:
|
|
||||||
output = np.transpose(
|
|
||||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
|
||||||
) # CHW-RGB to HCW-BGR
|
|
||||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
|
||||||
return Image.fromarray(output, "RGB")
|
|
||||||
|
|
||||||
|
|
||||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
|
||||||
# test the image tile by tile
|
|
||||||
b, c, h, w = img.size()
|
|
||||||
tile = min(tile, h, w)
|
|
||||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
|
||||||
sf = scale
|
|
||||||
|
|
||||||
stride = tile - tile_overlap
|
|
||||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
|
||||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
|
||||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
|
||||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
|
||||||
|
|
||||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
|
||||||
for h_idx in h_idx_list:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
for w_idx in w_idx_list:
|
|
||||||
if state.interrupted or state.skipped:
|
|
||||||
break
|
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
|
||||||
out_patch = model(in_patch)
|
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
|
||||||
|
|
||||||
E[
|
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
|
||||||
].add_(out_patch)
|
|
||||||
W[
|
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
|
||||||
].add_(out_patch_mask)
|
|
||||||
pbar.update(1)
|
|
||||||
output = E.div_(W)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def on_ui_settings():
|
def on_ui_settings():
|
||||||
@ -185,7 +89,6 @@ def on_ui_settings():
|
|||||||
|
|
||||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||||
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
|
||||||
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,867 +0,0 @@
|
|||||||
# -----------------------------------------------------------------------------------
|
|
||||||
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
|
||||||
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
|
||||||
# -----------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint as checkpoint
|
|
||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
|
||||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
||||||
super().__init__()
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = hidden_features or in_features
|
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
||||||
self.act = act_layer()
|
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
||||||
self.drop = nn.Dropout(drop)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def window_partition(x, window_size):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: (B, H, W, C)
|
|
||||||
window_size (int): window size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
|
||||||
"""
|
|
||||||
B, H, W, C = x.shape
|
|
||||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
|
||||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
||||||
return windows
|
|
||||||
|
|
||||||
|
|
||||||
def window_reverse(windows, window_size, H, W):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
|
||||||
window_size (int): Window size
|
|
||||||
H (int): Height of image
|
|
||||||
W (int): Width of image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
x: (B, H, W, C)
|
|
||||||
"""
|
|
||||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
||||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class WindowAttention(nn.Module):
|
|
||||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
|
||||||
It supports both of shifted and non-shifted window.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
window_size (tuple[int]): The height and width of the window.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
|
||||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
|
||||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.window_size = window_size # Wh, Ww
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
self.scale = qk_scale or head_dim ** -0.5
|
|
||||||
|
|
||||||
# define a parameter table of relative position bias
|
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
|
||||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
|
||||||
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
|
||||||
coords_h = torch.arange(self.window_size[0])
|
|
||||||
coords_w = torch.arange(self.window_size[1])
|
|
||||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
||||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
||||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
||||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
||||||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
|
||||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
||||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
||||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
||||||
self.register_buffer("relative_position_index", relative_position_index)
|
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
|
||||||
self.proj = nn.Linear(dim, dim)
|
|
||||||
|
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
|
||||||
|
|
||||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: input features with shape of (num_windows*B, N, C)
|
|
||||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
|
||||||
"""
|
|
||||||
B_, N, C = x.shape
|
|
||||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
||||||
|
|
||||||
q = q * self.scale
|
|
||||||
attn = (q @ k.transpose(-2, -1))
|
|
||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
|
||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
nW = mask.shape[0]
|
|
||||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
|
||||||
attn = attn.view(-1, self.num_heads, N, N)
|
|
||||||
attn = self.softmax(attn)
|
|
||||||
else:
|
|
||||||
attn = self.softmax(attn)
|
|
||||||
|
|
||||||
attn = self.attn_drop(attn)
|
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
|
||||||
x = self.proj(x)
|
|
||||||
x = self.proj_drop(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
|
||||||
|
|
||||||
def flops(self, N):
|
|
||||||
# calculate flops for 1 window with token length of N
|
|
||||||
flops = 0
|
|
||||||
# qkv = self.qkv(x)
|
|
||||||
flops += N * self.dim * 3 * self.dim
|
|
||||||
# attn = (q @ k.transpose(-2, -1))
|
|
||||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
|
||||||
# x = (attn @ v)
|
|
||||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
|
||||||
# x = self.proj(x)
|
|
||||||
flops += N * self.dim * self.dim
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class SwinTransformerBlock(nn.Module):
|
|
||||||
r""" Swin Transformer Block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
input_resolution (tuple[int]): Input resolution.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
window_size (int): Window size.
|
|
||||||
shift_size (int): Shift size for SW-MSA.
|
|
||||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
||||||
drop (float, optional): Dropout rate. Default: 0.0
|
|
||||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
||||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
|
||||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
|
||||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
|
||||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.window_size = window_size
|
|
||||||
self.shift_size = shift_size
|
|
||||||
self.mlp_ratio = mlp_ratio
|
|
||||||
if min(self.input_resolution) <= self.window_size:
|
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
|
||||||
self.shift_size = 0
|
|
||||||
self.window_size = min(self.input_resolution)
|
|
||||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
|
||||||
self.attn = WindowAttention(
|
|
||||||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
|
||||||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
||||||
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
||||||
self.norm2 = norm_layer(dim)
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
if self.shift_size > 0:
|
|
||||||
attn_mask = self.calculate_mask(self.input_resolution)
|
|
||||||
else:
|
|
||||||
attn_mask = None
|
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
|
||||||
# calculate attention mask for SW-MSA
|
|
||||||
H, W = x_size
|
|
||||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
|
||||||
h_slices = (slice(0, -self.window_size),
|
|
||||||
slice(-self.window_size, -self.shift_size),
|
|
||||||
slice(-self.shift_size, None))
|
|
||||||
w_slices = (slice(0, -self.window_size),
|
|
||||||
slice(-self.window_size, -self.shift_size),
|
|
||||||
slice(-self.shift_size, None))
|
|
||||||
cnt = 0
|
|
||||||
for h in h_slices:
|
|
||||||
for w in w_slices:
|
|
||||||
img_mask[:, h, w, :] = cnt
|
|
||||||
cnt += 1
|
|
||||||
|
|
||||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
|
||||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
|
||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
|
||||||
|
|
||||||
return attn_mask
|
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
|
||||||
H, W = x_size
|
|
||||||
B, L, C = x.shape
|
|
||||||
# assert L == H * W, "input feature has wrong size"
|
|
||||||
|
|
||||||
shortcut = x
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = x.view(B, H, W, C)
|
|
||||||
|
|
||||||
# cyclic shift
|
|
||||||
if self.shift_size > 0:
|
|
||||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
|
||||||
else:
|
|
||||||
shifted_x = x
|
|
||||||
|
|
||||||
# partition windows
|
|
||||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
|
||||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
|
||||||
|
|
||||||
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
|
||||||
if self.input_resolution == x_size:
|
|
||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
|
||||||
else:
|
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
|
||||||
|
|
||||||
# merge windows
|
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
|
||||||
|
|
||||||
# reverse cyclic shift
|
|
||||||
if self.shift_size > 0:
|
|
||||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
|
||||||
else:
|
|
||||||
x = shifted_x
|
|
||||||
x = x.view(B, H * W, C)
|
|
||||||
|
|
||||||
# FFN
|
|
||||||
x = shortcut + self.drop_path(x)
|
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
|
||||||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
H, W = self.input_resolution
|
|
||||||
# norm1
|
|
||||||
flops += self.dim * H * W
|
|
||||||
# W-MSA/SW-MSA
|
|
||||||
nW = H * W / self.window_size / self.window_size
|
|
||||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
|
||||||
# mlp
|
|
||||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
|
||||||
# norm2
|
|
||||||
flops += self.dim * H * W
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class PatchMerging(nn.Module):
|
|
||||||
r""" Patch Merging Layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_resolution (tuple[int]): Resolution of input feature.
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
|
||||||
super().__init__()
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
self.dim = dim
|
|
||||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
|
||||||
self.norm = norm_layer(4 * dim)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
x: B, H*W, C
|
|
||||||
"""
|
|
||||||
H, W = self.input_resolution
|
|
||||||
B, L, C = x.shape
|
|
||||||
assert L == H * W, "input feature has wrong size"
|
|
||||||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
|
||||||
|
|
||||||
x = x.view(B, H, W, C)
|
|
||||||
|
|
||||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
|
||||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
||||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
|
||||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
|
||||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
|
||||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
|
||||||
|
|
||||||
x = self.norm(x)
|
|
||||||
x = self.reduction(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
H, W = self.input_resolution
|
|
||||||
flops = H * W * self.dim
|
|
||||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
|
||||||
""" A basic Swin Transformer layer for one stage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
input_resolution (tuple[int]): Input resolution.
|
|
||||||
depth (int): Number of blocks.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
window_size (int): Local window size.
|
|
||||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
||||||
drop (float, optional): Dropout rate. Default: 0.0
|
|
||||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
||||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
||||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|
||||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
|
||||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
|
||||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
self.depth = depth
|
|
||||||
self.use_checkpoint = use_checkpoint
|
|
||||||
|
|
||||||
# build blocks
|
|
||||||
self.blocks = nn.ModuleList([
|
|
||||||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
|
||||||
num_heads=num_heads, window_size=window_size,
|
|
||||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop, attn_drop=attn_drop,
|
|
||||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
|
||||||
norm_layer=norm_layer)
|
|
||||||
for i in range(depth)])
|
|
||||||
|
|
||||||
# patch merging layer
|
|
||||||
if downsample is not None:
|
|
||||||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
|
||||||
else:
|
|
||||||
self.downsample = None
|
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
|
||||||
for blk in self.blocks:
|
|
||||||
if self.use_checkpoint:
|
|
||||||
x = checkpoint.checkpoint(blk, x, x_size)
|
|
||||||
else:
|
|
||||||
x = blk(x, x_size)
|
|
||||||
if self.downsample is not None:
|
|
||||||
x = self.downsample(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
for blk in self.blocks:
|
|
||||||
flops += blk.flops()
|
|
||||||
if self.downsample is not None:
|
|
||||||
flops += self.downsample.flops()
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Number of input channels.
|
|
||||||
input_resolution (tuple[int]): Input resolution.
|
|
||||||
depth (int): Number of blocks.
|
|
||||||
num_heads (int): Number of attention heads.
|
|
||||||
window_size (int): Local window size.
|
|
||||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
||||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
||||||
drop (float, optional): Dropout rate. Default: 0.0
|
|
||||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
||||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
||||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|
||||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|
||||||
img_size: Input image size.
|
|
||||||
patch_size: Patch size.
|
|
||||||
resi_connection: The convolutional block before residual connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
|
||||||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
|
||||||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
|
||||||
img_size=224, patch_size=4, resi_connection='1conv'):
|
|
||||||
super(RSTB, self).__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
|
|
||||||
self.residual_group = BasicLayer(dim=dim,
|
|
||||||
input_resolution=input_resolution,
|
|
||||||
depth=depth,
|
|
||||||
num_heads=num_heads,
|
|
||||||
window_size=window_size,
|
|
||||||
mlp_ratio=mlp_ratio,
|
|
||||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop, attn_drop=attn_drop,
|
|
||||||
drop_path=drop_path,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
downsample=downsample,
|
|
||||||
use_checkpoint=use_checkpoint)
|
|
||||||
|
|
||||||
if resi_connection == '1conv':
|
|
||||||
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
|
||||||
elif resi_connection == '3conv':
|
|
||||||
# to save parameters and memory
|
|
||||||
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
|
||||||
|
|
||||||
self.patch_embed = PatchEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
|
||||||
norm_layer=None)
|
|
||||||
|
|
||||||
self.patch_unembed = PatchUnEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
|
||||||
norm_layer=None)
|
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
|
||||||
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
flops += self.residual_group.flops()
|
|
||||||
H, W = self.input_resolution
|
|
||||||
flops += H * W * self.dim * self.dim * 9
|
|
||||||
flops += self.patch_embed.flops()
|
|
||||||
flops += self.patch_unembed.flops()
|
|
||||||
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
|
||||||
r""" Image to Patch Embedding
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_size (int): Image size. Default: 224.
|
|
||||||
patch_size (int): Patch token size. Default: 4.
|
|
||||||
in_chans (int): Number of input image channels. Default: 3.
|
|
||||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
|
||||||
super().__init__()
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.patches_resolution = patches_resolution
|
|
||||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
|
||||||
|
|
||||||
self.in_chans = in_chans
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
if norm_layer is not None:
|
|
||||||
self.norm = norm_layer(embed_dim)
|
|
||||||
else:
|
|
||||||
self.norm = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
|
||||||
if self.norm is not None:
|
|
||||||
x = self.norm(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
H, W = self.img_size
|
|
||||||
if self.norm is not None:
|
|
||||||
flops += H * W * self.embed_dim
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class PatchUnEmbed(nn.Module):
|
|
||||||
r""" Image to Patch Unembedding
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_size (int): Image size. Default: 224.
|
|
||||||
patch_size (int): Patch token size. Default: 4.
|
|
||||||
in_chans (int): Number of input image channels. Default: 3.
|
|
||||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|
||||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
|
||||||
super().__init__()
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
patch_size = to_2tuple(patch_size)
|
|
||||||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
|
||||||
self.img_size = img_size
|
|
||||||
self.patch_size = patch_size
|
|
||||||
self.patches_resolution = patches_resolution
|
|
||||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
|
||||||
|
|
||||||
self.in_chans = in_chans
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
|
||||||
B, HW, C = x.shape
|
|
||||||
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
|
||||||
return x
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Sequential):
|
|
||||||
"""Upsample module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
|
||||||
num_feat (int): Channel number of intermediate features.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale, num_feat):
|
|
||||||
m = []
|
|
||||||
if (scale & (scale - 1)) == 0: # scale = 2^n
|
|
||||||
for _ in range(int(math.log(scale, 2))):
|
|
||||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
|
||||||
m.append(nn.PixelShuffle(2))
|
|
||||||
elif scale == 3:
|
|
||||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
|
||||||
m.append(nn.PixelShuffle(3))
|
|
||||||
else:
|
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
|
||||||
super(Upsample, self).__init__(*m)
|
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
|
||||||
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
|
||||||
Used in lightweight SR to save parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
|
||||||
num_feat (int): Channel number of intermediate features.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
|
||||||
self.num_feat = num_feat
|
|
||||||
self.input_resolution = input_resolution
|
|
||||||
m = []
|
|
||||||
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
|
||||||
m.append(nn.PixelShuffle(scale))
|
|
||||||
super(UpsampleOneStep, self).__init__(*m)
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
H, W = self.input_resolution
|
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
class SwinIR(nn.Module):
|
|
||||||
r""" SwinIR
|
|
||||||
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
img_size (int | tuple(int)): Input image size. Default 64
|
|
||||||
patch_size (int | tuple(int)): Patch size. Default: 1
|
|
||||||
in_chans (int): Number of input image channels. Default: 3
|
|
||||||
embed_dim (int): Patch embedding dimension. Default: 96
|
|
||||||
depths (tuple(int)): Depth of each Swin Transformer layer.
|
|
||||||
num_heads (tuple(int)): Number of attention heads in different layers.
|
|
||||||
window_size (int): Window size. Default: 7
|
|
||||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
|
||||||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
|
||||||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
|
||||||
drop_rate (float): Dropout rate. Default: 0
|
|
||||||
attn_drop_rate (float): Attention dropout rate. Default: 0
|
|
||||||
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
|
||||||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
|
||||||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
|
||||||
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
|
||||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
|
||||||
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
|
||||||
img_range: Image range. 1. or 255.
|
|
||||||
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
|
||||||
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
|
||||||
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
|
||||||
**kwargs):
|
|
||||||
super(SwinIR, self).__init__()
|
|
||||||
num_in_ch = in_chans
|
|
||||||
num_out_ch = in_chans
|
|
||||||
num_feat = 64
|
|
||||||
self.img_range = img_range
|
|
||||||
if in_chans == 3:
|
|
||||||
rgb_mean = (0.4488, 0.4371, 0.4040)
|
|
||||||
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
|
||||||
else:
|
|
||||||
self.mean = torch.zeros(1, 1, 1, 1)
|
|
||||||
self.upscale = upscale
|
|
||||||
self.upsampler = upsampler
|
|
||||||
self.window_size = window_size
|
|
||||||
|
|
||||||
#####################################################################################################
|
|
||||||
################################### 1, shallow feature extraction ###################################
|
|
||||||
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
|
||||||
|
|
||||||
#####################################################################################################
|
|
||||||
################################### 2, deep feature extraction ######################################
|
|
||||||
self.num_layers = len(depths)
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.ape = ape
|
|
||||||
self.patch_norm = patch_norm
|
|
||||||
self.num_features = embed_dim
|
|
||||||
self.mlp_ratio = mlp_ratio
|
|
||||||
|
|
||||||
# split image into non-overlapping patches
|
|
||||||
self.patch_embed = PatchEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
|
||||||
norm_layer=norm_layer if self.patch_norm else None)
|
|
||||||
num_patches = self.patch_embed.num_patches
|
|
||||||
patches_resolution = self.patch_embed.patches_resolution
|
|
||||||
self.patches_resolution = patches_resolution
|
|
||||||
|
|
||||||
# merge non-overlapping patches into image
|
|
||||||
self.patch_unembed = PatchUnEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
|
||||||
norm_layer=norm_layer if self.patch_norm else None)
|
|
||||||
|
|
||||||
# absolute position embedding
|
|
||||||
if self.ape:
|
|
||||||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
|
||||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
|
||||||
|
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
||||||
|
|
||||||
# stochastic depth
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|
||||||
|
|
||||||
# build Residual Swin Transformer blocks (RSTB)
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for i_layer in range(self.num_layers):
|
|
||||||
layer = RSTB(dim=embed_dim,
|
|
||||||
input_resolution=(patches_resolution[0],
|
|
||||||
patches_resolution[1]),
|
|
||||||
depth=depths[i_layer],
|
|
||||||
num_heads=num_heads[i_layer],
|
|
||||||
window_size=window_size,
|
|
||||||
mlp_ratio=self.mlp_ratio,
|
|
||||||
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
downsample=None,
|
|
||||||
use_checkpoint=use_checkpoint,
|
|
||||||
img_size=img_size,
|
|
||||||
patch_size=patch_size,
|
|
||||||
resi_connection=resi_connection
|
|
||||||
|
|
||||||
)
|
|
||||||
self.layers.append(layer)
|
|
||||||
self.norm = norm_layer(self.num_features)
|
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
|
||||||
if resi_connection == '1conv':
|
|
||||||
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
|
||||||
elif resi_connection == '3conv':
|
|
||||||
# to save parameters and memory
|
|
||||||
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
|
||||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
|
||||||
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
|
||||||
|
|
||||||
#####################################################################################################
|
|
||||||
################################ 3, high quality image reconstruction ################################
|
|
||||||
if self.upsampler == 'pixelshuffle':
|
|
||||||
# for classical SR
|
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
|
||||||
nn.LeakyReLU(inplace=True))
|
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
|
||||||
# for lightweight SR (to save parameters)
|
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
|
||||||
(patches_resolution[0], patches_resolution[1]))
|
|
||||||
elif self.upsampler == 'nearest+conv':
|
|
||||||
# for real-world SR (less artifacts)
|
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
|
||||||
nn.LeakyReLU(inplace=True))
|
|
||||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
if self.upscale == 4:
|
|
||||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
||||||
else:
|
|
||||||
# for image denoising and JPEG compression artifact reduction
|
|
||||||
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def no_weight_decay(self):
|
|
||||||
return {'absolute_pos_embed'}
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
|
||||||
def no_weight_decay_keywords(self):
|
|
||||||
return {'relative_position_bias_table'}
|
|
||||||
|
|
||||||
def check_image_size(self, x):
|
|
||||||
_, _, h, w = x.size()
|
|
||||||
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
|
||||||
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
|
||||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
|
||||||
x_size = (x.shape[2], x.shape[3])
|
|
||||||
x = self.patch_embed(x)
|
|
||||||
if self.ape:
|
|
||||||
x = x + self.absolute_pos_embed
|
|
||||||
x = self.pos_drop(x)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x, x_size)
|
|
||||||
|
|
||||||
x = self.norm(x) # B L C
|
|
||||||
x = self.patch_unembed(x, x_size)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
H, W = x.shape[2:]
|
|
||||||
x = self.check_image_size(x)
|
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
|
||||||
x = (x - self.mean) * self.img_range
|
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle':
|
|
||||||
# for classical SR
|
|
||||||
x = self.conv_first(x)
|
|
||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
|
||||||
x = self.conv_before_upsample(x)
|
|
||||||
x = self.conv_last(self.upsample(x))
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
|
||||||
# for lightweight SR
|
|
||||||
x = self.conv_first(x)
|
|
||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
|
||||||
x = self.upsample(x)
|
|
||||||
elif self.upsampler == 'nearest+conv':
|
|
||||||
# for real-world SR
|
|
||||||
x = self.conv_first(x)
|
|
||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
|
||||||
x = self.conv_before_upsample(x)
|
|
||||||
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
|
||||||
if self.upscale == 4:
|
|
||||||
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
|
||||||
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
|
||||||
else:
|
|
||||||
# for image denoising and JPEG compression artifact reduction
|
|
||||||
x_first = self.conv_first(x)
|
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
|
||||||
x = x + self.conv_last(res)
|
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
|
||||||
|
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
|
||||||
|
|
||||||
def flops(self):
|
|
||||||
flops = 0
|
|
||||||
H, W = self.patches_resolution
|
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
|
||||||
flops += self.patch_embed.flops()
|
|
||||||
for layer in self.layers:
|
|
||||||
flops += layer.flops()
|
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
|
||||||
flops += self.upsample.flops()
|
|
||||||
return flops
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
upscale = 4
|
|
||||||
window_size = 8
|
|
||||||
height = (1024 // upscale // window_size + 1) * window_size
|
|
||||||
width = (720 // upscale // window_size + 1) * window_size
|
|
||||||
model = SwinIR(upscale=2, img_size=(height, width),
|
|
||||||
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
|
|
||||||
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
|
|
||||||
print(model)
|
|
||||||
print(height, width, model.flops() / 1e9)
|
|
||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
|
||||||
x = model(x)
|
|
||||||
print(x.shape)
|
|
File diff suppressed because it is too large
Load Diff
@ -218,6 +218,8 @@ onUiLoaded(async() => {
|
|||||||
canvas_hotkey_fullscreen: "KeyS",
|
canvas_hotkey_fullscreen: "KeyS",
|
||||||
canvas_hotkey_move: "KeyF",
|
canvas_hotkey_move: "KeyF",
|
||||||
canvas_hotkey_overlap: "KeyO",
|
canvas_hotkey_overlap: "KeyO",
|
||||||
|
canvas_hotkey_shrink_brush: "KeyQ",
|
||||||
|
canvas_hotkey_grow_brush: "KeyW",
|
||||||
canvas_disabled_functions: [],
|
canvas_disabled_functions: [],
|
||||||
canvas_show_tooltip: true,
|
canvas_show_tooltip: true,
|
||||||
canvas_auto_expand: true,
|
canvas_auto_expand: true,
|
||||||
@ -227,6 +229,8 @@ onUiLoaded(async() => {
|
|||||||
const functionMap = {
|
const functionMap = {
|
||||||
"Zoom": "canvas_hotkey_zoom",
|
"Zoom": "canvas_hotkey_zoom",
|
||||||
"Adjust brush size": "canvas_hotkey_adjust",
|
"Adjust brush size": "canvas_hotkey_adjust",
|
||||||
|
"Hotkey shrink brush": "canvas_hotkey_shrink_brush",
|
||||||
|
"Hotkey enlarge brush": "canvas_hotkey_grow_brush",
|
||||||
"Moving canvas": "canvas_hotkey_move",
|
"Moving canvas": "canvas_hotkey_move",
|
||||||
"Fullscreen": "canvas_hotkey_fullscreen",
|
"Fullscreen": "canvas_hotkey_fullscreen",
|
||||||
"Reset Zoom": "canvas_hotkey_reset",
|
"Reset Zoom": "canvas_hotkey_reset",
|
||||||
@ -686,7 +690,9 @@ onUiLoaded(async() => {
|
|||||||
const hotkeyActions = {
|
const hotkeyActions = {
|
||||||
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
[hotkeysConfig.canvas_hotkey_reset]: resetZoom,
|
||||||
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
|
||||||
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen
|
[hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,
|
||||||
|
[hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),
|
||||||
|
[hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)
|
||||||
};
|
};
|
||||||
|
|
||||||
const action = hotkeyActions[event.code];
|
const action = hotkeyActions[event.code];
|
||||||
|
@ -4,6 +4,8 @@ from modules import shared
|
|||||||
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), {
|
||||||
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
"canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||||
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
"canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"),
|
||||||
|
"canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"),
|
||||||
|
"canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"),
|
||||||
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
"canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"),
|
||||||
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
"canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "),
|
||||||
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
|
||||||
@ -11,5 +13,5 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
|
|||||||
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
|
||||||
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
|
||||||
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
|
||||||
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
|
||||||
}))
|
}))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
|
from modules import scripts, shared, ui_components, ui_settings, infotext_utils
|
||||||
from modules.ui_components import FormColumn
|
from modules.ui_components import FormColumn
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script):
|
|||||||
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
|
||||||
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
|
||||||
|
|
||||||
mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
|
mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
|
||||||
|
|
||||||
with gr.Blocks() as interface:
|
with gr.Blocks() as interface:
|
||||||
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
|
||||||
|
747
extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
Normal file
747
extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
Normal file
@ -0,0 +1,747 @@
|
|||||||
|
import numpy as np
|
||||||
|
import gradio as gr
|
||||||
|
import math
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
import modules.scripts as scripts
|
||||||
|
|
||||||
|
|
||||||
|
class SoftInpaintingSettings:
|
||||||
|
def __init__(self,
|
||||||
|
mask_blend_power,
|
||||||
|
mask_blend_scale,
|
||||||
|
inpaint_detail_preservation,
|
||||||
|
composite_mask_influence,
|
||||||
|
composite_difference_threshold,
|
||||||
|
composite_difference_contrast):
|
||||||
|
self.mask_blend_power = mask_blend_power
|
||||||
|
self.mask_blend_scale = mask_blend_scale
|
||||||
|
self.inpaint_detail_preservation = inpaint_detail_preservation
|
||||||
|
self.composite_mask_influence = composite_mask_influence
|
||||||
|
self.composite_difference_threshold = composite_difference_threshold
|
||||||
|
self.composite_difference_contrast = composite_difference_contrast
|
||||||
|
|
||||||
|
def add_generation_params(self, dest):
|
||||||
|
dest[enabled_gen_param_label] = True
|
||||||
|
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
|
||||||
|
dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
|
||||||
|
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
|
||||||
|
dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
|
||||||
|
dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
|
||||||
|
dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Methods -------------------
|
||||||
|
|
||||||
|
def processing_uses_inpainting(p):
|
||||||
|
# TODO: Figure out a better way to determine if inpainting is being used by p
|
||||||
|
if getattr(p, "image_mask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if getattr(p, "mask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if getattr(p, "nmask", None) is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def latent_blend(settings, a, b, t):
|
||||||
|
"""
|
||||||
|
Interpolates two latent image representations according to the parameter t,
|
||||||
|
where the interpolated vectors' magnitudes are also interpolated separately.
|
||||||
|
The "detail_preservation" factor biases the magnitude interpolation towards
|
||||||
|
the larger of the two magnitudes.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# NOTE: We use inplace operations wherever possible.
|
||||||
|
|
||||||
|
# [4][w][h] to [1][4][w][h]
|
||||||
|
t2 = t.unsqueeze(0)
|
||||||
|
# [4][w][h] to [1][1][w][h] - the [4] seem redundant.
|
||||||
|
t3 = t[0].unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
one_minus_t2 = 1 - t2
|
||||||
|
one_minus_t3 = 1 - t3
|
||||||
|
|
||||||
|
# Linearly interpolate the image vectors.
|
||||||
|
a_scaled = a * one_minus_t2
|
||||||
|
b_scaled = b * t2
|
||||||
|
image_interp = a_scaled
|
||||||
|
image_interp.add_(b_scaled)
|
||||||
|
result_type = image_interp.dtype
|
||||||
|
del a_scaled, b_scaled, t2, one_minus_t2
|
||||||
|
|
||||||
|
# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
|
||||||
|
# 64-bit operations are used here to allow large exponents.
|
||||||
|
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
|
||||||
|
|
||||||
|
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
|
||||||
|
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||||
|
settings.inpaint_detail_preservation) * one_minus_t3
|
||||||
|
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
|
||||||
|
settings.inpaint_detail_preservation) * t3
|
||||||
|
desired_magnitude = a_magnitude
|
||||||
|
desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
|
||||||
|
del a_magnitude, b_magnitude, t3, one_minus_t3
|
||||||
|
|
||||||
|
# Change the linearly interpolated image vectors' magnitudes to the value we want.
|
||||||
|
# This is the last 64-bit operation.
|
||||||
|
image_interp_scaling_factor = desired_magnitude
|
||||||
|
image_interp_scaling_factor.div_(current_magnitude)
|
||||||
|
image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
|
||||||
|
image_interp_scaled = image_interp
|
||||||
|
image_interp_scaled.mul_(image_interp_scaling_factor)
|
||||||
|
del current_magnitude
|
||||||
|
del desired_magnitude
|
||||||
|
del image_interp
|
||||||
|
del image_interp_scaling_factor
|
||||||
|
del result_type
|
||||||
|
|
||||||
|
return image_interp_scaled
|
||||||
|
|
||||||
|
|
||||||
|
def get_modified_nmask(settings, nmask, sigma):
|
||||||
|
"""
|
||||||
|
Converts a negative mask representing the transparency of the original latent vectors being overlayed
|
||||||
|
to a mask that is scaled according to the denoising strength for this step.
|
||||||
|
|
||||||
|
Where:
|
||||||
|
0 = fully opaque, infinite density, fully masked
|
||||||
|
1 = fully transparent, zero density, fully unmasked
|
||||||
|
|
||||||
|
We bring this transparency to a power, as this allows one to simulate N number of blending operations
|
||||||
|
where N can be any positive real value. Using this one can control the balance of influence between
|
||||||
|
the denoiser and the original latents according to the sigma value.
|
||||||
|
|
||||||
|
NOTE: "mask" is not used
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_adaptive_masks(
|
||||||
|
settings: SoftInpaintingSettings,
|
||||||
|
nmask,
|
||||||
|
latent_orig,
|
||||||
|
latent_processed,
|
||||||
|
overlay_images,
|
||||||
|
width, height,
|
||||||
|
paste_to):
|
||||||
|
import torch
|
||||||
|
import modules.processing as proc
|
||||||
|
import modules.images as images
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
|
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
|
||||||
|
latent_mask = nmask[0].float()
|
||||||
|
# convert the original mask into a form we use to scale distances for thresholding
|
||||||
|
mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
|
||||||
|
mask_scalar = (0.5 * (1 - settings.composite_mask_influence)
|
||||||
|
+ mask_scalar * settings.composite_mask_influence)
|
||||||
|
mask_scalar = mask_scalar / (1.00001 - mask_scalar)
|
||||||
|
mask_scalar = mask_scalar.cpu().numpy()
|
||||||
|
|
||||||
|
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
|
||||||
|
|
||||||
|
kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
|
||||||
|
|
||||||
|
masks_for_overlay = []
|
||||||
|
|
||||||
|
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
|
||||||
|
converted_mask = distance_map.float().cpu().numpy()
|
||||||
|
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||||
|
percentile_min=0.9, percentile_max=1, min_width=1)
|
||||||
|
converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,
|
||||||
|
percentile_min=0.25, percentile_max=0.75, min_width=1)
|
||||||
|
|
||||||
|
# The distance at which opacity of original decreases to 50%
|
||||||
|
half_weighted_distance = settings.composite_difference_threshold * mask_scalar
|
||||||
|
converted_mask = converted_mask / half_weighted_distance
|
||||||
|
|
||||||
|
converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
|
||||||
|
converted_mask = smootherstep(converted_mask)
|
||||||
|
converted_mask = 1 - converted_mask
|
||||||
|
converted_mask = 255. * converted_mask
|
||||||
|
converted_mask = converted_mask.astype(np.uint8)
|
||||||
|
converted_mask = Image.fromarray(converted_mask)
|
||||||
|
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||||
|
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||||
|
|
||||||
|
# Remove aliasing artifacts using a gaussian blur.
|
||||||
|
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||||
|
|
||||||
|
# Expand the mask to fit the whole image if needed.
|
||||||
|
if paste_to is not None:
|
||||||
|
converted_mask = proc.uncrop(converted_mask,
|
||||||
|
(overlay_image.width, overlay_image.height),
|
||||||
|
paste_to)
|
||||||
|
|
||||||
|
masks_for_overlay.append(converted_mask)
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||||
|
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||||
|
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||||
|
|
||||||
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
return masks_for_overlay
|
||||||
|
|
||||||
|
|
||||||
|
def apply_masks(
|
||||||
|
settings,
|
||||||
|
nmask,
|
||||||
|
overlay_images,
|
||||||
|
width, height,
|
||||||
|
paste_to):
|
||||||
|
import torch
|
||||||
|
import modules.processing as proc
|
||||||
|
import modules.images as images
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
|
||||||
|
converted_mask = nmask[0].float()
|
||||||
|
converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
|
||||||
|
converted_mask = 255. * converted_mask
|
||||||
|
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
|
||||||
|
converted_mask = Image.fromarray(converted_mask)
|
||||||
|
converted_mask = images.resize_image(2, converted_mask, width, height)
|
||||||
|
converted_mask = proc.create_binary_mask(converted_mask, round=False)
|
||||||
|
|
||||||
|
# Remove aliasing artifacts using a gaussian blur.
|
||||||
|
converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
|
||||||
|
|
||||||
|
# Expand the mask to fit the whole image if needed.
|
||||||
|
if paste_to is not None:
|
||||||
|
converted_mask = proc.uncrop(converted_mask,
|
||||||
|
(width, height),
|
||||||
|
paste_to)
|
||||||
|
|
||||||
|
masks_for_overlay = []
|
||||||
|
|
||||||
|
for i, overlay_image in enumerate(overlay_images):
|
||||||
|
masks_for_overlay[i] = converted_mask
|
||||||
|
|
||||||
|
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
|
||||||
|
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
|
||||||
|
mask=ImageOps.invert(converted_mask.convert('L')))
|
||||||
|
|
||||||
|
overlay_images[i] = image_masked.convert('RGBA')
|
||||||
|
|
||||||
|
return masks_for_overlay
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):
|
||||||
|
"""
|
||||||
|
Generalization convolution filter capable of applying
|
||||||
|
weighted mean, median, maximum, and minimum filters
|
||||||
|
parametrically using an arbitrary kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (nparray):
|
||||||
|
The image, a 2-D array of floats, to which the filter is being applied.
|
||||||
|
kernel (nparray):
|
||||||
|
The kernel, a 2-D array of floats.
|
||||||
|
kernel_center (nparray):
|
||||||
|
The kernel center coordinate, a 1-D array with two elements.
|
||||||
|
percentile_min (float):
|
||||||
|
The lower bound of the histogram window used by the filter,
|
||||||
|
from 0 to 1.
|
||||||
|
percentile_max (float):
|
||||||
|
The upper bound of the histogram window used by the filter,
|
||||||
|
from 0 to 1.
|
||||||
|
min_width (float):
|
||||||
|
The minimum size of the histogram window bounds, in weight units.
|
||||||
|
Must be greater than 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(nparray): A filtered copy of the input image "img", a 2-D array of floats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Converts an index tuple into a vector.
|
||||||
|
def vec(x):
|
||||||
|
return np.array(x)
|
||||||
|
|
||||||
|
kernel_min = -kernel_center
|
||||||
|
kernel_max = vec(kernel.shape) - kernel_center
|
||||||
|
|
||||||
|
def weighted_histogram_filter_single(idx):
|
||||||
|
idx = vec(idx)
|
||||||
|
min_index = np.maximum(0, idx + kernel_min)
|
||||||
|
max_index = np.minimum(vec(img.shape), idx + kernel_max)
|
||||||
|
window_shape = max_index - min_index
|
||||||
|
|
||||||
|
class WeightedElement:
|
||||||
|
"""
|
||||||
|
An element of the histogram, its weight
|
||||||
|
and bounds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value, weight):
|
||||||
|
self.value: float = value
|
||||||
|
self.weight: float = weight
|
||||||
|
self.window_min: float = 0.0
|
||||||
|
self.window_max: float = 1.0
|
||||||
|
|
||||||
|
# Collect the values in the image as WeightedElements,
|
||||||
|
# weighted by their corresponding kernel values.
|
||||||
|
values = []
|
||||||
|
for window_tup in np.ndindex(tuple(window_shape)):
|
||||||
|
window_index = vec(window_tup)
|
||||||
|
image_index = window_index + min_index
|
||||||
|
centered_kernel_index = image_index - idx
|
||||||
|
kernel_index = centered_kernel_index + kernel_center
|
||||||
|
element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])
|
||||||
|
values.append(element)
|
||||||
|
|
||||||
|
def sort_key(x: WeightedElement):
|
||||||
|
return x.value
|
||||||
|
|
||||||
|
values.sort(key=sort_key)
|
||||||
|
|
||||||
|
# Calculate the height of the stack (sum)
|
||||||
|
# and each sample's range they occupy in the stack
|
||||||
|
sum = 0
|
||||||
|
for i in range(len(values)):
|
||||||
|
values[i].window_min = sum
|
||||||
|
sum += values[i].weight
|
||||||
|
values[i].window_max = sum
|
||||||
|
|
||||||
|
# Calculate what range of this stack ("window")
|
||||||
|
# we want to get the weighted average across.
|
||||||
|
window_min = sum * percentile_min
|
||||||
|
window_max = sum * percentile_max
|
||||||
|
window_width = window_max - window_min
|
||||||
|
|
||||||
|
# Ensure the window is within the stack and at least a certain size.
|
||||||
|
if window_width < min_width:
|
||||||
|
window_center = (window_min + window_max) / 2
|
||||||
|
window_min = window_center - min_width / 2
|
||||||
|
window_max = window_center + min_width / 2
|
||||||
|
|
||||||
|
if window_max > sum:
|
||||||
|
window_max = sum
|
||||||
|
window_min = sum - min_width
|
||||||
|
|
||||||
|
if window_min < 0:
|
||||||
|
window_min = 0
|
||||||
|
window_max = min_width
|
||||||
|
|
||||||
|
value = 0
|
||||||
|
value_weight = 0
|
||||||
|
|
||||||
|
# Get the weighted average of all the samples
|
||||||
|
# that overlap with the window, weighted
|
||||||
|
# by the size of their overlap.
|
||||||
|
for i in range(len(values)):
|
||||||
|
if window_min >= values[i].window_max:
|
||||||
|
continue
|
||||||
|
if window_max <= values[i].window_min:
|
||||||
|
break
|
||||||
|
|
||||||
|
s = max(window_min, values[i].window_min)
|
||||||
|
e = min(window_max, values[i].window_max)
|
||||||
|
w = e - s
|
||||||
|
|
||||||
|
value += values[i].value * w
|
||||||
|
value_weight += w
|
||||||
|
|
||||||
|
return value / value_weight if value_weight != 0 else 0
|
||||||
|
|
||||||
|
img_out = img.copy()
|
||||||
|
|
||||||
|
# Apply the kernel operation over each pixel.
|
||||||
|
for index in np.ndindex(img.shape):
|
||||||
|
img_out[index] = weighted_histogram_filter_single(index)
|
||||||
|
|
||||||
|
return img_out
|
||||||
|
|
||||||
|
|
||||||
|
def smoothstep(x):
|
||||||
|
"""
|
||||||
|
The smoothstep function, input should be clamped to 0-1 range.
|
||||||
|
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||||
|
"""
|
||||||
|
return x * x * (3 - 2 * x)
|
||||||
|
|
||||||
|
|
||||||
|
def smootherstep(x):
|
||||||
|
"""
|
||||||
|
The smootherstep function, input should be clamped to 0-1 range.
|
||||||
|
Turns a diagonal line (f(x) = x) into a sigmoid-like curve.
|
||||||
|
"""
|
||||||
|
return x * x * x * (x * (6 * x - 15) + 10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
|
||||||
|
"""
|
||||||
|
Creates a Gaussian kernel with thresholded edges.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stddev_radius (float):
|
||||||
|
Standard deviation of the gaussian kernel, in pixels.
|
||||||
|
max_radius (int):
|
||||||
|
The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.
|
||||||
|
The kernel is thresholded so that any values one pixel beyond this radius
|
||||||
|
is weighted at 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.
|
||||||
|
def gaussian(sqr_mag):
|
||||||
|
return math.exp(-sqr_mag / (stddev_radius * stddev_radius))
|
||||||
|
|
||||||
|
# Helper function for converting a tuple to an array.
|
||||||
|
def vec(x):
|
||||||
|
return np.array(x)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Since a gaussian is unbounded, we need to limit ourselves
|
||||||
|
to a finite range.
|
||||||
|
We taper the ends off at the end of that range so they equal zero
|
||||||
|
while preserving the maximum value of 1 at the mean.
|
||||||
|
"""
|
||||||
|
zero_radius = max_radius + 1.0
|
||||||
|
gauss_zero = gaussian(zero_radius * zero_radius)
|
||||||
|
gauss_kernel_scale = 1 / (1 - gauss_zero)
|
||||||
|
|
||||||
|
def gaussian_kernel_func(coordinate):
|
||||||
|
x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0
|
||||||
|
x = gaussian(x)
|
||||||
|
x -= gauss_zero
|
||||||
|
x *= gauss_kernel_scale
|
||||||
|
x = max(0.0, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
size = max_radius * 2 + 1
|
||||||
|
kernel_center = max_radius
|
||||||
|
kernel = np.zeros((size, size))
|
||||||
|
|
||||||
|
for index in np.ndindex(kernel.shape):
|
||||||
|
kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)
|
||||||
|
|
||||||
|
return kernel, kernel_center
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Constants -------------------
|
||||||
|
|
||||||
|
|
||||||
|
default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
|
||||||
|
|
||||||
|
enabled_ui_label = "Soft inpainting"
|
||||||
|
enabled_gen_param_label = "Soft inpainting enabled"
|
||||||
|
enabled_el_id = "soft_inpainting_enabled"
|
||||||
|
|
||||||
|
ui_labels = SoftInpaintingSettings(
|
||||||
|
"Schedule bias",
|
||||||
|
"Preservation strength",
|
||||||
|
"Transition contrast boost",
|
||||||
|
"Mask influence",
|
||||||
|
"Difference threshold",
|
||||||
|
"Difference contrast")
|
||||||
|
|
||||||
|
ui_info = SoftInpaintingSettings(
|
||||||
|
"Shifts when preservation of original content occurs during denoising.",
|
||||||
|
"How strongly partially masked content should be preserved.",
|
||||||
|
"Amplifies the contrast that may be lost in partially masked regions.",
|
||||||
|
"How strongly the original mask should bias the difference threshold.",
|
||||||
|
"How much an image region can change before the original pixels are not blended in anymore.",
|
||||||
|
"How sharp the transition should be between blended and not blended.")
|
||||||
|
|
||||||
|
gen_param_labels = SoftInpaintingSettings(
|
||||||
|
"Soft inpainting schedule bias",
|
||||||
|
"Soft inpainting preservation strength",
|
||||||
|
"Soft inpainting transition contrast boost",
|
||||||
|
"Soft inpainting mask influence",
|
||||||
|
"Soft inpainting difference threshold",
|
||||||
|
"Soft inpainting difference contrast")
|
||||||
|
|
||||||
|
el_ids = SoftInpaintingSettings(
|
||||||
|
"mask_blend_power",
|
||||||
|
"mask_blend_scale",
|
||||||
|
"inpaint_detail_preservation",
|
||||||
|
"composite_mask_influence",
|
||||||
|
"composite_difference_threshold",
|
||||||
|
"composite_difference_contrast")
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------- Script -------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Script(scripts.Script):
|
||||||
|
def __init__(self):
|
||||||
|
self.section = "inpaint"
|
||||||
|
self.masks_for_overlay = None
|
||||||
|
self.overlay_images = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Soft Inpainting"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible if is_img2img else False
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
if not is_img2img:
|
||||||
|
return
|
||||||
|
|
||||||
|
with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.
|
||||||
|
**High _Mask blur_** values are recommended!
|
||||||
|
""")
|
||||||
|
|
||||||
|
power = \
|
||||||
|
gr.Slider(label=ui_labels.mask_blend_power,
|
||||||
|
info=ui_info.mask_blend_power,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.1,
|
||||||
|
value=default.mask_blend_power,
|
||||||
|
elem_id=el_ids.mask_blend_power)
|
||||||
|
scale = \
|
||||||
|
gr.Slider(label=ui_labels.mask_blend_scale,
|
||||||
|
info=ui_info.mask_blend_scale,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.05,
|
||||||
|
value=default.mask_blend_scale,
|
||||||
|
elem_id=el_ids.mask_blend_scale)
|
||||||
|
detail = \
|
||||||
|
gr.Slider(label=ui_labels.inpaint_detail_preservation,
|
||||||
|
info=ui_info.inpaint_detail_preservation,
|
||||||
|
minimum=1,
|
||||||
|
maximum=32,
|
||||||
|
step=0.5,
|
||||||
|
value=default.inpaint_detail_preservation,
|
||||||
|
elem_id=el_ids.inpaint_detail_preservation)
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
### Pixel Composite Settings
|
||||||
|
""")
|
||||||
|
|
||||||
|
mask_inf = \
|
||||||
|
gr.Slider(label=ui_labels.composite_mask_influence,
|
||||||
|
info=ui_info.composite_mask_influence,
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
step=0.05,
|
||||||
|
value=default.composite_mask_influence,
|
||||||
|
elem_id=el_ids.composite_mask_influence)
|
||||||
|
|
||||||
|
dif_thresh = \
|
||||||
|
gr.Slider(label=ui_labels.composite_difference_threshold,
|
||||||
|
info=ui_info.composite_difference_threshold,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.25,
|
||||||
|
value=default.composite_difference_threshold,
|
||||||
|
elem_id=el_ids.composite_difference_threshold)
|
||||||
|
|
||||||
|
dif_contr = \
|
||||||
|
gr.Slider(label=ui_labels.composite_difference_contrast,
|
||||||
|
info=ui_info.composite_difference_contrast,
|
||||||
|
minimum=0,
|
||||||
|
maximum=8,
|
||||||
|
step=0.25,
|
||||||
|
value=default.composite_difference_contrast,
|
||||||
|
elem_id=el_ids.composite_difference_contrast)
|
||||||
|
|
||||||
|
with gr.Accordion("Help", open=False):
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.mask_blend_power}
|
||||||
|
|
||||||
|
The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).
|
||||||
|
This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.
|
||||||
|
This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.
|
||||||
|
|
||||||
|
- **Below 1**: Stronger preservation near the end (with low sigma)
|
||||||
|
- **1**: Balanced (proportional to sigma)
|
||||||
|
- **Above 1**: Stronger preservation in the beginning (with high sigma)
|
||||||
|
""")
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.mask_blend_scale}
|
||||||
|
|
||||||
|
Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.
|
||||||
|
This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.
|
||||||
|
|
||||||
|
- **Low values**: Favors generated content.
|
||||||
|
- **High values**: Favors original content.
|
||||||
|
""")
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.inpaint_detail_preservation}
|
||||||
|
|
||||||
|
This parameter controls how the original latent vectors and denoised latent vectors are interpolated.
|
||||||
|
With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.
|
||||||
|
This can prevent the loss of contrast that occurs with linear interpolation.
|
||||||
|
|
||||||
|
- **Low values**: Softer blending, details may fade.
|
||||||
|
- **High values**: Stronger contrast, may over-saturate colors.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
## Pixel Composite Settings
|
||||||
|
|
||||||
|
Masks are generated based on how much a part of the image changed after denoising.
|
||||||
|
These masks are used to blend the original and final images together.
|
||||||
|
If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_mask_influence}
|
||||||
|
|
||||||
|
This parameter controls how much the mask should bias this sensitivity to difference.
|
||||||
|
|
||||||
|
- **0**: Ignore the mask, only consider differences in image content.
|
||||||
|
- **1**: Follow the mask closely despite image content changes.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_difference_threshold}
|
||||||
|
|
||||||
|
This value represents the difference at which the original pixels will have less than 50% opacity.
|
||||||
|
|
||||||
|
- **Low values**: Two images patches must be almost the same in order to retain original pixels.
|
||||||
|
- **High values**: Two images patches can be very different and still retain original pixels.
|
||||||
|
""")
|
||||||
|
|
||||||
|
gr.Markdown(
|
||||||
|
f"""
|
||||||
|
### {ui_labels.composite_difference_contrast}
|
||||||
|
|
||||||
|
This value represents the contrast between the opacity of the original and inpainted content.
|
||||||
|
|
||||||
|
- **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.
|
||||||
|
- **High values**: Ghosting will be less common, but transitions may be very sudden.
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
|
||||||
|
(power, gen_param_labels.mask_blend_power),
|
||||||
|
(scale, gen_param_labels.mask_blend_scale),
|
||||||
|
(detail, gen_param_labels.inpaint_detail_preservation),
|
||||||
|
(mask_inf, gen_param_labels.composite_mask_influence),
|
||||||
|
(dif_thresh, gen_param_labels.composite_difference_threshold),
|
||||||
|
(dif_contr, gen_param_labels.composite_difference_contrast)]
|
||||||
|
|
||||||
|
self.paste_field_names = []
|
||||||
|
for _, field_name in self.infotext_fields:
|
||||||
|
self.paste_field_names.append(field_name)
|
||||||
|
|
||||||
|
return [soft_inpainting_enabled,
|
||||||
|
power,
|
||||||
|
scale,
|
||||||
|
detail,
|
||||||
|
mask_inf,
|
||||||
|
dif_thresh,
|
||||||
|
dif_contr]
|
||||||
|
|
||||||
|
def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Shut off the rounding it normally does.
|
||||||
|
p.mask_round = False
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# p.extra_generation_params["Mask rounding"] = False
|
||||||
|
settings.add_generation_params(p.extra_generation_params)
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||||
|
dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
if mba.is_final_blend:
|
||||||
|
mba.blended_latent = mba.current_latent
|
||||||
|
return
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# todo: Why is sigma 2D? Both values are the same.
|
||||||
|
mba.blended_latent = latent_blend(settings,
|
||||||
|
mba.init_latent,
|
||||||
|
mba.current_latent,
|
||||||
|
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,
|
||||||
|
dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
nmask = getattr(p, "nmask", None)
|
||||||
|
if nmask is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
from modules import images
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
|
||||||
|
|
||||||
|
# since the original code puts holes in the existing overlay images,
|
||||||
|
# we have to rebuild them.
|
||||||
|
self.overlay_images = []
|
||||||
|
for img in p.init_images:
|
||||||
|
|
||||||
|
image = images.flatten(img, opts.img2img_background_color)
|
||||||
|
|
||||||
|
if p.paste_to is None and p.resize_mode != 3:
|
||||||
|
image = images.resize_image(p.resize_mode, image, p.width, p.height)
|
||||||
|
|
||||||
|
self.overlay_images.append(image.convert('RGBA'))
|
||||||
|
|
||||||
|
if len(p.init_images) == 1:
|
||||||
|
self.overlay_images = self.overlay_images * p.batch_size
|
||||||
|
|
||||||
|
if getattr(ps.samples, 'already_decoded', False):
|
||||||
|
self.masks_for_overlay = apply_masks(settings=settings,
|
||||||
|
nmask=nmask,
|
||||||
|
overlay_images=self.overlay_images,
|
||||||
|
width=p.width,
|
||||||
|
height=p.height,
|
||||||
|
paste_to=p.paste_to)
|
||||||
|
else:
|
||||||
|
self.masks_for_overlay = apply_adaptive_masks(settings=settings,
|
||||||
|
nmask=nmask,
|
||||||
|
latent_orig=p.init_latent,
|
||||||
|
latent_processed=ps.samples,
|
||||||
|
overlay_images=self.overlay_images,
|
||||||
|
width=p.width,
|
||||||
|
height=p.height,
|
||||||
|
paste_to=p.paste_to)
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,
|
||||||
|
detail_preservation, mask_inf, dif_thresh, dif_contr):
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not processing_uses_inpainting(p):
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.masks_for_overlay is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.overlay_images is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
|
||||||
|
ppmo.overlay_image = self.overlay_images[ppmo.index]
|
@ -1,14 +1,9 @@
|
|||||||
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
|
<div class="card" style="{style}" onclick="{card_clicked}" data-name="{name}" {sort_keys}>
|
||||||
{background_image}
|
{background_image}
|
||||||
<div class="button-row">
|
<div class="button-row">{copy_path_button}{metadata_button}{edit_button}</div>
|
||||||
{metadata_button}
|
<div class="actions">
|
||||||
{edit_button}
|
<div class="additional">{search_terms}</div>
|
||||||
</div>
|
<span class="name">{name}</span>
|
||||||
<div class='actions'>
|
<span class="description">{description}</span>
|
||||||
<div class='additional'>
|
|
||||||
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
|
|
||||||
</div>
|
|
||||||
<span class='name'>{name}</span>
|
|
||||||
<span class='description'>{description}</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
5
html/extra-networks-copy-path-button.html
Normal file
5
html/extra-networks-copy-path-button.html
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
<div class="copy-path-button card-button"
|
||||||
|
title="Copy path to clipboard"
|
||||||
|
onclick="extraNetworksCopyCardPath(event, '{filename}')"
|
||||||
|
data-clipboard-text="{filename}">
|
||||||
|
</div>
|
4
html/extra-networks-edit-item-button.html
Normal file
4
html/extra-networks-edit-item-button.html
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<div class="edit-button card-button"
|
||||||
|
title="Edit metadata"
|
||||||
|
onclick="extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}', '{name}')">
|
||||||
|
</div>
|
4
html/extra-networks-metadata-button.html
Normal file
4
html/extra-networks-metadata-button.html
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
<div class="metadata-button card-button"
|
||||||
|
title="Show internal metadata"
|
||||||
|
onclick="extraNetworksRequestMetadata(event, '{extra_networks_tabname}', '{name}')">
|
||||||
|
</div>
|
55
html/extra-networks-pane.html
Normal file
55
html/extra-networks-pane.html
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane'>
|
||||||
|
<div class="extra-network-control" id="{tabname}_{extra_networks_tabname}_controls" style="display:none" >
|
||||||
|
<div class="extra-network-control--search">
|
||||||
|
<input
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_search"
|
||||||
|
class="extra-network-control--search-text"
|
||||||
|
type="search"
|
||||||
|
placeholder="Filter files"
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_sort"
|
||||||
|
class="extra-network-control--sort"
|
||||||
|
data-sortmode="{data_sortmode}"
|
||||||
|
data-sortkey="{data_sortkey}"
|
||||||
|
title="Sort by path"
|
||||||
|
onclick="extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--sort-icon"></i>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_sort_dir"
|
||||||
|
class="extra-network-control--sort-dir"
|
||||||
|
data-sortdir="{data_sortdir}"
|
||||||
|
title="Sort ascending"
|
||||||
|
onclick="extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--sort-dir-icon"></i>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_tree_view"
|
||||||
|
class="extra-network-control--tree-view {tree_view_btn_extra_class}"
|
||||||
|
title="Enable Tree View"
|
||||||
|
onclick="extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--tree-view-icon"></i>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
id="{tabname}_{extra_networks_tabname}_extra_refresh"
|
||||||
|
class="extra-network-control--refresh"
|
||||||
|
title="Refresh page"
|
||||||
|
onclick="extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');"
|
||||||
|
>
|
||||||
|
<i class="extra-network-control--refresh-icon"></i>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="extra-network-pane-content">
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree {tree_view_div_extra_class}'>
|
||||||
|
{tree_html}
|
||||||
|
</div>
|
||||||
|
<div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards'>
|
||||||
|
{items_html}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
23
html/extra-networks-tree-button.html
Normal file
23
html/extra-networks-tree-button.html
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<span data-filterable-item-text hidden>{search_terms}</span>
|
||||||
|
<div class="tree-list-content {subclass}"
|
||||||
|
type="button"
|
||||||
|
onclick="extraNetworksTreeOnClick(event, '{tabname}', '{extra_networks_tabname}');{onclick_extra}"
|
||||||
|
data-path="{data_path}"
|
||||||
|
data-hash="{data_hash}"
|
||||||
|
>
|
||||||
|
<span class='tree-list-item-action tree-list-item-action--leading'>
|
||||||
|
{action_list_item_action_leading}
|
||||||
|
</span>
|
||||||
|
<span class="tree-list-item-visual tree-list-item-visual--leading">
|
||||||
|
{action_list_item_visual_leading}
|
||||||
|
</span>
|
||||||
|
<span class="tree-list-item-label tree-list-item-label--truncate">
|
||||||
|
{action_list_item_label}
|
||||||
|
</span>
|
||||||
|
<span class="tree-list-item-visual tree-list-item-visual--trailing">
|
||||||
|
{action_list_item_visual_trailing}
|
||||||
|
</span>
|
||||||
|
<span class="tree-list-item-action tree-list-item-action--trailing">
|
||||||
|
{action_list_item_action_trailing}
|
||||||
|
</span>
|
||||||
|
</div>
|
@ -4,107 +4,6 @@
|
|||||||
#licenses pre { margin: 1em 0 2em 0;}
|
#licenses pre { margin: 1em 0 2em 0;}
|
||||||
</style>
|
</style>
|
||||||
|
|
||||||
<h2><a href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">CodeFormer</a></h2>
|
|
||||||
<small>Parts of CodeFormer code had to be copied to be compatible with GFPGAN.</small>
|
|
||||||
<pre>
|
|
||||||
S-Lab License 1.0
|
|
||||||
|
|
||||||
Copyright 2022 S-Lab
|
|
||||||
|
|
||||||
Redistribution and use for non-commercial purpose in source and
|
|
||||||
binary forms, with or without modification, are permitted provided
|
|
||||||
that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright
|
|
||||||
notice, this list of conditions and the following disclaimer in
|
|
||||||
the documentation and/or other materials provided with the
|
|
||||||
distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived
|
|
||||||
from this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
||||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
||||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
||||||
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
||||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
||||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
||||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
||||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
||||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
|
|
||||||
In the event that redistribution and/or use for commercial purpose in
|
|
||||||
source or binary forms, with or without modification is required,
|
|
||||||
please contact the contributor(s) of the work.
|
|
||||||
</pre>
|
|
||||||
|
|
||||||
|
|
||||||
<h2><a href="https://github.com/victorca25/iNNfer/blob/main/LICENSE">ESRGAN</a></h2>
|
|
||||||
<small>Code for architecture and reading models copied.</small>
|
|
||||||
<pre>
|
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2021 victorca25
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
</pre>
|
|
||||||
|
|
||||||
<h2><a href="https://github.com/xinntao/Real-ESRGAN/blob/master/LICENSE">Real-ESRGAN</a></h2>
|
|
||||||
<small>Some code is copied to support ESRGAN models.</small>
|
|
||||||
<pre>
|
|
||||||
BSD 3-Clause License
|
|
||||||
|
|
||||||
Copyright (c) 2021, Xintao Wang
|
|
||||||
All rights reserved.
|
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
|
||||||
modification, are permitted provided that the following conditions are met:
|
|
||||||
|
|
||||||
1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
list of conditions and the following disclaimer.
|
|
||||||
|
|
||||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
this list of conditions and the following disclaimer in the documentation
|
|
||||||
and/or other materials provided with the distribution.
|
|
||||||
|
|
||||||
3. Neither the name of the copyright holder nor the names of its
|
|
||||||
contributors may be used to endorse or promote products derived from
|
|
||||||
this software without specific prior written permission.
|
|
||||||
|
|
||||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
</pre>
|
|
||||||
|
|
||||||
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
<h2><a href="https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE">InvokeAI</a></h2>
|
||||||
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
<small>Some code for compatibility with OSX is taken from lstein's repository.</small>
|
||||||
<pre>
|
<pre>
|
||||||
@ -183,213 +82,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
</pre>
|
</pre>
|
||||||
|
|
||||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
|
||||||
<small>Code added by contributors, most likely copied from this repository.</small>
|
|
||||||
|
|
||||||
<pre>
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [2021] [SwinIR Authors]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
</pre>
|
|
||||||
|
|
||||||
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||||
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||||
<pre>
|
<pre>
|
||||||
|
@ -16,39 +16,55 @@ function toggleCss(key, css, enable) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function setupExtraNetworksForTab(tabname) {
|
function setupExtraNetworksForTab(tabname) {
|
||||||
gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks');
|
function registerPrompt(tabname, id) {
|
||||||
|
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
||||||
|
|
||||||
var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div');
|
if (!activePromptTextarea[tabname]) {
|
||||||
var searchDiv = gradioApp().getElementById(tabname + '_extra_search');
|
activePromptTextarea[tabname] = textarea;
|
||||||
var search = searchDiv.querySelector('textarea');
|
}
|
||||||
var sort = gradioApp().getElementById(tabname + '_extra_sort');
|
|
||||||
var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder');
|
|
||||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
|
||||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
|
||||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
|
||||||
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
|
|
||||||
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
|
|
||||||
|
|
||||||
tabs.appendChild(searchDiv);
|
textarea.addEventListener("focus", function() {
|
||||||
tabs.appendChild(sort);
|
activePromptTextarea[tabname] = textarea;
|
||||||
tabs.appendChild(sortOrder);
|
});
|
||||||
tabs.appendChild(refresh);
|
}
|
||||||
tabs.appendChild(showDirsDiv);
|
|
||||||
|
var tabnav = gradioApp().querySelector('#' + tabname + '_extra_tabs > div.tab-nav');
|
||||||
|
var controlsDiv = document.createElement('DIV');
|
||||||
|
controlsDiv.classList.add('extra-networks-controls-div');
|
||||||
|
tabnav.appendChild(controlsDiv);
|
||||||
|
tabnav.insertBefore(controlsDiv, null);
|
||||||
|
|
||||||
|
var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs');
|
||||||
|
this_tab.querySelectorAll(":scope > [id^='" + tabname + "_']").forEach(function(elem) {
|
||||||
|
// tabname_full = {tabname}_{extra_networks_tabname}
|
||||||
|
var tabname_full = elem.id;
|
||||||
|
var search = gradioApp().querySelector("#" + tabname_full + "_extra_search");
|
||||||
|
var sort_mode = gradioApp().querySelector("#" + tabname_full + "_extra_sort");
|
||||||
|
var sort_dir = gradioApp().querySelector("#" + tabname_full + "_extra_sort_dir");
|
||||||
|
var refresh = gradioApp().querySelector("#" + tabname_full + "_extra_refresh");
|
||||||
|
|
||||||
|
// If any of the buttons above don't exist, we want to skip this iteration of the loop.
|
||||||
|
if (!search || !sort_mode || !sort_dir || !refresh) {
|
||||||
|
return; // `return` is equivalent of `continue` but for forEach loops.
|
||||||
|
}
|
||||||
|
|
||||||
var applyFilter = function() {
|
var applyFilter = function() {
|
||||||
var searchTerm = search.value.toLowerCase();
|
var searchTerm = search.value.toLowerCase();
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {
|
||||||
var searchOnly = elem.querySelector('.search_only');
|
var searchOnly = elem.querySelector('.search_only');
|
||||||
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase();
|
var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms'), function(t) {
|
||||||
|
return t.textContent.toLowerCase();
|
||||||
|
}).join(" ");
|
||||||
|
|
||||||
var visible = text.indexOf(searchTerm) != -1;
|
var visible = text.indexOf(searchTerm) != -1;
|
||||||
|
|
||||||
if (searchOnly && searchTerm.length < 4) {
|
if (searchOnly && searchTerm.length < 4) {
|
||||||
visible = false;
|
visible = false;
|
||||||
}
|
}
|
||||||
|
if (visible) {
|
||||||
elem.style.display = visible ? "" : "none";
|
elem.classList.remove("hidden");
|
||||||
|
} else {
|
||||||
|
elem.classList.add("hidden");
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
applySort();
|
applySort();
|
||||||
@ -56,16 +72,15 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
|
|
||||||
var applySort = function() {
|
var applySort = function() {
|
||||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||||
|
var reverse = sort_dir.dataset.sortdir == "Descending";
|
||||||
var reverse = sortOrder.classList.contains("sortReverse");
|
var sortKey = sort_mode.dataset.sortmode.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
|
||||||
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||||
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
||||||
|
|
||||||
if (sortKeyStore == sort.dataset.sortkey) {
|
if (sortKeyStore == sort_mode.dataset.sortkey) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
sort.dataset.sortkey = sortKeyStore;
|
sort_mode.dataset.sortkey = sortKeyStore;
|
||||||
|
|
||||||
cards.forEach(function(card) {
|
cards.forEach(function(card) {
|
||||||
card.originalParentElement = card.parentElement;
|
card.originalParentElement = card.parentElement;
|
||||||
@ -92,23 +107,17 @@ function setupExtraNetworksForTab(tabname) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
search.addEventListener("input", applyFilter);
|
search.addEventListener("input", applyFilter);
|
||||||
sortOrder.addEventListener("click", function() {
|
|
||||||
sortOrder.classList.toggle("sortReverse");
|
|
||||||
applySort();
|
applySort();
|
||||||
});
|
|
||||||
applyFilter();
|
applyFilter();
|
||||||
|
extraNetworksApplySort[tabname_full] = applySort;
|
||||||
|
extraNetworksApplyFilter[tabname_full] = applyFilter;
|
||||||
|
|
||||||
extraNetworksApplySort[tabname] = applySort;
|
var controls = gradioApp().querySelector("#" + tabname_full + "_controls");
|
||||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
controlsDiv.insertBefore(controls, null);
|
||||||
|
});
|
||||||
|
|
||||||
var showDirsUpdate = function() {
|
registerPrompt(tabname, tabname + "_prompt");
|
||||||
var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }';
|
registerPrompt(tabname, tabname + "_neg_prompt");
|
||||||
toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked);
|
|
||||||
localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0);
|
|
||||||
};
|
|
||||||
showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1;
|
|
||||||
showDirs.addEventListener("change", showDirsUpdate);
|
|
||||||
showDirsUpdate();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
||||||
@ -137,21 +146,32 @@ function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePromp
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
function extraNetworksShowControlsForPage(tabname, tabname_full) {
|
||||||
extraNetworksMovePromptToTab(tabname, '', false, false);
|
gradioApp().querySelectorAll('#' + tabname + '_extra_tabs .extra-networks-controls-div > div').forEach(function(elem) {
|
||||||
|
var targetId = tabname_full + "_controls";
|
||||||
|
elem.style.display = elem.id == targetId ? "" : "none";
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
|
|
||||||
|
function extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||||
|
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||||
|
|
||||||
|
extraNetworksShowControlsForPage(tabname, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt, tabname_full) { // called from python when user selects an extra networks tab
|
||||||
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
||||||
|
|
||||||
|
extraNetworksShowControlsForPage(tabname, tabname_full);
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyExtraNetworkFilter(tabname) {
|
function applyExtraNetworkFilter(tabname_full) {
|
||||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
setTimeout(extraNetworksApplyFilter[tabname_full], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
function applyExtraNetworkSort(tabname) {
|
function applyExtraNetworkSort(tabname_full) {
|
||||||
setTimeout(extraNetworksApplySort[tabname], 1);
|
setTimeout(extraNetworksApplySort[tabname_full], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
var extraNetworksApplyFilter = {};
|
var extraNetworksApplyFilter = {};
|
||||||
@ -161,41 +181,24 @@ var activePromptTextarea = {};
|
|||||||
function setupExtraNetworks() {
|
function setupExtraNetworks() {
|
||||||
setupExtraNetworksForTab('txt2img');
|
setupExtraNetworksForTab('txt2img');
|
||||||
setupExtraNetworksForTab('img2img');
|
setupExtraNetworksForTab('img2img');
|
||||||
|
|
||||||
function registerPrompt(tabname, id) {
|
|
||||||
var textarea = gradioApp().querySelector("#" + id + " > label > textarea");
|
|
||||||
|
|
||||||
if (!activePromptTextarea[tabname]) {
|
|
||||||
activePromptTextarea[tabname] = textarea;
|
|
||||||
}
|
|
||||||
|
|
||||||
textarea.addEventListener("focus", function() {
|
|
||||||
activePromptTextarea[tabname] = textarea;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
registerPrompt('txt2img', 'txt2img_prompt');
|
|
||||||
registerPrompt('txt2img', 'txt2img_neg_prompt');
|
|
||||||
registerPrompt('img2img', 'img2img_prompt');
|
|
||||||
registerPrompt('img2img', 'img2img_neg_prompt');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiLoaded(setupExtraNetworks);
|
|
||||||
|
|
||||||
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
|
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
|
||||||
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
|
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
|
||||||
|
|
||||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/;
|
||||||
var m = text.match(re_extranet);
|
var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g;
|
||||||
|
function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {
|
||||||
|
var m = text.match(isNeg ? re_extranet_neg : re_extranet);
|
||||||
var replaced = false;
|
var replaced = false;
|
||||||
var newTextareaText;
|
var newTextareaText;
|
||||||
if (m) {
|
|
||||||
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
|
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
|
||||||
|
if (m) {
|
||||||
var extraTextAfterNet = m[2];
|
var extraTextAfterNet = m[2];
|
||||||
var partToSearch = m[1];
|
var partToSearch = m[1];
|
||||||
var foundAtPosition = -1;
|
var foundAtPosition = -1;
|
||||||
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) {
|
newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {
|
||||||
m = found.match(re_extranet);
|
m = found.match(isNeg ? re_extranet_neg : re_extranet);
|
||||||
if (m[1] == partToSearch) {
|
if (m[1] == partToSearch) {
|
||||||
replaced = true;
|
replaced = true;
|
||||||
foundAtPosition = pos;
|
foundAtPosition = pos;
|
||||||
@ -203,9 +206,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
|||||||
}
|
}
|
||||||
return found;
|
return found;
|
||||||
});
|
});
|
||||||
|
|
||||||
if (foundAtPosition >= 0) {
|
if (foundAtPosition >= 0) {
|
||||||
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||||
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
||||||
}
|
}
|
||||||
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
|
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
|
||||||
@ -213,13 +215,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, "g"), "");
|
||||||
if (found == text) {
|
replaced = (newTextareaText != textarea.value);
|
||||||
replaced = true;
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return found;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (replaced) {
|
if (replaced) {
|
||||||
@ -230,14 +227,22 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
function cardClicked(tabname, textToAdd, allowNegativePrompt) {
|
function updatePromptArea(text, textArea, isNeg) {
|
||||||
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
|
if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {
|
||||||
|
textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;
|
||||||
if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) {
|
|
||||||
textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updateInput(textarea);
|
updateInput(textArea);
|
||||||
|
}
|
||||||
|
|
||||||
|
function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {
|
||||||
|
if (textToAddNegative.length > 0) {
|
||||||
|
updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"));
|
||||||
|
updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true);
|
||||||
|
} else {
|
||||||
|
var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea");
|
||||||
|
updatePromptArea(textToAdd, textarea);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function saveCardPreview(event, tabname, filename) {
|
function saveCardPreview(event, tabname, filename) {
|
||||||
@ -253,13 +258,200 @@ function saveCardPreview(event, tabname, filename) {
|
|||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event) {
|
function extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {
|
||||||
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea');
|
/**
|
||||||
var button = event.target;
|
* Processes `onclick` events when user clicks on files in tree.
|
||||||
var text = button.classList.contains("search-all") ? "" : button.textContent.trim();
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param btn The clicked `tree-list-item` button.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
// NOTE: Currently unused.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
searchTextarea.value = text;
|
function extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) {
|
||||||
updateInput(searchTextarea);
|
/**
|
||||||
|
* Processes `onclick` events when user clicks on directories in tree.
|
||||||
|
*
|
||||||
|
* Here is how the tree reacts to clicks for various states:
|
||||||
|
* unselected unopened directory: Diretory is selected and expanded.
|
||||||
|
* unselected opened directory: Directory is selected.
|
||||||
|
* selected opened directory: Directory is collapsed and deselected.
|
||||||
|
* chevron is clicked: Directory is expanded or collapsed. Selected state unchanged.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param btn The clicked `tree-list-item` button.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
var ul = btn.nextElementSibling;
|
||||||
|
// This is the actual target that the user clicked on within the target button.
|
||||||
|
// We use this to detect if the chevron was clicked.
|
||||||
|
var true_targ = event.target;
|
||||||
|
|
||||||
|
function _expand_or_collapse(_ul, _btn) {
|
||||||
|
// Expands <ul> if it is collapsed, collapses otherwise. Updates button attributes.
|
||||||
|
if (_ul.hasAttribute("hidden")) {
|
||||||
|
_ul.removeAttribute("hidden");
|
||||||
|
_btn.dataset.expanded = "";
|
||||||
|
} else {
|
||||||
|
_ul.setAttribute("hidden", "");
|
||||||
|
delete _btn.dataset.expanded;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function _remove_selected_from_all() {
|
||||||
|
// Removes the `selected` attribute from all buttons.
|
||||||
|
var sels = document.querySelectorAll("div.tree-list-content");
|
||||||
|
[...sels].forEach(el => {
|
||||||
|
delete el.dataset.selected;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function _select_button(_btn) {
|
||||||
|
// Removes `data-selected` attribute from all buttons then adds to passed button.
|
||||||
|
_remove_selected_from_all();
|
||||||
|
_btn.dataset.selected = "";
|
||||||
|
}
|
||||||
|
|
||||||
|
function _update_search(_tabname, _extra_networks_tabname, _search_text) {
|
||||||
|
// Update search input with select button's path.
|
||||||
|
var search_input_elem = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_search");
|
||||||
|
search_input_elem.value = _search_text;
|
||||||
|
updateInput(search_input_elem);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// If user clicks on the chevron, then we do not select the folder.
|
||||||
|
if (true_targ.matches(".tree-list-item-action--leading, .tree-list-item-action-chevron")) {
|
||||||
|
_expand_or_collapse(ul, btn);
|
||||||
|
} else {
|
||||||
|
// User clicked anywhere else on the button.
|
||||||
|
if ("selected" in btn.dataset && !(ul.hasAttribute("hidden"))) {
|
||||||
|
// If folder is select and open, collapse and deselect button.
|
||||||
|
_expand_or_collapse(ul, btn);
|
||||||
|
delete btn.dataset.selected;
|
||||||
|
_update_search(tabname, extra_networks_tabname, "");
|
||||||
|
} else if (!(!("selected" in btn.dataset) && !(ul.hasAttribute("hidden")))) {
|
||||||
|
// If folder is open and not selected, then we don't collapse; just select.
|
||||||
|
// NOTE: Double inversion sucks but it is the clearest way to show the branching here.
|
||||||
|
_expand_or_collapse(ul, btn);
|
||||||
|
_select_button(btn, tabname, extra_networks_tabname);
|
||||||
|
_update_search(tabname, extra_networks_tabname, btn.dataset.path);
|
||||||
|
} else {
|
||||||
|
// All other cases, just select the button.
|
||||||
|
_select_button(btn, tabname, extra_networks_tabname);
|
||||||
|
_update_search(tabname, extra_networks_tabname, btn.dataset.path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {
|
||||||
|
/**
|
||||||
|
* Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`.
|
||||||
|
*
|
||||||
|
* Determines whether the clicked button in the tree is for a file entry or a directory
|
||||||
|
* then calls the appropriate function.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
var btn = event.currentTarget;
|
||||||
|
var par = btn.parentElement;
|
||||||
|
if (par.dataset.treeEntryType === "file") {
|
||||||
|
extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname);
|
||||||
|
} else {
|
||||||
|
extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {
|
||||||
|
/**
|
||||||
|
* Handles `onclick` events for the Sort Mode button.
|
||||||
|
*
|
||||||
|
* Modifies the data attributes of the Sort Mode button to cycle between
|
||||||
|
* various sorting modes.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
var curr_mode = event.currentTarget.dataset.sortmode;
|
||||||
|
var el_sort_dir = gradioApp().querySelector("#" + tabname + "_" + extra_networks_tabname + "_extra_sort_dir");
|
||||||
|
var sort_dir = el_sort_dir.dataset.sortdir;
|
||||||
|
if (curr_mode == "path") {
|
||||||
|
event.currentTarget.dataset.sortmode = "name";
|
||||||
|
event.currentTarget.dataset.sortkey = "sortName-" + sort_dir + "-640";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort by filename");
|
||||||
|
} else if (curr_mode == "name") {
|
||||||
|
event.currentTarget.dataset.sortmode = "date_created";
|
||||||
|
event.currentTarget.dataset.sortkey = "sortDate_created-" + sort_dir + "-640";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort by date created");
|
||||||
|
} else if (curr_mode == "date_created") {
|
||||||
|
event.currentTarget.dataset.sortmode = "date_modified";
|
||||||
|
event.currentTarget.dataset.sortkey = "sortDate_modified-" + sort_dir + "-640";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort by date modified");
|
||||||
|
} else {
|
||||||
|
event.currentTarget.dataset.sortmode = "path";
|
||||||
|
event.currentTarget.dataset.sortkey = "sortPath-" + sort_dir + "-640";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort by path");
|
||||||
|
}
|
||||||
|
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) {
|
||||||
|
/**
|
||||||
|
* Handles `onclick` events for the Sort Direction button.
|
||||||
|
*
|
||||||
|
* Modifies the data attributes of the Sort Direction button to cycle between
|
||||||
|
* ascending and descending sort directions.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
if (event.currentTarget.dataset.sortdir == "Ascending") {
|
||||||
|
event.currentTarget.dataset.sortdir = "Descending";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort descending");
|
||||||
|
} else {
|
||||||
|
event.currentTarget.dataset.sortdir = "Ascending";
|
||||||
|
event.currentTarget.setAttribute("title", "Sort ascending");
|
||||||
|
}
|
||||||
|
applyExtraNetworkSort(tabname + "_" + extra_networks_tabname);
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) {
|
||||||
|
/**
|
||||||
|
* Handles `onclick` events for the Tree View button.
|
||||||
|
*
|
||||||
|
* Toggles the tree view in the extra networks pane.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
gradioApp().getElementById(tabname + "_" + extra_networks_tabname + "_tree").classList.toggle("hidden");
|
||||||
|
event.currentTarget.classList.toggle("extra-network-control--enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
function extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {
|
||||||
|
/**
|
||||||
|
* Handles `onclick` events for the Refresh Page button.
|
||||||
|
*
|
||||||
|
* In order to actually call the python functions in `ui_extra_networks.py`
|
||||||
|
* to refresh the page, we created an empty gradio button in that file with an
|
||||||
|
* event handler that refreshes the page. So what this function here does
|
||||||
|
* is it manually raises a `click` event on that button.
|
||||||
|
*
|
||||||
|
* @param event The generated event.
|
||||||
|
* @param tabname The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.
|
||||||
|
* @param extra_networks_tabname The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.
|
||||||
|
*/
|
||||||
|
var btn_refresh_internal = gradioApp().getElementById(tabname + "_extra_refresh_internal");
|
||||||
|
btn_refresh_internal.dispatchEvent(new Event("click"));
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalPopup = null;
|
var globalPopup = null;
|
||||||
@ -337,6 +529,11 @@ function requestGet(url, data, handler, errorHandler) {
|
|||||||
xhr.send(js);
|
xhr.send(js);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function extraNetworksCopyCardPath(event, path) {
|
||||||
|
navigator.clipboard.writeText(path);
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
function extraNetworksRequestMetadata(event, extraPage, cardName) {
|
||||||
var showError = function() {
|
var showError = function() {
|
||||||
extraNetworksShowMetadata("there was an error getting metadata");
|
extraNetworksShowMetadata("there was an error getting metadata");
|
||||||
@ -398,3 +595,36 @@ window.addEventListener("keydown", function(event) {
|
|||||||
closePopup();
|
closePopup();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setup custom loading for this script.
|
||||||
|
* We need to wait for all of our HTML to be generated in the extra networks tabs
|
||||||
|
* before we can actually run the `setupExtraNetworks` function.
|
||||||
|
* The `onUiLoaded` function actually runs before all of our extra network tabs are
|
||||||
|
* finished generating. Thus we needed this new method.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
var uiAfterScriptsCallbacks = [];
|
||||||
|
var uiAfterScriptsTimeout = null;
|
||||||
|
var executedAfterScripts = false;
|
||||||
|
|
||||||
|
function scheduleAfterScriptsCallbacks() {
|
||||||
|
clearTimeout(uiAfterScriptsTimeout);
|
||||||
|
uiAfterScriptsTimeout = setTimeout(function() {
|
||||||
|
executeCallbacks(uiAfterScriptsCallbacks);
|
||||||
|
}, 200);
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
|
var mutationObserver = new MutationObserver(function(m) {
|
||||||
|
if (!executedAfterScripts &&
|
||||||
|
gradioApp().querySelectorAll("[id$='_extra_search']").length == 8) {
|
||||||
|
executedAfterScripts = true;
|
||||||
|
scheduleAfterScriptsCallbacks();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
mutationObserver.observe(gradioApp(), {childList: true, subtree: true});
|
||||||
|
});
|
||||||
|
|
||||||
|
uiAfterScriptsCallbacks.push(setupExtraNetworks);
|
||||||
|
@ -150,6 +150,14 @@ function submit() {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function submit_txt2img_upscale() {
|
||||||
|
var res = submit(...arguments);
|
||||||
|
|
||||||
|
res[2] = selected_gallery_index();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
function submit_img2img() {
|
function submit_img2img() {
|
||||||
showSubmitButtons('img2img', false);
|
showSubmitButtons('img2img', false);
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -31,7 +31,7 @@ from typing import Any
|
|||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
@ -230,6 +230,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
|
self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
@ -251,6 +252,24 @@ class Api:
|
|||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
|
|
||||||
|
txt2img_script_runner = scripts.scripts_txt2img
|
||||||
|
img2img_script_runner = scripts.scripts_img2img
|
||||||
|
|
||||||
|
if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
|
||||||
|
ui.create_ui()
|
||||||
|
|
||||||
|
if not txt2img_script_runner.scripts:
|
||||||
|
txt2img_script_runner.initialize_scripts(False)
|
||||||
|
if not self.default_script_arg_txt2img:
|
||||||
|
self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
|
||||||
|
|
||||||
|
if not img2img_script_runner.scripts:
|
||||||
|
img2img_script_runner.initialize_scripts(True)
|
||||||
|
if not self.default_script_arg_img2img:
|
||||||
|
self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_api_route(self, path: str, endpoint, **kwargs):
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
||||||
@ -312,8 +331,13 @@ class Api:
|
|||||||
script_args[script.args_from:script.args_to] = ui_default_values
|
script_args[script.args_from:script.args_to] = ui_default_values
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
|
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
|
||||||
script_args = default_script_args.copy()
|
script_args = default_script_args.copy()
|
||||||
|
|
||||||
|
if input_script_args is not None:
|
||||||
|
for index, value in input_script_args.items():
|
||||||
|
script_args[index] = value
|
||||||
|
|
||||||
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
|
||||||
if selectable_scripts:
|
if selectable_scripts:
|
||||||
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
|
||||||
@ -335,13 +359,83 @@ class Api:
|
|||||||
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
|
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
|
||||||
|
"""Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
|
||||||
|
|
||||||
|
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
|
||||||
|
|
||||||
|
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not request.infotext:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
possible_fields = infotext_utils.paste_fields[tabname]["fields"]
|
||||||
|
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
|
||||||
|
params = infotext_utils.parse_generation_parameters(request.infotext)
|
||||||
|
|
||||||
|
def get_field_value(field, params):
|
||||||
|
value = field.function(params) if field.function else params.get(field.label)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if field.api in request.__fields__:
|
||||||
|
target_type = request.__fields__[field.api].type_
|
||||||
|
else:
|
||||||
|
target_type = type(field.component.value)
|
||||||
|
|
||||||
|
if target_type == type(None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
|
||||||
|
value = value.get('value')
|
||||||
|
|
||||||
|
if value is not None and not isinstance(value, target_type):
|
||||||
|
value = target_type(value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
for field in possible_fields:
|
||||||
|
if not field.api:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if field.api in set_fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = get_field_value(field, params)
|
||||||
|
if value is not None:
|
||||||
|
setattr(request, field.api, value)
|
||||||
|
|
||||||
|
if request.override_settings is None:
|
||||||
|
request.override_settings = {}
|
||||||
|
|
||||||
|
overriden_settings = infotext_utils.get_override_settings(params)
|
||||||
|
for _, setting_name, value in overriden_settings:
|
||||||
|
if setting_name not in request.override_settings:
|
||||||
|
request.override_settings[setting_name] = value
|
||||||
|
|
||||||
|
if script_runner is not None and mentioned_script_args is not None:
|
||||||
|
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
|
||||||
|
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
|
||||||
|
|
||||||
|
for field, index in script_fields:
|
||||||
|
value = get_field_value(field, params)
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mentioned_script_args[index] = value
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
|
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
|
||||||
|
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
|
||||||
script_runner.initialize_scripts(False)
|
infotext_script_args = {}
|
||||||
ui.create_ui()
|
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||||
if not self.default_script_arg_txt2img:
|
|
||||||
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
|
|
||||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
|
||||||
|
|
||||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||||
@ -356,12 +450,15 @@ class Api:
|
|||||||
args.pop('script_name', None)
|
args.pop('script_name', None)
|
||||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||||
args.pop('alwayson_scripts', None)
|
args.pop('alwayson_scripts', None)
|
||||||
|
args.pop('infotext', None)
|
||||||
|
|
||||||
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
|
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||||
|
|
||||||
send_images = args.pop('send_images', True)
|
send_images = args.pop('send_images', True)
|
||||||
args.pop('save_images', None)
|
args.pop('save_images', None)
|
||||||
|
|
||||||
|
add_task_to_queue(task_id)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.is_api = True
|
p.is_api = True
|
||||||
@ -371,12 +468,14 @@ class Api:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="scripts_txt2img")
|
shared.state.begin(job="scripts_txt2img")
|
||||||
|
start_task(task_id)
|
||||||
if selectable_scripts is not None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
finish_task(task_id)
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
shared.total_tqdm.clear()
|
shared.total_tqdm.clear()
|
||||||
@ -386,6 +485,8 @@ class Api:
|
|||||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
|
task_id = img2imgreq.force_task_id or create_task_id("img2img")
|
||||||
|
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -395,11 +496,10 @@ class Api:
|
|||||||
mask = decode_base64_to_image(mask)
|
mask = decode_base64_to_image(mask)
|
||||||
|
|
||||||
script_runner = scripts.scripts_img2img
|
script_runner = scripts.scripts_img2img
|
||||||
if not script_runner.scripts:
|
|
||||||
script_runner.initialize_scripts(True)
|
infotext_script_args = {}
|
||||||
ui.create_ui()
|
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
|
||||||
if not self.default_script_arg_img2img:
|
|
||||||
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
|
|
||||||
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
|
||||||
|
|
||||||
populate = img2imgreq.copy(update={ # Override __init__ params
|
populate = img2imgreq.copy(update={ # Override __init__ params
|
||||||
@ -416,12 +516,15 @@ class Api:
|
|||||||
args.pop('script_name', None)
|
args.pop('script_name', None)
|
||||||
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
|
||||||
args.pop('alwayson_scripts', None)
|
args.pop('alwayson_scripts', None)
|
||||||
|
args.pop('infotext', None)
|
||||||
|
|
||||||
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
|
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
|
||||||
|
|
||||||
send_images = args.pop('send_images', True)
|
send_images = args.pop('send_images', True)
|
||||||
args.pop('save_images', None)
|
args.pop('save_images', None)
|
||||||
|
|
||||||
|
add_task_to_queue(task_id)
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
|
||||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||||
@ -432,12 +535,14 @@ class Api:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="scripts_img2img")
|
shared.state.begin(job="scripts_img2img")
|
||||||
|
start_task(task_id)
|
||||||
if selectable_scripts is not None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
finish_task(task_id)
|
||||||
finally:
|
finally:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
shared.total_tqdm.clear()
|
shared.total_tqdm.clear()
|
||||||
@ -480,7 +585,7 @@ class Api:
|
|||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
geninfo = ""
|
geninfo = ""
|
||||||
|
|
||||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
params = infotext_utils.parse_generation_parameters(geninfo)
|
||||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||||
|
|
||||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||||
@ -511,7 +616,7 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
@ -643,6 +748,10 @@ class Api:
|
|||||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def refresh_embeddings(self):
|
||||||
|
with self.queue_lock:
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||||
|
|
||||||
def refresh_checkpoints(self):
|
def refresh_checkpoints(self):
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
@ -775,7 +884,15 @@ class Api:
|
|||||||
|
|
||||||
def launch(self, server_name, port, root_path):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
uvicorn.run(
|
||||||
|
self.app,
|
||||||
|
host=server_name,
|
||||||
|
port=port,
|
||||||
|
timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
|
||||||
|
root_path=root_path,
|
||||||
|
ssl_keyfile=shared.cmd_opts.tls_keyfile,
|
||||||
|
ssl_certfile=shared.cmd_opts.tls_certfile
|
||||||
|
)
|
||||||
|
|
||||||
def kill_webui(self):
|
def kill_webui(self):
|
||||||
restart.stop_program()
|
restart.stop_program()
|
||||||
|
@ -107,6 +107,8 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
|||||||
{"key": "send_images", "type": bool, "default": True},
|
{"key": "send_images", "type": bool, "default": True},
|
||||||
{"key": "save_images", "type": bool, "default": False},
|
{"key": "save_images", "type": bool, "default": False},
|
||||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||||
|
{"key": "force_task_id", "type": str, "default": None},
|
||||||
|
{"key": "infotext", "type": str, "default": None},
|
||||||
]
|
]
|
||||||
).generate_model()
|
).generate_model()
|
||||||
|
|
||||||
@ -124,6 +126,8 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
|||||||
{"key": "send_images", "type": bool, "default": True},
|
{"key": "send_images", "type": bool, "default": True},
|
||||||
{"key": "save_images", "type": bool, "default": False},
|
{"key": "save_images", "type": bool, "default": False},
|
||||||
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
{"key": "alwayson_scripts", "type": dict, "default": {}},
|
||||||
|
{"key": "force_task_id", "type": str, "default": None},
|
||||||
|
{"key": "infotext", "type": str, "default": None},
|
||||||
]
|
]
|
||||||
).generate_model()
|
).generate_model()
|
||||||
|
|
||||||
|
@ -62,12 +62,11 @@ def cache(subsection):
|
|||||||
if cache_data is None:
|
if cache_data is None:
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
if cache_data is None:
|
if cache_data is None:
|
||||||
if not os.path.isfile(cache_filename):
|
|
||||||
cache_data = {}
|
|
||||||
else:
|
|
||||||
try:
|
try:
|
||||||
with open(cache_filename, "r", encoding="utf8") as file:
|
with open(cache_filename, "r", encoding="utf8") as file:
|
||||||
cache_data = json.load(file)
|
cache_data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
cache_data = {}
|
||||||
except Exception:
|
except Exception:
|
||||||
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
os.replace(cache_filename, os.path.join(script_path, "tmp", "cache.json"))
|
||||||
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
print('[ERROR] issue occurred while trying to read cache.json, move current cache to tmp/cache.json and create new cache')
|
||||||
|
@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.stopping_generation = False
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
if not add_stats:
|
if not add_stats:
|
||||||
|
@ -77,7 +77,9 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
|
|||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
|
||||||
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
|
||||||
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
|
parser.add_argument("--freeze-settings", action='store_true', help="disable editing of all settings globally", default=False)
|
||||||
|
parser.add_argument("--freeze-settings-in-sections", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like "saving-images,upscaling". The list of setting names can be found in the modules/shared_options.py file', default=None)
|
||||||
|
parser.add_argument("--freeze-specific-settings", type=str, help='disable editing of individual settings by specifying a comma-delimited list like "samples_save,samples_format". The list of setting names can be found in the config.json file', default=None)
|
||||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
|
||||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
@ -86,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
|
|||||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
@ -1,276 +0,0 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn, Tensor
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
|
||||||
"""Calculate mean and std for adaptive_instance_normalization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat (Tensor): 4D tensor.
|
|
||||||
eps (float): A small value added to the variance to avoid
|
|
||||||
divide-by-zero. Default: 1e-5.
|
|
||||||
"""
|
|
||||||
size = feat.size()
|
|
||||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
|
||||||
b, c = size[:2]
|
|
||||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
|
||||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
|
||||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
|
||||||
return feat_mean, feat_std
|
|
||||||
|
|
||||||
|
|
||||||
def adaptive_instance_normalization(content_feat, style_feat):
|
|
||||||
"""Adaptive instance normalization.
|
|
||||||
|
|
||||||
Adjust the reference features to have the similar color and illuminations
|
|
||||||
as those in the degradate features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content_feat (Tensor): The reference feature.
|
|
||||||
style_feat (Tensor): The degradate features.
|
|
||||||
"""
|
|
||||||
size = content_feat.size()
|
|
||||||
style_mean, style_std = calc_mean_std(style_feat)
|
|
||||||
content_mean, content_std = calc_mean_std(content_feat)
|
|
||||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
|
||||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
|
||||||
|
|
||||||
|
|
||||||
class PositionEmbeddingSine(nn.Module):
|
|
||||||
"""
|
|
||||||
This is a more standard version of the position embedding, very similar to the one
|
|
||||||
used by the Attention is all you need paper, generalized to work on images.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_pos_feats = num_pos_feats
|
|
||||||
self.temperature = temperature
|
|
||||||
self.normalize = normalize
|
|
||||||
if scale is not None and normalize is False:
|
|
||||||
raise ValueError("normalize should be True if scale is passed")
|
|
||||||
if scale is None:
|
|
||||||
scale = 2 * math.pi
|
|
||||||
self.scale = scale
|
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
|
||||||
if mask is None:
|
|
||||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
|
||||||
not_mask = ~mask
|
|
||||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
|
||||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
|
||||||
if self.normalize:
|
|
||||||
eps = 1e-6
|
|
||||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
||||||
|
|
||||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
||||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
||||||
|
|
||||||
pos_x = x_embed[:, :, :, None] / dim_t
|
|
||||||
pos_y = y_embed[:, :, :, None] / dim_t
|
|
||||||
pos_x = torch.stack(
|
|
||||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos_y = torch.stack(
|
|
||||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
||||||
).flatten(3)
|
|
||||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
||||||
return pos
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
"""Return an activation function given a string"""
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
if activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
if activation == "glu":
|
|
||||||
return F.glu
|
|
||||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerSALayer(nn.Module):
|
|
||||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
|
||||||
super().__init__()
|
|
||||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model - MLP
|
|
||||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(embed_dim)
|
|
||||||
self.norm2 = nn.LayerNorm(embed_dim)
|
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
||||||
return tensor if pos is None else tensor + pos
|
|
||||||
|
|
||||||
def forward(self, tgt,
|
|
||||||
tgt_mask: Optional[Tensor] = None,
|
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
query_pos: Optional[Tensor] = None):
|
|
||||||
|
|
||||||
# self attention
|
|
||||||
tgt2 = self.norm1(tgt)
|
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
||||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
|
|
||||||
# ffn
|
|
||||||
tgt2 = self.norm2(tgt)
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
class Fuse_sft_block(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
|
||||||
|
|
||||||
self.scale = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
|
||||||
|
|
||||||
self.shift = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
|
||||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
|
||||||
|
|
||||||
def forward(self, enc_feat, dec_feat, w=1):
|
|
||||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
|
||||||
scale = self.scale(enc_feat)
|
|
||||||
shift = self.shift(enc_feat)
|
|
||||||
residual = w * (dec_feat * scale + shift)
|
|
||||||
out = dec_feat + residual
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class CodeFormer(VQAutoEncoder):
|
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
|
||||||
codebook_size=1024, latent_size=256,
|
|
||||||
connect_list=('32', '64', '128', '256'),
|
|
||||||
fix_modules=('quantize', 'generator')):
|
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
|
||||||
|
|
||||||
if fix_modules is not None:
|
|
||||||
for module in fix_modules:
|
|
||||||
for param in getattr(self, module).parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
self.connect_list = connect_list
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.dim_embd = dim_embd
|
|
||||||
self.dim_mlp = dim_embd*2
|
|
||||||
|
|
||||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
|
||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
|
||||||
|
|
||||||
# transformer
|
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
|
||||||
for _ in range(self.n_layers)])
|
|
||||||
|
|
||||||
# logits_predict head
|
|
||||||
self.idx_pred_layer = nn.Sequential(
|
|
||||||
nn.LayerNorm(dim_embd),
|
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
|
||||||
|
|
||||||
self.channels = {
|
|
||||||
'16': 512,
|
|
||||||
'32': 256,
|
|
||||||
'64': 256,
|
|
||||||
'128': 128,
|
|
||||||
'256': 128,
|
|
||||||
'512': 64,
|
|
||||||
}
|
|
||||||
|
|
||||||
# after second residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
|
||||||
# after first residual block for > 16, before attn layer for ==16
|
|
||||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
|
||||||
|
|
||||||
# fuse_convs_dict
|
|
||||||
self.fuse_convs_dict = nn.ModuleDict()
|
|
||||||
for f_size in self.connect_list:
|
|
||||||
in_ch = self.channels[f_size]
|
|
||||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
||||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.LayerNorm):
|
|
||||||
module.bias.data.zero_()
|
|
||||||
module.weight.data.fill_(1.0)
|
|
||||||
|
|
||||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
|
||||||
# ################### Encoder #####################
|
|
||||||
enc_feat_dict = {}
|
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in out_list:
|
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
|
||||||
|
|
||||||
lq_feat = x
|
|
||||||
# ################# Transformer ###################
|
|
||||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
|
||||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
|
||||||
# BCHW -> BC(HW) -> (HW)BC
|
|
||||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
|
||||||
query_emb = feat_emb
|
|
||||||
# Transformer encoder
|
|
||||||
for layer in self.ft_layers:
|
|
||||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
|
||||||
|
|
||||||
# output logits
|
|
||||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
|
||||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
|
||||||
|
|
||||||
if code_only: # for training stage II
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return logits, lq_feat
|
|
||||||
|
|
||||||
# ################# Quantization ###################
|
|
||||||
# if self.training:
|
|
||||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
|
||||||
# # b(hw)c -> bc(hw) -> bchw
|
|
||||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
|
||||||
# ------------
|
|
||||||
soft_one_hot = F.softmax(logits, dim=2)
|
|
||||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
|
||||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
|
||||||
# preserve gradients
|
|
||||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
|
||||||
|
|
||||||
if detach_16:
|
|
||||||
quant_feat = quant_feat.detach() # for training stage III
|
|
||||||
if adain:
|
|
||||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
|
||||||
|
|
||||||
# ################## Generator ####################
|
|
||||||
x = quant_feat
|
|
||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
|
||||||
x = block(x)
|
|
||||||
if i in fuse_list: # fuse after i-th block
|
|
||||||
f_size = str(x.shape[-1])
|
|
||||||
if w>0:
|
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
|
||||||
out = x
|
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
|
||||||
return out, logits, lq_feat
|
|
@ -1,435 +0,0 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
|
||||||
|
|
||||||
'''
|
|
||||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
|
||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
|
||||||
|
|
||||||
'''
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
|
|
||||||
def normalize(in_channels):
|
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def swish(x):
|
|
||||||
return x*torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
# Define VQVAE classes
|
|
||||||
class VectorQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, beta):
|
|
||||||
super(VectorQuantizer, self).__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
|
||||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
|
||||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
# reshape z -> (batch, height, width, channel) and flatten
|
|
||||||
z = z.permute(0, 2, 3, 1).contiguous()
|
|
||||||
z_flattened = z.view(-1, self.emb_dim)
|
|
||||||
|
|
||||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
|
||||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
|
||||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
|
||||||
|
|
||||||
mean_distance = torch.mean(d)
|
|
||||||
# find closest encodings
|
|
||||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
|
||||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
|
||||||
# [0-1], higher score, higher confidence
|
|
||||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
|
||||||
|
|
||||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
|
||||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
|
||||||
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
|
||||||
# compute loss for embedding
|
|
||||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
|
||||||
# preserve gradients
|
|
||||||
z_q = z + (z_q - z).detach()
|
|
||||||
|
|
||||||
# perplexity
|
|
||||||
e_mean = torch.mean(min_encodings, dim=0)
|
|
||||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
|
||||||
# reshape back to match original input shape
|
|
||||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q, loss, {
|
|
||||||
"perplexity": perplexity,
|
|
||||||
"min_encodings": min_encodings,
|
|
||||||
"min_encoding_indices": min_encoding_indices,
|
|
||||||
"min_encoding_scores": min_encoding_scores,
|
|
||||||
"mean_distance": mean_distance
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_codebook_feat(self, indices, shape):
|
|
||||||
# input indices: batch*token_num -> (batch*token_num)*1
|
|
||||||
# shape: batch, height, width, channel
|
|
||||||
indices = indices.view(-1,1)
|
|
||||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
|
||||||
min_encodings.scatter_(1, indices, 1)
|
|
||||||
# get quantized latent vectors
|
|
||||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
|
||||||
|
|
||||||
if shape is not None: # reshape back to match original input shape
|
|
||||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
|
||||||
|
|
||||||
return z_q
|
|
||||||
|
|
||||||
|
|
||||||
class GumbelQuantizer(nn.Module):
|
|
||||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
|
||||||
super().__init__()
|
|
||||||
self.codebook_size = codebook_size # number of embeddings
|
|
||||||
self.emb_dim = emb_dim # dimension of embedding
|
|
||||||
self.straight_through = straight_through
|
|
||||||
self.temperature = temp_init
|
|
||||||
self.kl_weight = kl_weight
|
|
||||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
|
||||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
|
||||||
|
|
||||||
def forward(self, z):
|
|
||||||
hard = self.straight_through if self.training else True
|
|
||||||
|
|
||||||
logits = self.proj(z)
|
|
||||||
|
|
||||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
|
||||||
|
|
||||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
|
||||||
|
|
||||||
# + kl divergence to the prior loss
|
|
||||||
qy = F.softmax(logits, dim=1)
|
|
||||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
|
||||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
|
||||||
|
|
||||||
return z_q, diff, {
|
|
||||||
"min_encoding_indices": min_encoding_indices
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Downsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
pad = (0, 1, 0, 1)
|
|
||||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
|
||||||
x = self.conv(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels=None):
|
|
||||||
super(ResBlock, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels if out_channels is None else out_channels
|
|
||||||
self.norm1 = normalize(in_channels)
|
|
||||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
self.norm2 = normalize(out_channels)
|
|
||||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
||||||
|
|
||||||
def forward(self, x_in):
|
|
||||||
x = x_in
|
|
||||||
x = self.norm1(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = swish(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
if self.in_channels != self.out_channels:
|
|
||||||
x_in = self.conv_out(x_in)
|
|
||||||
|
|
||||||
return x + x_in
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = normalize(in_channels)
|
|
||||||
self.q = torch.nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0
|
|
||||||
)
|
|
||||||
self.k = torch.nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0
|
|
||||||
)
|
|
||||||
self.v = torch.nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0
|
|
||||||
)
|
|
||||||
self.proj_out = torch.nn.Conv2d(
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q.shape
|
|
||||||
q = q.reshape(b, c, h*w)
|
|
||||||
q = q.permute(0, 2, 1)
|
|
||||||
k = k.reshape(b, c, h*w)
|
|
||||||
w_ = torch.bmm(q, k)
|
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
|
||||||
w_ = F.softmax(w_, dim=2)
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v = v.reshape(b, c, h*w)
|
|
||||||
w_ = w_.permute(0, 2, 1)
|
|
||||||
h_ = torch.bmm(v, w_)
|
|
||||||
h_ = h_.reshape(b, c, h, w)
|
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
|
||||||
|
|
||||||
return x+h_
|
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
|
||||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.num_resolutions = len(ch_mult)
|
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.resolution = resolution
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
|
|
||||||
curr_res = self.resolution
|
|
||||||
in_ch_mult = (1,)+tuple(ch_mult)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial convultion
|
|
||||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
|
||||||
for i in range(self.num_resolutions):
|
|
||||||
block_in_ch = nf * in_ch_mult[i]
|
|
||||||
block_out_ch = nf * ch_mult[i]
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
if curr_res in attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != self.num_resolutions - 1:
|
|
||||||
blocks.append(Downsample(block_in_ch))
|
|
||||||
curr_res = curr_res // 2
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
# normalise and convert to latent size
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
|
||||||
super().__init__()
|
|
||||||
self.nf = nf
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.num_resolutions = len(self.ch_mult)
|
|
||||||
self.num_res_blocks = res_blocks
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions
|
|
||||||
self.in_channels = emb_dim
|
|
||||||
self.out_channels = 3
|
|
||||||
block_in_ch = self.nf * self.ch_mult[-1]
|
|
||||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
|
||||||
|
|
||||||
blocks = []
|
|
||||||
# initial conv
|
|
||||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
# non-local attention block
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
|
||||||
|
|
||||||
for i in reversed(range(self.num_resolutions)):
|
|
||||||
block_out_ch = self.nf * self.ch_mult[i]
|
|
||||||
|
|
||||||
for _ in range(self.num_res_blocks):
|
|
||||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
|
||||||
block_in_ch = block_out_ch
|
|
||||||
|
|
||||||
if curr_res in self.attn_resolutions:
|
|
||||||
blocks.append(AttnBlock(block_in_ch))
|
|
||||||
|
|
||||||
if i != 0:
|
|
||||||
blocks.append(Upsample(block_in_ch))
|
|
||||||
curr_res = curr_res * 2
|
|
||||||
|
|
||||||
blocks.append(normalize(block_in_ch))
|
|
||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
for block in self.blocks:
|
|
||||||
x = block(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQAutoEncoder(nn.Module):
|
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
|
||||||
super().__init__()
|
|
||||||
logger = get_root_logger()
|
|
||||||
self.in_channels = 3
|
|
||||||
self.nf = nf
|
|
||||||
self.n_blocks = res_blocks
|
|
||||||
self.codebook_size = codebook_size
|
|
||||||
self.embed_dim = emb_dim
|
|
||||||
self.ch_mult = ch_mult
|
|
||||||
self.resolution = img_size
|
|
||||||
self.attn_resolutions = attn_resolutions or [16]
|
|
||||||
self.quantizer_type = quantizer
|
|
||||||
self.encoder = Encoder(
|
|
||||||
self.in_channels,
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions
|
|
||||||
)
|
|
||||||
if self.quantizer_type == "nearest":
|
|
||||||
self.beta = beta #0.25
|
|
||||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
|
||||||
elif self.quantizer_type == "gumbel":
|
|
||||||
self.gumbel_num_hiddens = emb_dim
|
|
||||||
self.straight_through = gumbel_straight_through
|
|
||||||
self.kl_weight = gumbel_kl_weight
|
|
||||||
self.quantize = GumbelQuantizer(
|
|
||||||
self.codebook_size,
|
|
||||||
self.embed_dim,
|
|
||||||
self.gumbel_num_hiddens,
|
|
||||||
self.straight_through,
|
|
||||||
self.kl_weight
|
|
||||||
)
|
|
||||||
self.generator = Generator(
|
|
||||||
self.nf,
|
|
||||||
self.embed_dim,
|
|
||||||
self.ch_mult,
|
|
||||||
self.n_blocks,
|
|
||||||
self.resolution,
|
|
||||||
self.attn_resolutions
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location='cpu')
|
|
||||||
if 'params_ema' in chkpt:
|
|
||||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
|
||||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
|
||||||
elif 'params' in chkpt:
|
|
||||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
|
||||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong params!')
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encoder(x)
|
|
||||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
|
||||||
x = self.generator(quant)
|
|
||||||
return x, codebook_loss, quant_stats
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# patch based discriminator
|
|
||||||
@ARCH_REGISTRY.register()
|
|
||||||
class VQGANDiscriminator(nn.Module):
|
|
||||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
|
||||||
ndf_mult = 1
|
|
||||||
ndf_mult_prev = 1
|
|
||||||
for n in range(1, n_layers): # gradually increase the number of filters
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2 ** n, 8)
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True)
|
|
||||||
]
|
|
||||||
|
|
||||||
ndf_mult_prev = ndf_mult
|
|
||||||
ndf_mult = min(2 ** n_layers, 8)
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
|
||||||
nn.BatchNorm2d(ndf * ndf_mult),
|
|
||||||
nn.LeakyReLU(0.2, True)
|
|
||||||
]
|
|
||||||
|
|
||||||
layers += [
|
|
||||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
|
||||||
self.main = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
chkpt = torch.load(model_path, map_location='cpu')
|
|
||||||
if 'params_d' in chkpt:
|
|
||||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
|
||||||
elif 'params' in chkpt:
|
|
||||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
|
||||||
else:
|
|
||||||
raise ValueError('Wrong params!')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.main(x)
|
|
@ -1,132 +1,64 @@
|
|||||||
import os
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import cv2
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import modules.face_restoration
|
from modules import (
|
||||||
import modules.shared
|
devices,
|
||||||
from modules import shared, devices, modelloader, errors
|
errors,
|
||||||
from modules.paths import models_path
|
face_restoration,
|
||||||
|
face_restoration_utils,
|
||||||
|
modelloader,
|
||||||
|
shared,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# codeformer people made a choice to include modified basicsr library to their project which makes
|
|
||||||
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
|
||||||
# I am making a choice to include some files from codeformer to work around this issue.
|
|
||||||
model_dir = "Codeformer"
|
|
||||||
model_path = os.path.join(models_path, model_dir)
|
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
model_download_name = 'codeformer-v0.1.0.pth'
|
||||||
|
|
||||||
codeformer = None
|
# used by e.g. postprocessing_codeformer.py
|
||||||
|
codeformer: face_restoration.FaceRestoration | None = None
|
||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||||
os.makedirs(model_path, exist_ok=True)
|
|
||||||
|
|
||||||
path = modules.paths.paths.get("CodeFormer", None)
|
|
||||||
if path is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
|
||||||
from facelib.detection.retinaface import retinaface
|
|
||||||
|
|
||||||
net_class = CodeFormer
|
|
||||||
|
|
||||||
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
|
||||||
def name(self):
|
def name(self):
|
||||||
return "CodeFormer"
|
return "CodeFormer"
|
||||||
|
|
||||||
def __init__(self, dirname):
|
def load_net(self) -> torch.Module:
|
||||||
self.net = None
|
for model_path in modelloader.load_models(
|
||||||
self.face_helper = None
|
model_path=self.model_path,
|
||||||
self.cmd_dir = dirname
|
model_url=model_url,
|
||||||
|
command_path=self.model_path,
|
||||||
|
download_name=model_download_name,
|
||||||
|
ext_filter=['.pth'],
|
||||||
|
):
|
||||||
|
return modelloader.load_spandrel_model(
|
||||||
|
model_path,
|
||||||
|
device=devices.device_codeformer,
|
||||||
|
expected_architecture='CodeFormer',
|
||||||
|
).model
|
||||||
|
raise ValueError("No codeformer model found")
|
||||||
|
|
||||||
def create_models(self):
|
def get_device(self):
|
||||||
|
return devices.device_codeformer
|
||||||
|
|
||||||
if self.net is not None and self.face_helper is not None:
|
def restore(self, np_image, w: float | None = None):
|
||||||
self.net.to(devices.device_codeformer)
|
if w is None:
|
||||||
return self.net, self.face_helper
|
w = getattr(shared.opts, "code_former_weight", 0.5)
|
||||||
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth'])
|
|
||||||
if len(model_paths) != 0:
|
|
||||||
ckpt_path = model_paths[0]
|
|
||||||
else:
|
|
||||||
print("Unable to load codeformer model.")
|
|
||||||
return None, None
|
|
||||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
|
||||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
|
||||||
net.load_state_dict(checkpoint)
|
|
||||||
net.eval()
|
|
||||||
|
|
||||||
if hasattr(retinaface, 'device'):
|
def restore_face(cropped_face_t):
|
||||||
retinaface.device = devices.device_codeformer
|
assert self.net is not None
|
||||||
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
|
return self.net(cropped_face_t, w=w, adain=True)[0]
|
||||||
|
|
||||||
self.net = net
|
return self.restore_with_helper(np_image, restore_face)
|
||||||
self.face_helper = face_helper
|
|
||||||
|
|
||||||
return net, face_helper
|
|
||||||
|
|
||||||
def send_model_to(self, device):
|
|
||||||
self.net.to(device)
|
|
||||||
self.face_helper.face_det.to(device)
|
|
||||||
self.face_helper.face_parse.to(device)
|
|
||||||
|
|
||||||
def restore(self, np_image, w=None):
|
|
||||||
np_image = np_image[:, :, ::-1]
|
|
||||||
|
|
||||||
original_resolution = np_image.shape[0:2]
|
|
||||||
|
|
||||||
self.create_models()
|
|
||||||
if self.net is None or self.face_helper is None:
|
|
||||||
return np_image
|
|
||||||
|
|
||||||
self.send_model_to(devices.device_codeformer)
|
|
||||||
|
|
||||||
self.face_helper.clean_all()
|
|
||||||
self.face_helper.read_image(np_image)
|
|
||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
|
||||||
self.face_helper.align_warp_face()
|
|
||||||
|
|
||||||
for cropped_face in self.face_helper.cropped_faces:
|
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
|
|
||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
|
||||||
del output
|
|
||||||
devices.torch_gc()
|
|
||||||
except Exception:
|
|
||||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
|
||||||
|
|
||||||
restored_face = restored_face.astype('uint8')
|
|
||||||
self.face_helper.add_restored_face(restored_face)
|
|
||||||
|
|
||||||
self.face_helper.get_inverse_affine(None)
|
|
||||||
|
|
||||||
restored_img = self.face_helper.paste_faces_to_input_image()
|
|
||||||
restored_img = restored_img[:, :, ::-1]
|
|
||||||
|
|
||||||
if original_resolution != restored_img.shape[0:2]:
|
|
||||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
|
||||||
|
|
||||||
self.face_helper.clean_all()
|
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
|
||||||
self.send_model_to(devices.cpu)
|
|
||||||
|
|
||||||
return restored_img
|
|
||||||
|
|
||||||
|
def setup_model(dirname: str) -> None:
|
||||||
global codeformer
|
global codeformer
|
||||||
|
try:
|
||||||
codeformer = FaceRestorerCodeFormer(dirname)
|
codeformer = FaceRestorerCodeFormer(dirname)
|
||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report("Error setting up CodeFormer", exc_info=True)
|
errors.report("Error setting up CodeFormer", exc_info=True)
|
||||||
|
|
||||||
# sys.path = stored_sys_path
|
|
||||||
|
79
modules/dat_model.py
Normal file
79
modules/dat_model.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from modules import modelloader, errors
|
||||||
|
from modules.shared import cmd_opts, opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerDAT(Upscaler):
|
||||||
|
def __init__(self, user_path):
|
||||||
|
self.name = "DAT"
|
||||||
|
self.user_path = user_path
|
||||||
|
self.scalers = []
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
scaler_data = UpscalerData(name, file, upscaler=self, scale=None)
|
||||||
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
|
for model in get_dat_models(self):
|
||||||
|
if model.name in opts.dat_enabled_models:
|
||||||
|
self.scalers.append(model)
|
||||||
|
|
||||||
|
def do_upscale(self, img, path):
|
||||||
|
try:
|
||||||
|
info = self.load_model(path)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Unable to load DAT model {path}", exc_info=True)
|
||||||
|
return img
|
||||||
|
|
||||||
|
model_descriptor = modelloader.load_spandrel_model(
|
||||||
|
info.local_data_path,
|
||||||
|
device=self.device,
|
||||||
|
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||||
|
expected_architecture="DAT",
|
||||||
|
)
|
||||||
|
return upscale_with_model(
|
||||||
|
model_descriptor,
|
||||||
|
img,
|
||||||
|
tile_size=opts.DAT_tile,
|
||||||
|
tile_overlap=opts.DAT_tile_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self, path):
|
||||||
|
for scaler in self.scalers:
|
||||||
|
if scaler.data_path == path:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
|
scaler.data_path,
|
||||||
|
model_dir=self.model_download_path,
|
||||||
|
)
|
||||||
|
if not os.path.exists(scaler.local_data_path):
|
||||||
|
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
|
||||||
|
return scaler
|
||||||
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_dat_models(scaler):
|
||||||
|
return [
|
||||||
|
UpscalerData(
|
||||||
|
name="DAT x2",
|
||||||
|
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
|
||||||
|
scale=2,
|
||||||
|
upscaler=scaler,
|
||||||
|
),
|
||||||
|
UpscalerData(
|
||||||
|
name="DAT x3",
|
||||||
|
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
|
||||||
|
scale=3,
|
||||||
|
upscaler=scaler,
|
||||||
|
),
|
||||||
|
UpscalerData(
|
||||||
|
name="DAT x4",
|
||||||
|
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
|
||||||
|
scale=4,
|
||||||
|
upscaler=scaler,
|
||||||
|
),
|
||||||
|
]
|
@ -23,6 +23,23 @@ def has_mps() -> bool:
|
|||||||
return mac_specific.has_mps
|
return mac_specific.has_mps
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_no_autocast(device_id=None) -> bool:
|
||||||
|
if device_id is None:
|
||||||
|
device_id = get_cuda_device_id()
|
||||||
|
return (
|
||||||
|
torch.cuda.get_device_capability(device_id) == (7, 5)
|
||||||
|
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_device_id():
|
||||||
|
return (
|
||||||
|
int(shared.cmd_opts.device_id)
|
||||||
|
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
|
||||||
|
else 0
|
||||||
|
) or torch.cuda.current_device()
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_device_string():
|
def get_cuda_device_string():
|
||||||
if shared.cmd_opts.device_id is not None:
|
if shared.cmd_opts.device_id is not None:
|
||||||
return f"cuda:{shared.cmd_opts.device_id}"
|
return f"cuda:{shared.cmd_opts.device_id}"
|
||||||
@ -79,8 +96,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
if cuda_no_autocast():
|
||||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@ -90,6 +106,7 @@ def enable_tf32():
|
|||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
cpu: torch.device = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
|
fp8: bool = False
|
||||||
device: torch.device = None
|
device: torch.device = None
|
||||||
device_interrogate: torch.device = None
|
device_interrogate: torch.device = None
|
||||||
device_gfpgan: torch.device = None
|
device_gfpgan: torch.device = None
|
||||||
@ -98,6 +115,7 @@ device_codeformer: torch.device = None
|
|||||||
dtype: torch.dtype = torch.float16
|
dtype: torch.dtype = torch.float16
|
||||||
dtype_vae: torch.dtype = torch.float16
|
dtype_vae: torch.dtype = torch.float16
|
||||||
dtype_unet: torch.dtype = torch.float16
|
dtype_unet: torch.dtype = torch.float16
|
||||||
|
dtype_inference: torch.dtype = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
@ -110,15 +128,89 @@ def cond_cast_float(input):
|
|||||||
|
|
||||||
|
|
||||||
nv_rng = None
|
nv_rng = None
|
||||||
|
patch_module_list = [
|
||||||
|
torch.nn.Linear,
|
||||||
|
torch.nn.Conv2d,
|
||||||
|
torch.nn.MultiheadAttention,
|
||||||
|
torch.nn.GroupNorm,
|
||||||
|
torch.nn.LayerNorm,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def manual_cast_forward(target_dtype):
|
||||||
|
def forward_wrapper(self, *args, **kwargs):
|
||||||
|
if any(
|
||||||
|
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||||
|
for arg in args
|
||||||
|
):
|
||||||
|
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||||
|
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||||
|
|
||||||
|
org_dtype = target_dtype
|
||||||
|
for param in self.parameters():
|
||||||
|
if param.dtype != target_dtype:
|
||||||
|
org_dtype = param.dtype
|
||||||
|
break
|
||||||
|
|
||||||
|
if org_dtype != target_dtype:
|
||||||
|
self.to(target_dtype)
|
||||||
|
result = self.org_forward(*args, **kwargs)
|
||||||
|
if org_dtype != target_dtype:
|
||||||
|
self.to(org_dtype)
|
||||||
|
|
||||||
|
if target_dtype != dtype_inference:
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
result = tuple(
|
||||||
|
i.to(dtype_inference)
|
||||||
|
if isinstance(i, torch.Tensor)
|
||||||
|
else i
|
||||||
|
for i in result
|
||||||
|
)
|
||||||
|
elif isinstance(result, torch.Tensor):
|
||||||
|
result = result.to(dtype_inference)
|
||||||
|
return result
|
||||||
|
return forward_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def manual_cast(target_dtype):
|
||||||
|
applied = False
|
||||||
|
for module_type in patch_module_list:
|
||||||
|
if hasattr(module_type, "org_forward"):
|
||||||
|
continue
|
||||||
|
applied = True
|
||||||
|
org_forward = module_type.forward
|
||||||
|
if module_type == torch.nn.MultiheadAttention:
|
||||||
|
module_type.forward = manual_cast_forward(torch.float32)
|
||||||
|
else:
|
||||||
|
module_type.forward = manual_cast_forward(target_dtype)
|
||||||
|
module_type.org_forward = org_forward
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if applied:
|
||||||
|
for module_type in patch_module_list:
|
||||||
|
if hasattr(module_type, "org_forward"):
|
||||||
|
module_type.forward = module_type.org_forward
|
||||||
|
delattr(module_type, "org_forward")
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
if disable:
|
if disable:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
if fp8 and device==cpu:
|
||||||
|
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||||
|
|
||||||
|
if fp8 and dtype_inference == torch.float32:
|
||||||
|
return manual_cast(dtype)
|
||||||
|
|
||||||
|
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||||
|
return manual_cast(dtype)
|
||||||
|
|
||||||
return torch.autocast("cuda")
|
return torch.autocast("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,8 +107,8 @@ def check_versions():
|
|||||||
import torch
|
import torch
|
||||||
import gradio
|
import gradio
|
||||||
|
|
||||||
expected_torch_version = "2.0.0"
|
expected_torch_version = "2.1.2"
|
||||||
expected_xformers_version = "0.0.20"
|
expected_xformers_version = "0.0.23.post1"
|
||||||
expected_gradio_version = "3.41.2"
|
expected_gradio_version = "3.41.2"
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
|
@ -1,121 +1,7 @@
|
|||||||
import sys
|
from modules import modelloader, devices, errors
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
|
||||||
from modules import modelloader, images, devices
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
|
||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
|
||||||
if 'conv_first.weight' in state_dict:
|
|
||||||
crt_net = {}
|
|
||||||
items = list(state_dict)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
|
||||||
|
|
||||||
for k in items.copy():
|
|
||||||
if 'RDB' in k:
|
|
||||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
|
||||||
if '.weight' in k:
|
|
||||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
|
||||||
elif '.bias' in k:
|
|
||||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
|
||||||
crt_net[ori_k] = state_dict[k]
|
|
||||||
items.remove(k)
|
|
||||||
|
|
||||||
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
|
|
||||||
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
|
|
||||||
crt_net['model.3.weight'] = state_dict['upconv1.weight']
|
|
||||||
crt_net['model.3.bias'] = state_dict['upconv1.bias']
|
|
||||||
crt_net['model.6.weight'] = state_dict['upconv2.weight']
|
|
||||||
crt_net['model.6.bias'] = state_dict['upconv2.bias']
|
|
||||||
crt_net['model.8.weight'] = state_dict['HRconv.weight']
|
|
||||||
crt_net['model.8.bias'] = state_dict['HRconv.bias']
|
|
||||||
crt_net['model.10.weight'] = state_dict['conv_last.weight']
|
|
||||||
crt_net['model.10.bias'] = state_dict['conv_last.bias']
|
|
||||||
state_dict = crt_net
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def resrgan2normal(state_dict, nb=23):
|
|
||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
|
||||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
|
||||||
re8x = 0
|
|
||||||
crt_net = {}
|
|
||||||
items = list(state_dict)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
|
||||||
|
|
||||||
for k in items.copy():
|
|
||||||
if "rdb" in k:
|
|
||||||
ori_k = k.replace('body.', 'model.1.sub.')
|
|
||||||
ori_k = ori_k.replace('.rdb', '.RDB')
|
|
||||||
if '.weight' in k:
|
|
||||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
|
||||||
elif '.bias' in k:
|
|
||||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
|
||||||
crt_net[ori_k] = state_dict[k]
|
|
||||||
items.remove(k)
|
|
||||||
|
|
||||||
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
|
|
||||||
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
|
|
||||||
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
|
|
||||||
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
|
|
||||||
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
|
|
||||||
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
|
|
||||||
|
|
||||||
if 'conv_up3.weight' in state_dict:
|
|
||||||
# modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
|
|
||||||
re8x = 3
|
|
||||||
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
|
|
||||||
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
|
|
||||||
|
|
||||||
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
|
|
||||||
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
|
|
||||||
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
|
|
||||||
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
|
|
||||||
|
|
||||||
state_dict = crt_net
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def infer_params(state_dict):
|
|
||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
|
||||||
scale2x = 0
|
|
||||||
scalemin = 6
|
|
||||||
n_uplayer = 0
|
|
||||||
plus = False
|
|
||||||
|
|
||||||
for block in list(state_dict):
|
|
||||||
parts = block.split(".")
|
|
||||||
n_parts = len(parts)
|
|
||||||
if n_parts == 5 and parts[2] == "sub":
|
|
||||||
nb = int(parts[3])
|
|
||||||
elif n_parts == 3:
|
|
||||||
part_num = int(parts[1])
|
|
||||||
if (part_num > scalemin
|
|
||||||
and parts[0] == "model"
|
|
||||||
and parts[2] == "weight"):
|
|
||||||
scale2x += 1
|
|
||||||
if part_num > n_uplayer:
|
|
||||||
n_uplayer = part_num
|
|
||||||
out_nc = state_dict[block].shape[0]
|
|
||||||
if not plus and "conv1x1" in block:
|
|
||||||
plus = True
|
|
||||||
|
|
||||||
nf = state_dict["model.0.weight"].shape[0]
|
|
||||||
in_nc = state_dict["model.0.weight"].shape[1]
|
|
||||||
out_nc = out_nc
|
|
||||||
scale = 2 ** scale2x
|
|
||||||
|
|
||||||
return in_nc, out_nc, nf, nb, plus, scale
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerESRGAN(Upscaler):
|
class UpscalerESRGAN(Upscaler):
|
||||||
@ -143,12 +29,11 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
def do_upscale(self, img, selected_model):
|
def do_upscale(self, img, selected_model):
|
||||||
try:
|
try:
|
||||||
model = self.load_model(selected_model)
|
model = self.load_model(selected_model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
model.to(devices.device_esrgan)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
return esrgan_upscale(model, img)
|
||||||
return img
|
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if path.startswith("http"):
|
if path.startswith("http"):
|
||||||
@ -161,69 +46,17 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
return modelloader.load_spandrel_model(
|
||||||
|
filename,
|
||||||
if "params_ema" in state_dict:
|
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||||
state_dict = state_dict["params_ema"]
|
expected_architecture='ESRGAN',
|
||||||
elif "params" in state_dict:
|
)
|
||||||
state_dict = state_dict["params"]
|
|
||||||
num_conv = 16 if "realesr-animevideov3" in filename else 32
|
|
||||||
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
|
|
||||||
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
|
|
||||||
state_dict = resrgan2normal(state_dict, nb)
|
|
||||||
elif "conv_first.weight" in state_dict:
|
|
||||||
state_dict = mod2normal(state_dict)
|
|
||||||
elif "model.0.weight" not in state_dict:
|
|
||||||
raise Exception("The file is not a recognized ESRGAN model.")
|
|
||||||
|
|
||||||
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
|
|
||||||
|
|
||||||
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def upscale_without_tiling(model, img):
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
return Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
|
|
||||||
def esrgan_upscale(model, img):
|
def esrgan_upscale(model, img):
|
||||||
if opts.ESRGAN_tile == 0:
|
return upscale_with_model(
|
||||||
return upscale_without_tiling(model, img)
|
model,
|
||||||
|
img,
|
||||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
tile_size=opts.ESRGAN_tile,
|
||||||
newtiles = []
|
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||||
scale_factor = 1
|
)
|
||||||
|
|
||||||
for y, h, row in grid.tiles:
|
|
||||||
newrow = []
|
|
||||||
for tiledata in row:
|
|
||||||
x, w, tile = tiledata
|
|
||||||
|
|
||||||
output = upscale_without_tiling(model, tile)
|
|
||||||
scale_factor = output.width // tile.width
|
|
||||||
|
|
||||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
|
||||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
|
||||||
|
|
||||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
|
||||||
output = images.combine_grid(newgrid)
|
|
||||||
return output
|
|
||||||
|
@ -1,465 +0,0 @@
|
|||||||
# this file is adapted from https://github.com/victorca25/iNNfer
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# RRDBNet Generator
|
|
||||||
####################
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
|
||||||
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
|
|
||||||
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
|
|
||||||
finalact=None, gaussian_noise=False, plus=False):
|
|
||||||
super(RRDBNet, self).__init__()
|
|
||||||
n_upscale = int(math.log(upscale, 2))
|
|
||||||
if upscale == 3:
|
|
||||||
n_upscale = 1
|
|
||||||
|
|
||||||
self.resrgan_scale = 0
|
|
||||||
if in_nc % 16 == 0:
|
|
||||||
self.resrgan_scale = 1
|
|
||||||
elif in_nc != 4 and in_nc % 4 == 0:
|
|
||||||
self.resrgan_scale = 2
|
|
||||||
|
|
||||||
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
|
||||||
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
|
||||||
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
|
|
||||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
|
|
||||||
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
|
|
||||||
|
|
||||||
if upsample_mode == 'upconv':
|
|
||||||
upsample_block = upconv_block
|
|
||||||
elif upsample_mode == 'pixelshuffle':
|
|
||||||
upsample_block = pixelshuffle_block
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
|
||||||
if upscale == 3:
|
|
||||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
|
||||||
else:
|
|
||||||
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
|
|
||||||
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
|
|
||||||
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
|
|
||||||
|
|
||||||
outact = act(finalact) if finalact else None
|
|
||||||
|
|
||||||
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
|
|
||||||
*upsampler, HR_conv0, HR_conv1, outact)
|
|
||||||
|
|
||||||
def forward(self, x, outm=None):
|
|
||||||
if self.resrgan_scale == 1:
|
|
||||||
feat = pixel_unshuffle(x, scale=4)
|
|
||||||
elif self.resrgan_scale == 2:
|
|
||||||
feat = pixel_unshuffle(x, scale=2)
|
|
||||||
else:
|
|
||||||
feat = x
|
|
||||||
|
|
||||||
return self.model(feat)
|
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual in Residual Dense Block
|
|
||||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
|
||||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
|
||||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
|
||||||
super(RRDB, self).__init__()
|
|
||||||
# This is for backwards compatibility with existing models
|
|
||||||
if nr == 3:
|
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
|
||||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
|
||||||
gaussian_noise=gaussian_noise, plus=plus)
|
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
|
||||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
|
||||||
gaussian_noise=gaussian_noise, plus=plus)
|
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
|
||||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
|
||||||
gaussian_noise=gaussian_noise, plus=plus)
|
|
||||||
else:
|
|
||||||
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
|
|
||||||
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
|
|
||||||
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
|
|
||||||
self.RDBs = nn.Sequential(*RDB_list)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if hasattr(self, 'RDB1'):
|
|
||||||
out = self.RDB1(x)
|
|
||||||
out = self.RDB2(out)
|
|
||||||
out = self.RDB3(out)
|
|
||||||
else:
|
|
||||||
out = self.RDBs(x)
|
|
||||||
return out * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
class ResidualDenseBlock_5C(nn.Module):
|
|
||||||
"""
|
|
||||||
Residual Dense Block
|
|
||||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
|
||||||
Modified options that can be used:
|
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
|
|
||||||
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
|
|
||||||
spectral_norm=False, gaussian_noise=False, plus=False):
|
|
||||||
super(ResidualDenseBlock_5C, self).__init__()
|
|
||||||
|
|
||||||
self.noise = GaussianNoise() if gaussian_noise else None
|
|
||||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
|
||||||
|
|
||||||
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
|
||||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
|
||||||
spectral_norm=spectral_norm)
|
|
||||||
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
|
||||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
|
||||||
spectral_norm=spectral_norm)
|
|
||||||
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
|
||||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
|
||||||
spectral_norm=spectral_norm)
|
|
||||||
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
|
|
||||||
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
|
|
||||||
spectral_norm=spectral_norm)
|
|
||||||
if mode == 'CNA':
|
|
||||||
last_act = None
|
|
||||||
else:
|
|
||||||
last_act = act_type
|
|
||||||
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
|
|
||||||
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
|
|
||||||
spectral_norm=spectral_norm)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.conv1(x)
|
|
||||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
|
||||||
if self.conv1x1:
|
|
||||||
x2 = x2 + self.conv1x1(x)
|
|
||||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
|
||||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
|
||||||
if self.conv1x1:
|
|
||||||
x4 = x4 + x2
|
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
||||||
if self.noise:
|
|
||||||
return self.noise(x5.mul(0.2) + x)
|
|
||||||
else:
|
|
||||||
return x5 * 0.2 + x
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# ESRGANplus
|
|
||||||
####################
|
|
||||||
|
|
||||||
class GaussianNoise(nn.Module):
|
|
||||||
def __init__(self, sigma=0.1, is_relative_detach=False):
|
|
||||||
super().__init__()
|
|
||||||
self.sigma = sigma
|
|
||||||
self.is_relative_detach = is_relative_detach
|
|
||||||
self.noise = torch.tensor(0, dtype=torch.float)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.training and self.sigma != 0:
|
|
||||||
self.noise = self.noise.to(x.device)
|
|
||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
|
||||||
x = x + sampled_noise
|
|
||||||
return x
|
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# SRVGGNetCompact
|
|
||||||
####################
|
|
||||||
|
|
||||||
class SRVGGNetCompact(nn.Module):
|
|
||||||
"""A compact VGG-style network structure for super-resolution.
|
|
||||||
This class is copied from https://github.com/xinntao/Real-ESRGAN
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
|
||||||
super(SRVGGNetCompact, self).__init__()
|
|
||||||
self.num_in_ch = num_in_ch
|
|
||||||
self.num_out_ch = num_out_ch
|
|
||||||
self.num_feat = num_feat
|
|
||||||
self.num_conv = num_conv
|
|
||||||
self.upscale = upscale
|
|
||||||
self.act_type = act_type
|
|
||||||
|
|
||||||
self.body = nn.ModuleList()
|
|
||||||
# the first conv
|
|
||||||
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
|
||||||
# the first activation
|
|
||||||
if act_type == 'relu':
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == 'prelu':
|
|
||||||
activation = nn.PReLU(num_parameters=num_feat)
|
|
||||||
elif act_type == 'leakyrelu':
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation)
|
|
||||||
|
|
||||||
# the body structure
|
|
||||||
for _ in range(num_conv):
|
|
||||||
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
|
||||||
# activation
|
|
||||||
if act_type == 'relu':
|
|
||||||
activation = nn.ReLU(inplace=True)
|
|
||||||
elif act_type == 'prelu':
|
|
||||||
activation = nn.PReLU(num_parameters=num_feat)
|
|
||||||
elif act_type == 'leakyrelu':
|
|
||||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
|
||||||
self.body.append(activation)
|
|
||||||
|
|
||||||
# the last conv
|
|
||||||
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
|
||||||
# upsample
|
|
||||||
self.upsampler = nn.PixelShuffle(upscale)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = x
|
|
||||||
for i in range(0, len(self.body)):
|
|
||||||
out = self.body[i](out)
|
|
||||||
|
|
||||||
out = self.upsampler(out)
|
|
||||||
# add the nearest upsampled image, so that the network learns the residual
|
|
||||||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
|
||||||
out += base
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Upsampler
|
|
||||||
####################
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
|
||||||
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
|
|
||||||
The input data is assumed to be of the form
|
|
||||||
`minibatch x channels x [optional depth] x [optional height] x width`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
|
||||||
super(Upsample, self).__init__()
|
|
||||||
if isinstance(scale_factor, tuple):
|
|
||||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
|
||||||
else:
|
|
||||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
|
||||||
self.mode = mode
|
|
||||||
self.size = size
|
|
||||||
self.align_corners = align_corners
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
if self.scale_factor is not None:
|
|
||||||
info = f'scale_factor={self.scale_factor}'
|
|
||||||
else:
|
|
||||||
info = f'size={self.size}'
|
|
||||||
info += f', mode={self.mode}'
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def pixel_unshuffle(x, scale):
|
|
||||||
""" Pixel unshuffle.
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
|
||||||
scale (int): Downsample ratio.
|
|
||||||
Returns:
|
|
||||||
Tensor: the pixel unshuffled feature.
|
|
||||||
"""
|
|
||||||
b, c, hh, hw = x.size()
|
|
||||||
out_channel = c * (scale**2)
|
|
||||||
assert hh % scale == 0 and hw % scale == 0
|
|
||||||
h = hh // scale
|
|
||||||
w = hw // scale
|
|
||||||
x_view = x.view(b, c, h, scale, w, scale)
|
|
||||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
|
||||||
|
|
||||||
|
|
||||||
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
|
||||||
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
|
|
||||||
"""
|
|
||||||
Pixel shuffle layer
|
|
||||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
|
||||||
Neural Network, CVPR17)
|
|
||||||
"""
|
|
||||||
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
|
|
||||||
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
|
|
||||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
||||||
|
|
||||||
n = norm(norm_type, out_nc) if norm_type else None
|
|
||||||
a = act(act_type) if act_type else None
|
|
||||||
return sequential(conv, pixel_shuffle, n, a)
|
|
||||||
|
|
||||||
|
|
||||||
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
|
|
||||||
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
|
|
||||||
""" Upconv layer """
|
|
||||||
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
|
|
||||||
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
|
|
||||||
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
|
|
||||||
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
|
|
||||||
return sequential(upsample, conv)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
####################
|
|
||||||
# Basic blocks
|
|
||||||
####################
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer(basic_block, num_basic_block, **kwarg):
|
|
||||||
"""Make layers by stacking the same blocks.
|
|
||||||
Args:
|
|
||||||
basic_block (nn.module): nn.module class for basic block. (block)
|
|
||||||
num_basic_block (int): number of blocks. (n_layers)
|
|
||||||
Returns:
|
|
||||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
|
||||||
"""
|
|
||||||
layers = []
|
|
||||||
for _ in range(num_basic_block):
|
|
||||||
layers.append(basic_block(**kwarg))
|
|
||||||
return nn.Sequential(*layers)
|
|
||||||
|
|
||||||
|
|
||||||
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
|
||||||
""" activation helper """
|
|
||||||
act_type = act_type.lower()
|
|
||||||
if act_type == 'relu':
|
|
||||||
layer = nn.ReLU(inplace)
|
|
||||||
elif act_type in ('leakyrelu', 'lrelu'):
|
|
||||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
|
||||||
elif act_type == 'prelu':
|
|
||||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
|
||||||
elif act_type == 'tanh': # [-1, 1] range output
|
|
||||||
layer = nn.Tanh()
|
|
||||||
elif act_type == 'sigmoid': # [0, 1] range output
|
|
||||||
layer = nn.Sigmoid()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(nn.Module):
|
|
||||||
def __init__(self, *kwargs):
|
|
||||||
super(Identity, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x, *kwargs):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def norm(norm_type, nc):
|
|
||||||
""" Return a normalization layer """
|
|
||||||
norm_type = norm_type.lower()
|
|
||||||
if norm_type == 'batch':
|
|
||||||
layer = nn.BatchNorm2d(nc, affine=True)
|
|
||||||
elif norm_type == 'instance':
|
|
||||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
|
||||||
elif norm_type == 'none':
|
|
||||||
def norm_layer(x): return Identity()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def pad(pad_type, padding):
|
|
||||||
""" padding layer helper """
|
|
||||||
pad_type = pad_type.lower()
|
|
||||||
if padding == 0:
|
|
||||||
return None
|
|
||||||
if pad_type == 'reflect':
|
|
||||||
layer = nn.ReflectionPad2d(padding)
|
|
||||||
elif pad_type == 'replicate':
|
|
||||||
layer = nn.ReplicationPad2d(padding)
|
|
||||||
elif pad_type == 'zero':
|
|
||||||
layer = nn.ZeroPad2d(padding)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def get_valid_padding(kernel_size, dilation):
|
|
||||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
|
||||||
padding = (kernel_size - 1) // 2
|
|
||||||
return padding
|
|
||||||
|
|
||||||
|
|
||||||
class ShortcutBlock(nn.Module):
|
|
||||||
""" Elementwise sum the output of a submodule to its input """
|
|
||||||
def __init__(self, submodule):
|
|
||||||
super(ShortcutBlock, self).__init__()
|
|
||||||
self.sub = submodule
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
output = x + self.sub(x)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
|
|
||||||
|
|
||||||
|
|
||||||
def sequential(*args):
|
|
||||||
""" Flatten Sequential. It unwraps nn.Sequential. """
|
|
||||||
if len(args) == 1:
|
|
||||||
if isinstance(args[0], OrderedDict):
|
|
||||||
raise NotImplementedError('sequential does not support OrderedDict input.')
|
|
||||||
return args[0] # No sequential is needed.
|
|
||||||
modules = []
|
|
||||||
for module in args:
|
|
||||||
if isinstance(module, nn.Sequential):
|
|
||||||
for submodule in module.children():
|
|
||||||
modules.append(submodule)
|
|
||||||
elif isinstance(module, nn.Module):
|
|
||||||
modules.append(module)
|
|
||||||
return nn.Sequential(*modules)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
|
|
||||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
|
||||||
spectral_norm=False):
|
|
||||||
""" Conv layer with padding, normalization, activation """
|
|
||||||
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
|
||||||
padding = padding if pad_type == 'zero' else 0
|
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
|
||||||
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
|
||||||
elif convtype=='DeformConv2D':
|
|
||||||
from torchvision.ops import DeformConv2d # not tested
|
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
|
||||||
elif convtype=='Conv3D':
|
|
||||||
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
|
||||||
else:
|
|
||||||
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
|
||||||
|
|
||||||
if spectral_norm:
|
|
||||||
c = nn.utils.spectral_norm(c)
|
|
||||||
|
|
||||||
a = act(act_type) if act_type else None
|
|
||||||
if 'CNA' in mode:
|
|
||||||
n = norm(norm_type, out_nc) if norm_type else None
|
|
||||||
return sequential(p, c, n, a)
|
|
||||||
elif mode == 'NAC':
|
|
||||||
if norm_type is None and act_type is not None:
|
|
||||||
a = act(act_type, inplace=False)
|
|
||||||
n = norm(norm_type, in_nc) if norm_type else None
|
|
||||||
return sequential(n, a, p, c)
|
|
@ -32,7 +32,8 @@ class ExtensionMetadata:
|
|||||||
self.config = configparser.ConfigParser()
|
self.config = configparser.ConfigParser()
|
||||||
|
|
||||||
filepath = os.path.join(path, self.filename)
|
filepath = os.path.join(path, self.filename)
|
||||||
if os.path.isfile(filepath):
|
# `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),
|
||||||
|
# so no need to check whether the file exists beforehand.
|
||||||
try:
|
try:
|
||||||
self.config.read(filepath)
|
self.config.read(filepath)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -223,13 +224,16 @@ def list_extensions():
|
|||||||
|
|
||||||
# check for requirements
|
# check for requirements
|
||||||
for extension in extensions:
|
for extension in extensions:
|
||||||
|
if not extension.enabled:
|
||||||
|
continue
|
||||||
|
|
||||||
for req in extension.metadata.requires:
|
for req in extension.metadata.requires:
|
||||||
required_extension = loaded_extensions.get(req)
|
required_extension = loaded_extensions.get(req)
|
||||||
if required_extension is None:
|
if required_extension is None:
|
||||||
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not extension.enabled:
|
if not required_extension.enabled:
|
||||||
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -206,7 +206,7 @@ def parse_prompts(prompts):
|
|||||||
return res, extra_data
|
return res, extra_data
|
||||||
|
|
||||||
|
|
||||||
def get_user_metadata(filename):
|
def get_user_metadata(filename, lister=None):
|
||||||
if filename is None:
|
if filename is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -215,7 +215,8 @@ def get_user_metadata(filename):
|
|||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
try:
|
try:
|
||||||
if os.path.isfile(metadata_filename):
|
exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)
|
||||||
|
if exists:
|
||||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||||
metadata = json.load(file)
|
metadata = json.load(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
180
modules/face_restoration_utils.py
Normal file
180
modules/face_restoration_utils.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import devices, errors, face_restoration, shared
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
|
||||||
|
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
|
||||||
|
assert img.shape[2] == 3, "image must be RGB"
|
||||||
|
if img.dtype == "float64":
|
||||||
|
img = img.astype("float32")
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
return torch.from_numpy(img.transpose(2, 0, 1)).float()
|
||||||
|
|
||||||
|
|
||||||
|
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
|
||||||
|
"""
|
||||||
|
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
||||||
|
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||||
|
assert tensor.dim() == 3, "tensor must be RGB"
|
||||||
|
img_np = tensor.numpy().transpose(1, 2, 0)
|
||||||
|
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
|
||||||
|
return np.squeeze(img_np, axis=2)
|
||||||
|
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
|
||||||
|
def create_face_helper(device) -> FaceRestoreHelper:
|
||||||
|
from facexlib.detection import retinaface
|
||||||
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
if hasattr(retinaface, 'device'):
|
||||||
|
retinaface.device = device
|
||||||
|
return FaceRestoreHelper(
|
||||||
|
upscale_factor=1,
|
||||||
|
face_size=512,
|
||||||
|
crop_ratio=(1, 1),
|
||||||
|
det_model='retinaface_resnet50',
|
||||||
|
save_ext='png',
|
||||||
|
use_parse=True,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_with_face_helper(
|
||||||
|
np_image: np.ndarray,
|
||||||
|
face_helper: FaceRestoreHelper,
|
||||||
|
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||||
|
|
||||||
|
`restore_face` should take a cropped face image and return a restored face image.
|
||||||
|
"""
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
np_image = np_image[:, :, ::-1]
|
||||||
|
original_resolution = np_image.shape[0:2]
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.debug("Detecting faces...")
|
||||||
|
face_helper.clean_all()
|
||||||
|
face_helper.read_image(np_image)
|
||||||
|
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
|
face_helper.align_warp_face()
|
||||||
|
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||||
|
for cropped_face in face_helper.cropped_faces:
|
||||||
|
cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
|
||||||
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
cropped_face_t = restore_face(cropped_face_t)
|
||||||
|
devices.torch_gc()
|
||||||
|
except Exception:
|
||||||
|
errors.report('Failed face-restoration inference', exc_info=True)
|
||||||
|
|
||||||
|
restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
|
||||||
|
restored_face = (restored_face * 255.0).astype('uint8')
|
||||||
|
face_helper.add_restored_face(restored_face)
|
||||||
|
|
||||||
|
logger.debug("Merging restored faces into image")
|
||||||
|
face_helper.get_inverse_affine(None)
|
||||||
|
img = face_helper.paste_faces_to_input_image()
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
if original_resolution != img.shape[0:2]:
|
||||||
|
img = cv2.resize(
|
||||||
|
img,
|
||||||
|
(0, 0),
|
||||||
|
fx=original_resolution[1] / img.shape[1],
|
||||||
|
fy=original_resolution[0] / img.shape[0],
|
||||||
|
interpolation=cv2.INTER_LINEAR,
|
||||||
|
)
|
||||||
|
logger.debug("Face restoration complete")
|
||||||
|
finally:
|
||||||
|
face_helper.clean_all()
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||||
|
net: torch.Module | None
|
||||||
|
model_url: str
|
||||||
|
model_download_name: str
|
||||||
|
|
||||||
|
def __init__(self, model_path: str):
|
||||||
|
super().__init__()
|
||||||
|
self.net = None
|
||||||
|
self.model_path = model_path
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def face_helper(self) -> FaceRestoreHelper:
|
||||||
|
return create_face_helper(self.get_device())
|
||||||
|
|
||||||
|
def send_model_to(self, device):
|
||||||
|
if self.net:
|
||||||
|
logger.debug("Sending %s to %s", self.net, device)
|
||||||
|
self.net.to(device)
|
||||||
|
if self.face_helper:
|
||||||
|
logger.debug("Sending face helper to %s", device)
|
||||||
|
self.face_helper.face_det.to(device)
|
||||||
|
self.face_helper.face_parse.to(device)
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
raise NotImplementedError("get_device must be implemented by subclasses")
|
||||||
|
|
||||||
|
def load_net(self) -> torch.Module:
|
||||||
|
raise NotImplementedError("load_net must be implemented by subclasses")
|
||||||
|
|
||||||
|
def restore_with_helper(
|
||||||
|
self,
|
||||||
|
np_image: np.ndarray,
|
||||||
|
restore_face: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
if self.net is None:
|
||||||
|
self.net = self.load_net()
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Unable to load face-restoration model", exc_info=True)
|
||||||
|
return np_image
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.send_model_to(self.get_device())
|
||||||
|
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||||
|
finally:
|
||||||
|
if shared.opts.face_restoration_unload:
|
||||||
|
self.send_model_to(devices.cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_facexlib(dirname: str) -> None:
|
||||||
|
import facexlib.detection
|
||||||
|
import facexlib.parsing
|
||||||
|
|
||||||
|
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
||||||
|
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
||||||
|
|
||||||
|
def update_kwargs(kwargs):
|
||||||
|
return dict(kwargs, save_dir=dirname, model_dir=None)
|
||||||
|
|
||||||
|
def facex_load_file_from_url(**kwargs):
|
||||||
|
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||||
|
|
||||||
|
def facex_load_file_from_url2(**kwargs):
|
||||||
|
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||||
|
|
||||||
|
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||||
|
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
@ -1,125 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import facexlib
|
import torch
|
||||||
import gfpgan
|
|
||||||
|
|
||||||
import modules.face_restoration
|
from modules import (
|
||||||
from modules import paths, shared, devices, modelloader, errors
|
devices,
|
||||||
|
errors,
|
||||||
|
face_restoration,
|
||||||
|
face_restoration_utils,
|
||||||
|
modelloader,
|
||||||
|
shared,
|
||||||
|
)
|
||||||
|
|
||||||
model_dir = "GFPGAN"
|
logger = logging.getLogger(__name__)
|
||||||
user_path = None
|
|
||||||
model_path = os.path.join(paths.models_path, model_dir)
|
|
||||||
model_file_path = None
|
|
||||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||||
have_gfpgan = False
|
model_download_name = "GFPGANv1.4.pth"
|
||||||
loaded_gfpgan_model = None
|
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
||||||
|
|
||||||
|
|
||||||
def gfpgann():
|
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||||
global loaded_gfpgan_model
|
|
||||||
global model_path
|
|
||||||
global model_file_path
|
|
||||||
if loaded_gfpgan_model is not None:
|
|
||||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
|
||||||
return loaded_gfpgan_model
|
|
||||||
|
|
||||||
if gfpgan_constructor is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
|
||||||
|
|
||||||
if len(models) == 1 and models[0].startswith("http"):
|
|
||||||
model_file = models[0]
|
|
||||||
elif len(models) != 0:
|
|
||||||
gfp_models = []
|
|
||||||
for item in models:
|
|
||||||
if 'GFPGAN' in os.path.basename(item):
|
|
||||||
gfp_models.append(item)
|
|
||||||
latest_file = max(gfp_models, key=os.path.getctime)
|
|
||||||
model_file = latest_file
|
|
||||||
else:
|
|
||||||
print("Unable to load gfpgan model!")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
|
||||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
|
||||||
model_file_path = model_file
|
|
||||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
|
||||||
loaded_gfpgan_model = model
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def send_model_to(model, device):
|
|
||||||
model.gfpgan.to(device)
|
|
||||||
model.face_helper.face_det.to(device)
|
|
||||||
model.face_helper.face_parse.to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def gfpgan_fix_faces(np_image):
|
|
||||||
model = gfpgann()
|
|
||||||
if model is None:
|
|
||||||
return np_image
|
|
||||||
|
|
||||||
send_model_to(model, devices.device_gfpgan)
|
|
||||||
|
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
|
||||||
|
|
||||||
model.face_helper.clean_all()
|
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
|
||||||
send_model_to(model, devices.cpu)
|
|
||||||
|
|
||||||
return np_image
|
|
||||||
|
|
||||||
|
|
||||||
gfpgan_constructor = None
|
|
||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
|
||||||
try:
|
|
||||||
os.makedirs(model_path, exist_ok=True)
|
|
||||||
from gfpgan import GFPGANer
|
|
||||||
from facexlib import detection, parsing # noqa: F401
|
|
||||||
global user_path
|
|
||||||
global have_gfpgan
|
|
||||||
global gfpgan_constructor
|
|
||||||
global model_file_path
|
|
||||||
|
|
||||||
facexlib_path = model_path
|
|
||||||
|
|
||||||
if dirname is not None:
|
|
||||||
facexlib_path = dirname
|
|
||||||
|
|
||||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
|
||||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
|
||||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
|
||||||
|
|
||||||
def my_load_file_from_url(**kwargs):
|
|
||||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
|
||||||
|
|
||||||
def facex_load_file_from_url(**kwargs):
|
|
||||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
|
||||||
|
|
||||||
def facex_load_file_from_url2(**kwargs):
|
|
||||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
|
||||||
|
|
||||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
|
||||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
|
||||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
|
||||||
user_path = dirname
|
|
||||||
have_gfpgan = True
|
|
||||||
gfpgan_constructor = GFPGANer
|
|
||||||
|
|
||||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
|
||||||
def name(self):
|
def name(self):
|
||||||
return "GFPGAN"
|
return "GFPGAN"
|
||||||
|
|
||||||
def restore(self, np_image):
|
def get_device(self):
|
||||||
return gfpgan_fix_faces(np_image)
|
return devices.device_gfpgan
|
||||||
|
|
||||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
def load_net(self) -> torch.Module:
|
||||||
|
for model_path in modelloader.load_models(
|
||||||
|
model_path=self.model_path,
|
||||||
|
model_url=model_url,
|
||||||
|
command_path=self.model_path,
|
||||||
|
download_name=model_download_name,
|
||||||
|
ext_filter=['.pth'],
|
||||||
|
):
|
||||||
|
if 'GFPGAN' in os.path.basename(model_path):
|
||||||
|
model = modelloader.load_spandrel_model(
|
||||||
|
model_path,
|
||||||
|
device=self.get_device(),
|
||||||
|
expected_architecture='GFPGAN',
|
||||||
|
).model
|
||||||
|
model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||||
|
return model
|
||||||
|
raise ValueError("No GFPGAN model found")
|
||||||
|
|
||||||
|
def restore(self, np_image):
|
||||||
|
def restore_face(cropped_face_t):
|
||||||
|
assert self.net is not None
|
||||||
|
return self.net(cropped_face_t, return_rgb=False)[0]
|
||||||
|
|
||||||
|
return self.restore_with_helper(np_image, restore_face)
|
||||||
|
|
||||||
|
|
||||||
|
def gfpgan_fix_faces(np_image):
|
||||||
|
if gfpgan_face_restorer:
|
||||||
|
return gfpgan_face_restorer.restore(np_image)
|
||||||
|
logger.warning("GFPGAN face restorer not set up")
|
||||||
|
return np_image
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(dirname: str) -> None:
|
||||||
|
global gfpgan_face_restorer
|
||||||
|
|
||||||
|
try:
|
||||||
|
face_restoration_utils.patch_facexlib(dirname)
|
||||||
|
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
||||||
|
shared.face_restorers.append(gfpgan_face_restorer)
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||||
|
43
modules/hat_model.py
Normal file
43
modules/hat_model.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from modules import modelloader, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerHAT(Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "HAT"
|
||||||
|
self.scalers = []
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
for file in self.find_models(ext_filter=[".pt", ".pth"]):
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
|
||||||
|
scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
|
||||||
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model):
|
||||||
|
try:
|
||||||
|
model = self.load_model(selected_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
|
||||||
|
return img
|
||||||
|
model.to(devices.device_esrgan) # TODO: should probably be device_hat
|
||||||
|
return upscale_with_model(
|
||||||
|
model,
|
||||||
|
img,
|
||||||
|
tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
|
||||||
|
tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
raise FileNotFoundError(f"Model file {path} not found")
|
||||||
|
return modelloader.load_spandrel_model(
|
||||||
|
path,
|
||||||
|
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||||
|
expected_architecture='HAT',
|
||||||
|
)
|
@ -61,12 +61,17 @@ def image_grid(imgs, batch_size=1, rows=None):
|
|||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
class Grid(namedtuple("_Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])):
|
||||||
|
@property
|
||||||
|
def tile_count(self) -> int:
|
||||||
|
"""
|
||||||
|
The total number of tiles in the grid.
|
||||||
|
"""
|
||||||
|
return sum(len(row[2]) for row in self.tiles)
|
||||||
|
|
||||||
|
|
||||||
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
|
||||||
w = image.width
|
w, h = image.size
|
||||||
h = image.height
|
|
||||||
|
|
||||||
non_overlap_width = tile_w - overlap
|
non_overlap_width = tile_w - overlap
|
||||||
non_overlap_height = tile_h - overlap
|
non_overlap_height = tile_h - overlap
|
||||||
@ -316,7 +321,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
invalid_filename_chars = '#<>:"/\\|?*\n\r\t'
|
||||||
invalid_filename_prefix = ' '
|
invalid_filename_prefix = ' '
|
||||||
invalid_filename_postfix = ' .'
|
invalid_filename_postfix = ' .'
|
||||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||||
@ -791,3 +796,4 @@ def flatten(img, bgcolor):
|
|||||||
img = background
|
img = background
|
||||||
|
|
||||||
return img.convert('RGB')
|
return img.convert('RGB')
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images as imgutil
|
from modules import images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from modules.sd_models import get_closet_checkpoint_match
|
from modules.sd_models import get_closet_checkpoint_match
|
||||||
@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
if state.skipped:
|
if state.skipped:
|
||||||
state.skipped = False
|
state.skipped = False
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted or state.stopping_generation:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -222,9 +222,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
if shared.opts.enable_console_prompts:
|
if shared.opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
if mask:
|
|
||||||
p.extra_generation_params["Mask blur"] = mask_blur
|
|
||||||
|
|
||||||
with closing(p):
|
with closing(p):
|
||||||
if is_batch:
|
if is_batch:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
|
||||||
|
@ -4,12 +4,15 @@ import io
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
|
||||||
|
|
||||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
@ -28,6 +31,19 @@ class ParamBinding:
|
|||||||
self.paste_field_names = paste_field_names or []
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
|
class PasteField(tuple):
|
||||||
|
def __new__(cls, component, target, *, api=None):
|
||||||
|
return super().__new__(cls, (component, target))
|
||||||
|
|
||||||
|
def __init__(self, component, target, *, api=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.api = api
|
||||||
|
self.component = component
|
||||||
|
self.label = target if isinstance(target, str) else None
|
||||||
|
self.function = target if callable(target) else None
|
||||||
|
|
||||||
|
|
||||||
paste_fields: dict[str, dict] = {}
|
paste_fields: dict[str, dict] = {}
|
||||||
registered_param_bindings: list[ParamBinding] = []
|
registered_param_bindings: list[ParamBinding] = []
|
||||||
|
|
||||||
@ -84,6 +100,12 @@ def image_from_url_text(filedata):
|
|||||||
|
|
||||||
|
|
||||||
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
|
||||||
|
|
||||||
|
if fields:
|
||||||
|
for i in range(len(fields)):
|
||||||
|
if not isinstance(fields[i], PasteField):
|
||||||
|
fields[i] = PasteField(*fields[i])
|
||||||
|
|
||||||
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
|
||||||
|
|
||||||
# backwards compatibility for existing extensions
|
# backwards compatibility for existing extensions
|
||||||
@ -208,7 +230,7 @@ def restore_old_hires_fix_params(res):
|
|||||||
res['Hires resize-2'] = height
|
res['Hires resize-2'] = height
|
||||||
|
|
||||||
|
|
||||||
def parse_generation_parameters(x: str):
|
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
|
||||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||||
```
|
```
|
||||||
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
|
||||||
@ -218,6 +240,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
|
|
||||||
returns a dict with field values
|
returns a dict with field values
|
||||||
"""
|
"""
|
||||||
|
if skip_fields is None:
|
||||||
|
skip_fields = shared.opts.infotext_skip_pasting
|
||||||
|
|
||||||
res = {}
|
res = {}
|
||||||
|
|
||||||
@ -290,6 +314,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Hires negative prompt" not in res:
|
if "Hires negative prompt" not in res:
|
||||||
res["Hires negative prompt"] = ""
|
res["Hires negative prompt"] = ""
|
||||||
|
|
||||||
|
if "Mask mode" not in res:
|
||||||
|
res["Mask mode"] = "Inpaint masked"
|
||||||
|
|
||||||
|
if "Masked content" not in res:
|
||||||
|
res["Masked content"] = 'original'
|
||||||
|
|
||||||
|
if "Inpaint area" not in res:
|
||||||
|
res["Inpaint area"] = "Whole picture"
|
||||||
|
|
||||||
|
if "Masked area padding" not in res:
|
||||||
|
res["Masked area padding"] = 32
|
||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
# Missing RNG means the default was set, which is GPU RNG
|
# Missing RNG means the default was set, which is GPU RNG
|
||||||
@ -314,8 +350,16 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "VAE Decoder" not in res:
|
if "VAE Decoder" not in res:
|
||||||
res["VAE Decoder"] = "Full"
|
res["VAE Decoder"] = "Full"
|
||||||
|
|
||||||
skip = set(shared.opts.infotext_skip_pasting)
|
if "FP8 weight" not in res:
|
||||||
res = {k: v for k, v in res.items() if k not in skip}
|
res["FP8 weight"] = "Disable"
|
||||||
|
|
||||||
|
if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
|
||||||
|
res["Cache FP16 weight for LoRA"] = False
|
||||||
|
|
||||||
|
infotext_versions.backcompat(res)
|
||||||
|
|
||||||
|
for key in skip_fields:
|
||||||
|
res.pop(key, None)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -365,13 +409,57 @@ def create_override_settings_dict(text_pairs):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_override_settings(params, *, skip_fields=None):
|
||||||
|
"""Returns a list of settings overrides from the infotext parameters dictionary.
|
||||||
|
|
||||||
|
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
|
||||||
|
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
|
||||||
|
|
||||||
|
It checks for conditions before adding an override:
|
||||||
|
- ignores settings that match the current value
|
||||||
|
- ignores parameter keys present in skip_fields argument.
|
||||||
|
|
||||||
|
Example input:
|
||||||
|
{"Clip skip": "2"}
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
||||||
|
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
||||||
|
if param_name in (skip_fields or {}):
|
||||||
|
continue
|
||||||
|
|
||||||
|
v = params.get(param_name, None)
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
||||||
|
continue
|
||||||
|
|
||||||
|
v = shared.opts.cast_value(setting_name, v)
|
||||||
|
current_value = getattr(shared.opts, setting_name, None)
|
||||||
|
|
||||||
|
if v == current_value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
res.append((param_name, setting_name, v))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
|
||||||
def paste_func(prompt):
|
def paste_func(prompt):
|
||||||
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
|
||||||
filename = os.path.join(data_path, "params.txt")
|
filename = os.path.join(data_path, "params.txt")
|
||||||
if os.path.exists(filename):
|
try:
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
prompt = file.read()
|
prompt = file.read()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
params = parse_generation_parameters(prompt)
|
params = parse_generation_parameters(prompt)
|
||||||
script_callbacks.infotext_pasted_callback(prompt, params)
|
script_callbacks.infotext_pasted_callback(prompt, params)
|
||||||
@ -393,6 +481,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
|
|
||||||
if valtype == bool and v == "False":
|
if valtype == bool and v == "False":
|
||||||
val = False
|
val = False
|
||||||
|
elif valtype == int:
|
||||||
|
val = float(v)
|
||||||
else:
|
else:
|
||||||
val = valtype(v)
|
val = valtype(v)
|
||||||
|
|
||||||
@ -406,29 +496,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
already_handled_fields = {key: 1 for _, key in paste_fields}
|
already_handled_fields = {key: 1 for _, key in paste_fields}
|
||||||
|
|
||||||
def paste_settings(params):
|
def paste_settings(params):
|
||||||
vals = {}
|
vals = get_override_settings(params, skip_fields=already_handled_fields)
|
||||||
|
|
||||||
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
|
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
|
||||||
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
|
|
||||||
if param_name in already_handled_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
v = params.get(param_name, None)
|
|
||||||
if v is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
|
|
||||||
continue
|
|
||||||
|
|
||||||
v = shared.opts.cast_value(setting_name, v)
|
|
||||||
current_value = getattr(shared.opts, setting_name, None)
|
|
||||||
|
|
||||||
if v == current_value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
vals[param_name] = v
|
|
||||||
|
|
||||||
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
|
|
||||||
|
|
||||||
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
|
||||||
|
|
42
modules/infotext_versions.py
Normal file
42
modules/infotext_versions.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from modules import shared
|
||||||
|
from packaging import version
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
v160 = version.parse("1.6.0")
|
||||||
|
v170_tsnr = version.parse("v1.7.0-225")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(text):
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
m = re.match(r'([^-]+-[^-]+)-.*', text)
|
||||||
|
if m:
|
||||||
|
text = m.group(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return version.parse(text)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def backcompat(d):
|
||||||
|
"""Checks infotext Version field, and enables backwards compatibility options according to it."""
|
||||||
|
|
||||||
|
if not shared.opts.auto_backcompat:
|
||||||
|
return
|
||||||
|
|
||||||
|
ver = parse_version(d.get("Version"))
|
||||||
|
if ver is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if ver < v160 and '[' in d.get('Prompt', ''):
|
||||||
|
d["Old prompt editing timelines"] = True
|
||||||
|
|
||||||
|
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
|
||||||
|
d["Pad conds v0"] = True
|
||||||
|
|
||||||
|
if ver < v170_tsnr:
|
||||||
|
d["Downcast alphas_cumprod"] = True
|
||||||
|
|
@ -1,5 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -18,6 +19,7 @@ def imports():
|
|||||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
|
||||||
|
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||||
import gradio # noqa: F401
|
import gradio # noqa: F401
|
||||||
startup_timer.record("import gradio")
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
@ -54,9 +56,6 @@ def initialize():
|
|||||||
initialize_util.configure_sigint_handler()
|
initialize_util.configure_sigint_handler()
|
||||||
initialize_util.configure_opts_onchange()
|
initialize_util.configure_opts_onchange()
|
||||||
|
|
||||||
from modules import modelloader
|
|
||||||
modelloader.cleanup_models()
|
|
||||||
|
|
||||||
from modules import sd_models
|
from modules import sd_models
|
||||||
sd_models.setup_model()
|
sd_models.setup_model()
|
||||||
startup_timer.record("setup SD model")
|
startup_timer.record("setup SD model")
|
||||||
|
@ -177,6 +177,8 @@ def configure_opts_onchange():
|
|||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
|
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||||
|
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||||
startup_timer.record("opts onchange")
|
startup_timer.record("opts onchange")
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,14 +10,14 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
|
|
||||||
Category = namedtuple("Category", ["name", "topn", "items"])
|
Category = namedtuple("Category", ["name", "topn", "items"])
|
||||||
|
|
||||||
re_topn = re.compile(r"\.top(\d+)\.")
|
re_topn = re.compile(r"\.top(\d+)$")
|
||||||
|
|
||||||
def category_types():
|
def category_types():
|
||||||
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
|
||||||
@ -131,7 +131,7 @@ class InterrogateModels:
|
|||||||
|
|
||||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||||
|
|
||||||
self.dtype = next(self.clip_model.parameters()).dtype
|
self.dtype = torch_utils.get_param(self.clip_model).dtype
|
||||||
|
|
||||||
def send_clip_to_ram(self):
|
def send_clip_to_ram(self):
|
||||||
if not shared.opts.interrogate_keep_models_in_memory:
|
if not shared.opts.interrogate_keep_models_in_memory:
|
||||||
|
@ -27,8 +27,7 @@ dir_repos = "repositories"
|
|||||||
# Whether to default to printing command output
|
# Whether to default to printing command output
|
||||||
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
|
||||||
|
|
||||||
|
|
||||||
def check_python_version():
|
def check_python_version():
|
||||||
@ -245,11 +244,13 @@ def list_extensions(settings_file):
|
|||||||
settings = {}
|
settings = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.isfile(settings_file):
|
|
||||||
with open(settings_file, "r", encoding="utf8") as file:
|
with open(settings_file, "r", encoding="utf8") as file:
|
||||||
settings = json.load(file)
|
settings = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report("Could not load settings", exc_info=True)
|
errors.report(f'\nCould not load settings\nThe config file "{settings_file}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||||
|
os.replace(settings_file, os.path.join(script_path, "tmp", "config.json"))
|
||||||
|
|
||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
@ -314,8 +315,8 @@ def requirements_met(requirements_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}")
|
||||||
if args.use_ipex:
|
if args.use_ipex:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
|
||||||
@ -338,20 +339,20 @@ def prepare_environment():
|
|||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
|
assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git")
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
|
assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917")
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -405,18 +406,14 @@ def prepare_environment():
|
|||||||
|
|
||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
|
git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash)
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
startup_timer.record("clone repositores")
|
startup_timer.record("clone repositores")
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
|
||||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
|
||||||
startup_timer.record("install CodeFormer requirements")
|
|
||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
if not os.path.isfile(requirements_file):
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
|
|
||||||
|
@ -1,41 +1,58 @@
|
|||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tqdm.auto import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class TqdmLoggingHandler(logging.Handler):
|
class TqdmLoggingHandler(logging.Handler):
|
||||||
def __init__(self, level=logging.INFO):
|
def __init__(self, fallback_handler: logging.Handler):
|
||||||
super().__init__(level)
|
super().__init__()
|
||||||
|
self.fallback_handler = fallback_handler
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
try:
|
try:
|
||||||
msg = self.format(record)
|
# If there are active tqdm progress bars,
|
||||||
tqdm.write(msg)
|
# attempt to not interfere with them.
|
||||||
self.flush()
|
if tqdm._instances:
|
||||||
|
tqdm.write(self.format(record))
|
||||||
|
else:
|
||||||
|
self.fallback_handler.emit(record)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.handleError(record)
|
self.fallback_handler.emit(record)
|
||||||
|
|
||||||
TQDM_IMPORTED = True
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# tqdm does not exist before first launch
|
TqdmLoggingHandler = None
|
||||||
# I will import once the UI finishes seting up the enviroment and reloads.
|
|
||||||
TQDM_IMPORTED = False
|
|
||||||
|
|
||||||
def setup_logging(loglevel):
|
def setup_logging(loglevel):
|
||||||
if loglevel is None:
|
if loglevel is None:
|
||||||
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
|
||||||
|
|
||||||
loghandlers = []
|
if not loglevel:
|
||||||
|
return
|
||||||
|
|
||||||
if TQDM_IMPORTED:
|
if logging.root.handlers:
|
||||||
loghandlers.append(TqdmLoggingHandler())
|
# Already configured, do not interfere
|
||||||
|
return
|
||||||
|
|
||||||
if loglevel:
|
formatter = logging.Formatter(
|
||||||
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
||||||
logging.basicConfig(
|
'%Y-%m-%d %H:%M:%S',
|
||||||
level=log_level,
|
|
||||||
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
|
|
||||||
datefmt='%Y-%m-%d %H:%M:%S',
|
|
||||||
handlers=loghandlers
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if os.environ.get("SD_WEBUI_RICH_LOG"):
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
handler = RichHandler()
|
||||||
|
else:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
if TqdmLoggingHandler:
|
||||||
|
handler = TqdmLoggingHandler(handler)
|
||||||
|
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
|
||||||
|
logging.root.setLevel(log_level)
|
||||||
|
logging.root.addHandler(handler)
|
||||||
|
@ -3,40 +3,15 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
|
|
||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image, the result may be (256, 0, 512, 256)"""
|
||||||
|
mask_img = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)
|
||||||
h, w = mask.shape
|
box = mask_img.getbbox()
|
||||||
|
if box:
|
||||||
crop_left = 0
|
x1, y1, x2, y2 = box
|
||||||
for i in range(w):
|
else: # when no box is found
|
||||||
if not (mask[:, i] == 0).all():
|
x1, y1 = mask_img.size
|
||||||
break
|
x2 = y2 = 0
|
||||||
crop_left += 1
|
return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask_img.size[0]), min(y2 + pad, mask_img.size[1])
|
||||||
|
|
||||||
crop_right = 0
|
|
||||||
for i in reversed(range(w)):
|
|
||||||
if not (mask[:, i] == 0).all():
|
|
||||||
break
|
|
||||||
crop_right += 1
|
|
||||||
|
|
||||||
crop_top = 0
|
|
||||||
for i in range(h):
|
|
||||||
if not (mask[i] == 0).all():
|
|
||||||
break
|
|
||||||
crop_top += 1
|
|
||||||
|
|
||||||
crop_bottom = 0
|
|
||||||
for i in reversed(range(h)):
|
|
||||||
if not (mask[i] == 0).all():
|
|
||||||
break
|
|
||||||
crop_bottom += 1
|
|
||||||
|
|
||||||
return (
|
|
||||||
int(max(crop_left-pad, 0)),
|
|
||||||
int(max(crop_top-pad, 0)),
|
|
||||||
int(min(w - crop_right + pad, w)),
|
|
||||||
int(min(h - crop_bottom + pad, h))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||||
from modules.paths import script_path, models_path
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import spandrel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_file_from_url(
|
def load_file_from_url(
|
||||||
@ -90,54 +97,6 @@ def friendly_name(file: str):
|
|||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models():
|
|
||||||
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
|
||||||
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
|
||||||
# somehow auto-register and just do these things...
|
|
||||||
root_path = script_path
|
|
||||||
src_path = models_path
|
|
||||||
dest_path = os.path.join(models_path, "Stable-diffusion")
|
|
||||||
move_files(src_path, dest_path, ".ckpt")
|
|
||||||
move_files(src_path, dest_path, ".safetensors")
|
|
||||||
src_path = os.path.join(root_path, "ESRGAN")
|
|
||||||
dest_path = os.path.join(models_path, "ESRGAN")
|
|
||||||
move_files(src_path, dest_path)
|
|
||||||
src_path = os.path.join(models_path, "BSRGAN")
|
|
||||||
dest_path = os.path.join(models_path, "ESRGAN")
|
|
||||||
move_files(src_path, dest_path, ".pth")
|
|
||||||
src_path = os.path.join(root_path, "gfpgan")
|
|
||||||
dest_path = os.path.join(models_path, "GFPGAN")
|
|
||||||
move_files(src_path, dest_path)
|
|
||||||
src_path = os.path.join(root_path, "SwinIR")
|
|
||||||
dest_path = os.path.join(models_path, "SwinIR")
|
|
||||||
move_files(src_path, dest_path)
|
|
||||||
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
|
||||||
dest_path = os.path.join(models_path, "LDSR")
|
|
||||||
move_files(src_path, dest_path)
|
|
||||||
|
|
||||||
|
|
||||||
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|
||||||
try:
|
|
||||||
os.makedirs(dest_path, exist_ok=True)
|
|
||||||
if os.path.exists(src_path):
|
|
||||||
for file in os.listdir(src_path):
|
|
||||||
fullpath = os.path.join(src_path, file)
|
|
||||||
if os.path.isfile(fullpath):
|
|
||||||
if ext_filter is not None:
|
|
||||||
if ext_filter not in file:
|
|
||||||
continue
|
|
||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
|
||||||
try:
|
|
||||||
shutil.move(fullpath, dest_path)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if len(os.listdir(src_path)) == 0:
|
|
||||||
print(f"Removing empty folder: {src_path}")
|
|
||||||
shutil.rmtree(src_path, True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
@ -177,3 +136,34 @@ def load_upscalers():
|
|||||||
# Special case for UpscalerNone keeps it at the beginning of the list.
|
# Special case for UpscalerNone keeps it at the beginning of the list.
|
||||||
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_spandrel_model(
|
||||||
|
path: str | os.PathLike,
|
||||||
|
*,
|
||||||
|
device: str | torch.device | None,
|
||||||
|
prefer_half: bool = False,
|
||||||
|
dtype: str | torch.dtype | None = None,
|
||||||
|
expected_architecture: str | None = None,
|
||||||
|
) -> spandrel.ModelDescriptor:
|
||||||
|
import spandrel
|
||||||
|
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||||
|
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||||
|
logger.warning(
|
||||||
|
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||||
|
)
|
||||||
|
half = False
|
||||||
|
if prefer_half:
|
||||||
|
if model_descriptor.supports_half:
|
||||||
|
model_descriptor.model.half()
|
||||||
|
half = True
|
||||||
|
else:
|
||||||
|
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||||
|
if dtype:
|
||||||
|
model_descriptor.model.to(dtype=dtype)
|
||||||
|
model_descriptor.model.eval()
|
||||||
|
logger.debug(
|
||||||
|
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||||
|
model_descriptor, path, device, half, dtype,
|
||||||
|
)
|
||||||
|
return model_descriptor
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -6,6 +7,7 @@ import gradio as gr
|
|||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
from modules.shared_cmd_options import cmd_opts
|
from modules.shared_cmd_options import cmd_opts
|
||||||
|
from modules.paths_internal import script_path
|
||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
@ -91,18 +93,35 @@ class Options:
|
|||||||
|
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
if key in self.data or key in self.data_labels:
|
if key in self.data or key in self.data_labels:
|
||||||
|
|
||||||
|
# Check that settings aren't globally frozen
|
||||||
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
assert not cmd_opts.freeze_settings, "changing settings is disabled"
|
||||||
|
|
||||||
|
# Get the info related to the setting being changed
|
||||||
info = self.data_labels.get(key, None)
|
info = self.data_labels.get(key, None)
|
||||||
if info.do_not_save:
|
if info.do_not_save:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Restrict component arguments
|
||||||
comp_args = info.component_args if info else None
|
comp_args = info.component_args if info else None
|
||||||
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
raise RuntimeError(f"not possible to set '{key}' because it is restricted")
|
||||||
|
|
||||||
|
# Check that this section isn't frozen
|
||||||
|
if cmd_opts.freeze_settings_in_sections is not None:
|
||||||
|
frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names
|
||||||
|
section_key = info.section[0]
|
||||||
|
section_name = info.section[1]
|
||||||
|
assert section_key not in frozen_sections, f"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections"
|
||||||
|
|
||||||
|
# Check that this section of the settings isn't frozen
|
||||||
|
if cmd_opts.freeze_specific_settings is not None:
|
||||||
|
frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys
|
||||||
|
assert key not in frozen_keys, f"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings"
|
||||||
|
|
||||||
|
# Check shorthand option which disables editing options in "saving-paths"
|
||||||
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
|
||||||
raise RuntimeError(f"not possible to set {key} because it is restricted")
|
raise RuntimeError(f"not possible to set '{key}' because it is restricted with --hide_ui_dir_config")
|
||||||
|
|
||||||
self.data[key] = value
|
self.data[key] = value
|
||||||
return
|
return
|
||||||
@ -176,9 +195,15 @@ class Options:
|
|||||||
return type_x == type_y
|
return type_x == type_y
|
||||||
|
|
||||||
def load(self, filename):
|
def load(self, filename):
|
||||||
|
try:
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
self.data = {}
|
||||||
|
except Exception:
|
||||||
|
errors.report(f'\nCould not load settings\nThe config file "{filename}" is likely corrupted\nIt has been moved to the "tmp/config.json"\nReverting config to default\n\n''', exc_info=True)
|
||||||
|
os.replace(filename, os.path.join(script_path, "tmp", "config.json"))
|
||||||
|
self.data = {}
|
||||||
# 1.6.0 VAE defaults
|
# 1.6.0 VAE defaults
|
||||||
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||||
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||||
|
@ -38,7 +38,6 @@ mute_sdxl_imports()
|
|||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
]
|
]
|
||||||
|
@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
|
|||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
config_states_dir = os.path.join(script_path, "config_states")
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
default_output_dir = os.path.join(data_path, "output")
|
||||||
|
|
||||||
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
|
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -62,8 +62,6 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
else:
|
else:
|
||||||
image_data = image_placeholder
|
image_data = image_placeholder
|
||||||
|
|
||||||
shared.state.assign_current_image(image_data)
|
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
@ -86,22 +84,25 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
basename = ''
|
basename = ''
|
||||||
forced_filename = None
|
forced_filename = None
|
||||||
|
|
||||||
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
|
infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
pp.image.info = existing_pnginfo
|
pp.image.info = existing_pnginfo
|
||||||
pp.image.info["postprocessing"] = infotext
|
pp.image.info["postprocessing"] = infotext
|
||||||
|
|
||||||
|
shared.state.assign_current_image(pp.image)
|
||||||
|
|
||||||
if save_output:
|
if save_output:
|
||||||
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
|
||||||
|
|
||||||
if pp.caption:
|
if pp.caption:
|
||||||
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
caption_filename = os.path.splitext(fullfn)[0] + ".txt"
|
||||||
if os.path.isfile(caption_filename):
|
existing_caption = ""
|
||||||
|
try:
|
||||||
with open(caption_filename, encoding="utf8") as file:
|
with open(caption_filename, encoding="utf8") as file:
|
||||||
existing_caption = file.read().strip()
|
existing_caption = file.read().strip()
|
||||||
else:
|
except FileNotFoundError:
|
||||||
existing_caption = ""
|
pass
|
||||||
|
|
||||||
action = shared.opts.postprocessing_existing_caption_action
|
action = shared.opts.postprocessing_existing_caption_action
|
||||||
if action == 'Prepend' and existing_caption:
|
if action == 'Prepend' and existing_caption:
|
||||||
|
@ -16,7 +16,7 @@ from skimage import exposure
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||||
from modules.rng import slerp # noqa: F401
|
from modules.rng import slerp # noqa: F401
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
@ -62,28 +62,35 @@ def apply_color_correction(correction, original_image):
|
|||||||
return image.convert('RGB')
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def apply_overlay(image, paste_loc, index, overlays):
|
def uncrop(image, dest_size, paste_loc):
|
||||||
if overlays is None or index >= len(overlays):
|
|
||||||
return image
|
|
||||||
|
|
||||||
overlay = overlays[index]
|
|
||||||
|
|
||||||
if paste_loc is not None:
|
|
||||||
x, y, w, h = paste_loc
|
x, y, w, h = paste_loc
|
||||||
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
base_image = Image.new('RGBA', dest_size)
|
||||||
image = images.resize_image(1, image, w, h)
|
image = images.resize_image(1, image, w, h)
|
||||||
base_image.paste(image, (x, y))
|
base_image.paste(image, (x, y))
|
||||||
image = base_image
|
image = base_image
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def apply_overlay(image, paste_loc, overlay):
|
||||||
|
if overlay is None:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if paste_loc is not None:
|
||||||
|
image = uncrop(image, (overlay.width, overlay.height), paste_loc)
|
||||||
|
|
||||||
image = image.convert('RGBA')
|
image = image.convert('RGBA')
|
||||||
image.alpha_composite(overlay)
|
image.alpha_composite(overlay)
|
||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def create_binary_mask(image):
|
def create_binary_mask(image, round=True):
|
||||||
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
|
||||||
|
if round:
|
||||||
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||||
|
else:
|
||||||
|
image = image.split()[-1].convert("L")
|
||||||
else:
|
else:
|
||||||
image = image.convert('L')
|
image = image.convert('L')
|
||||||
return image
|
return image
|
||||||
@ -106,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||||||
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
sd = sd_model.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
|
||||||
|
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
|
||||||
|
image_conditioning = images_tensor_to_samples(image_conditioning,
|
||||||
|
approximation_indexes.get(opts.sd_vae_encode_method))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
||||||
# Still takes up a bit of memory, but no encoder call.
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
@ -157,6 +179,7 @@ class StableDiffusionProcessing:
|
|||||||
token_merging_ratio = 0
|
token_merging_ratio = 0
|
||||||
token_merging_ratio_hr = 0
|
token_merging_ratio_hr = 0
|
||||||
disable_extra_networks: bool = False
|
disable_extra_networks: bool = False
|
||||||
|
firstpass_image: Image = None
|
||||||
|
|
||||||
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||||
script_args_value: list = field(default=None, init=False)
|
script_args_value: list = field(default=None, init=False)
|
||||||
@ -308,7 +331,7 @@ class StableDiffusionProcessing:
|
|||||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||||
return c_adm
|
return c_adm
|
||||||
|
|
||||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||||
self.is_using_inpainting_conditioning = True
|
self.is_using_inpainting_conditioning = True
|
||||||
|
|
||||||
# Handle the different mask inputs
|
# Handle the different mask inputs
|
||||||
@ -320,8 +343,10 @@ class StableDiffusionProcessing:
|
|||||||
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
if round_image_mask:
|
||||||
|
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
conditioning_mask = torch.round(conditioning_mask)
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
||||||
|
|
||||||
@ -345,7 +370,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
return image_conditioning
|
return image_conditioning
|
||||||
|
|
||||||
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
|
||||||
source_image = devices.cond_cast_float(source_image)
|
source_image = devices.cond_cast_float(source_image)
|
||||||
|
|
||||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||||
@ -357,11 +382,17 @@ class StableDiffusionProcessing:
|
|||||||
return self.edit_image_conditioning(source_image)
|
return self.edit_image_conditioning(source_image)
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)
|
||||||
|
|
||||||
if self.sampler.conditioning_key == "crossattn-adm":
|
if self.sampler.conditioning_key == "crossattn-adm":
|
||||||
return self.unclip_image_conditioning(source_image)
|
return self.unclip_image_conditioning(source_image)
|
||||||
|
|
||||||
|
sd = self.sampler.model_wrap.inner_model.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
|
|
||||||
@ -422,6 +453,8 @@ class StableDiffusionProcessing:
|
|||||||
opts.sdxl_crop_top,
|
opts.sdxl_crop_top,
|
||||||
self.width,
|
self.width,
|
||||||
self.height,
|
self.height,
|
||||||
|
opts.fp8_storage,
|
||||||
|
opts.cache_fp16_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||||
@ -596,20 +629,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
|||||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||||
|
|
||||||
if check_for_nans:
|
if check_for_nans:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
devices.test_for_nans(sample, "vae")
|
devices.test_for_nans(sample, "vae")
|
||||||
except devices.NansException as e:
|
except devices.NansException as e:
|
||||||
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
if shared.opts.auto_vae_precision_bfloat16:
|
||||||
|
autofix_dtype = torch.bfloat16
|
||||||
|
autofix_dtype_text = "bfloat16"
|
||||||
|
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
|
||||||
|
autofix_dtype_comment = ""
|
||||||
|
elif shared.opts.auto_vae_precision:
|
||||||
|
autofix_dtype = torch.float32
|
||||||
|
autofix_dtype_text = "32-bit float"
|
||||||
|
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
|
||||||
|
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if devices.dtype_vae == autofix_dtype:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
errors.print_error_explanation(
|
errors.print_error_explanation(
|
||||||
"A tensor with all NaNs was produced in VAE.\n"
|
"A tensor with all NaNs was produced in VAE.\n"
|
||||||
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
|
||||||
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
|
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
|
||||||
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
devices.dtype_vae = torch.float32
|
devices.dtype_vae = autofix_dtype
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
batch = batch.to(devices.dtype_vae)
|
batch = batch.to(devices.dtype_vae)
|
||||||
|
|
||||||
@ -679,12 +725,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
|
"FP8 weight": opts.fp8_storage if devices.fp8 else None,
|
||||||
|
"Cache FP16 weight for LoRA": opts.cache_fp16_weight if devices.fp8 else None,
|
||||||
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_vae_hash_to_info else None,
|
||||||
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_vae_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": p.extra_generation_params.get("Denoising strength"),
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||||
@ -699,7 +747,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"User": p.user if opts.add_user_name_to_info else None,
|
"User": p.user if opts.add_user_name_to_info else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
||||||
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
||||||
@ -818,7 +866,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.skipped:
|
if state.skipped:
|
||||||
state.skipped = False
|
state.skipped = False
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted or state.stopping_generation:
|
||||||
break
|
break
|
||||||
|
|
||||||
sd_models.reload_model_weights() # model can be changed for example by refiner
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
||||||
@ -864,9 +912,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
|
||||||
|
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
|
||||||
|
|
||||||
|
if opts.use_downcasted_alpha_bar:
|
||||||
|
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||||
|
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
|
||||||
|
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||||
|
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||||
|
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
||||||
|
|
||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
ps = scripts.PostSampleArgs(samples_ddim)
|
||||||
|
p.scripts.post_sample(p, ps)
|
||||||
|
samples_ddim = ps.samples
|
||||||
|
|
||||||
if getattr(samples_ddim, 'already_decoded', False):
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
x_samples_ddim = samples_ddim
|
x_samples_ddim = samples_ddim
|
||||||
else:
|
else:
|
||||||
@ -922,13 +1003,42 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
pp = scripts.PostprocessImageArgs(image)
|
pp = scripts.PostprocessImageArgs(image)
|
||||||
p.scripts.postprocess_image(p, pp)
|
p.scripts.postprocess_image(p, pp)
|
||||||
image = pp.image
|
image = pp.image
|
||||||
|
|
||||||
|
mask_for_overlay = getattr(p, "mask_for_overlay", None)
|
||||||
|
|
||||||
|
if not shared.opts.overlay_inpaint:
|
||||||
|
overlay_image = None
|
||||||
|
elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images):
|
||||||
|
overlay_image = p.overlay_images[i]
|
||||||
|
else:
|
||||||
|
overlay_image = None
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
|
||||||
|
p.scripts.postprocess_maskoverlay(p, ppmo)
|
||||||
|
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
|
||||||
|
|
||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
if save_samples and opts.save_images_before_color_correction:
|
if save_samples and opts.save_images_before_color_correction:
|
||||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
|
||||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||||
image = apply_color_correction(p.color_corrections[i], image)
|
image = apply_color_correction(p.color_corrections[i], image)
|
||||||
|
|
||||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
# If the intention is to show the output from the model
|
||||||
|
# that is being composited over the original image,
|
||||||
|
# we need to keep the original image around
|
||||||
|
# and use it in the composite step.
|
||||||
|
original_denoised_image = image.copy()
|
||||||
|
|
||||||
|
if p.paste_to is not None:
|
||||||
|
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
|
||||||
|
|
||||||
|
image = apply_overlay(image, p.paste_to, overlay_image)
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
pp = scripts.PostprocessImageArgs(image)
|
||||||
|
p.scripts.postprocess_image_after_composite(p, pp)
|
||||||
|
image = pp.image
|
||||||
|
|
||||||
if save_samples:
|
if save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||||
@ -938,16 +1048,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
|
||||||
|
if mask_for_overlay is not None:
|
||||||
if opts.return_mask or opts.save_mask:
|
if opts.return_mask or opts.save_mask:
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask = mask_for_overlay.convert('RGB')
|
||||||
if save_samples and opts.save_mask:
|
if save_samples and opts.save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
if opts.return_mask:
|
if opts.return_mask:
|
||||||
output_images.append(image_mask)
|
output_images.append(image_mask)
|
||||||
|
|
||||||
if opts.return_mask_composite or opts.save_mask_composite:
|
if opts.return_mask_composite or opts.save_mask_composite:
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
if save_samples and opts.save_mask_composite:
|
if save_samples and opts.save_mask_composite:
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
if opts.return_mask_composite:
|
if opts.return_mask_composite:
|
||||||
@ -1025,6 +1136,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
hr_sampler_name: str = None
|
hr_sampler_name: str = None
|
||||||
hr_prompt: str = ''
|
hr_prompt: str = ''
|
||||||
hr_negative_prompt: str = ''
|
hr_negative_prompt: str = ''
|
||||||
|
force_task_id: str = None
|
||||||
|
|
||||||
cached_hr_uc = [None, None]
|
cached_hr_uc = [None, None]
|
||||||
cached_hr_c = [None, None]
|
cached_hr_c = [None, None]
|
||||||
@ -1097,7 +1209,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
if self.hr_checkpoint_name:
|
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||||
|
|
||||||
|
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
|
||||||
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||||
|
|
||||||
if self.hr_checkpoint_info is None:
|
if self.hr_checkpoint_info is None:
|
||||||
@ -1124,8 +1238,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if not state.processing_has_refined_job_count:
|
if not state.processing_has_refined_job_count:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = self.n_iter
|
state.job_count = self.n_iter
|
||||||
|
if getattr(self, 'txt2img_upscale', False):
|
||||||
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count
|
||||||
|
else:
|
||||||
|
total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count
|
||||||
|
shared.total_tqdm.updateTotal(total_steps)
|
||||||
state.job_count = state.job_count * 2
|
state.job_count = state.job_count * 2
|
||||||
state.processing_has_refined_job_count = True
|
state.processing_has_refined_job_count = True
|
||||||
|
|
||||||
@ -1138,12 +1255,39 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
|
if self.firstpass_image is not None and self.enable_hr:
|
||||||
|
# here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
|
||||||
|
|
||||||
|
if self.latent_scale_mode is None:
|
||||||
|
image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||||
|
image = np.moveaxis(image, 2, 0)
|
||||||
|
|
||||||
|
samples = None
|
||||||
|
decoded_samples = torch.asarray(np.expand_dims(image, 0))
|
||||||
|
|
||||||
|
else:
|
||||||
|
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
|
||||||
|
image = np.moveaxis(image, 2, 0)
|
||||||
|
image = torch.from_numpy(np.expand_dims(image, axis=0))
|
||||||
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||||
|
|
||||||
|
if opts.sd_vae_encode_method != 'Full':
|
||||||
|
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||||
|
|
||||||
|
samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
|
||||||
|
decoded_samples = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
else:
|
||||||
|
# here we generate an image normally
|
||||||
|
|
||||||
x = self.rng.next()
|
x = self.rng.next()
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
del x
|
del x
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if self.latent_scale_mode is None:
|
if self.latent_scale_mode is None:
|
||||||
@ -1351,12 +1495,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
mask_blur_x: int = 4
|
mask_blur_x: int = 4
|
||||||
mask_blur_y: int = 4
|
mask_blur_y: int = 4
|
||||||
mask_blur: int = None
|
mask_blur: int = None
|
||||||
|
mask_round: bool = True
|
||||||
inpainting_fill: int = 0
|
inpainting_fill: int = 0
|
||||||
inpaint_full_res: bool = True
|
inpaint_full_res: bool = True
|
||||||
inpaint_full_res_padding: int = 0
|
inpaint_full_res_padding: int = 0
|
||||||
inpainting_mask_invert: int = 0
|
inpainting_mask_invert: int = 0
|
||||||
initial_noise_multiplier: float = None
|
initial_noise_multiplier: float = None
|
||||||
latent_mask: Image = None
|
latent_mask: Image = None
|
||||||
|
force_task_id: str = None
|
||||||
|
|
||||||
image_mask: Any = field(default=None, init=False)
|
image_mask: Any = field(default=None, init=False)
|
||||||
|
|
||||||
@ -1386,6 +1532,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.mask_blur_y = value
|
self.mask_blur_y = value
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
|
self.extra_generation_params["Denoising strength"] = self.denoising_strength
|
||||||
|
|
||||||
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
@ -1396,10 +1544,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
# image_mask is passed in as RGBA by Gradio to support alpha masks,
|
||||||
# but we still want to support binary masks.
|
# but we still want to support binary masks.
|
||||||
image_mask = create_binary_mask(image_mask)
|
image_mask = create_binary_mask(image_mask, round=self.mask_round)
|
||||||
|
|
||||||
if self.inpainting_mask_invert:
|
if self.inpainting_mask_invert:
|
||||||
image_mask = ImageOps.invert(image_mask)
|
image_mask = ImageOps.invert(image_mask)
|
||||||
|
self.extra_generation_params["Mask mode"] = "Inpaint not masked"
|
||||||
|
|
||||||
if self.mask_blur_x > 0:
|
if self.mask_blur_x > 0:
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
@ -1413,16 +1562,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||||
image_mask = Image.fromarray(np_mask)
|
image_mask = Image.fromarray(np_mask)
|
||||||
|
|
||||||
|
if self.mask_blur_x > 0 or self.mask_blur_y > 0:
|
||||||
|
self.extra_generation_params["Mask blur"] = self.mask_blur
|
||||||
|
|
||||||
if self.inpaint_full_res:
|
if self.inpaint_full_res:
|
||||||
self.mask_for_overlay = image_mask
|
self.mask_for_overlay = image_mask
|
||||||
mask = image_mask.convert('L')
|
mask = image_mask.convert('L')
|
||||||
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
crop_region = masking.get_crop_region(mask, self.inpaint_full_res_padding)
|
||||||
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
||||||
x1, y1, x2, y2 = crop_region
|
x1, y1, x2, y2 = crop_region
|
||||||
|
|
||||||
mask = mask.crop(crop_region)
|
mask = mask.crop(crop_region)
|
||||||
image_mask = images.resize_image(2, mask, self.width, self.height)
|
image_mask = images.resize_image(2, mask, self.width, self.height)
|
||||||
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
||||||
|
|
||||||
|
self.extra_generation_params["Inpaint area"] = "Only masked"
|
||||||
|
self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding
|
||||||
else:
|
else:
|
||||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
@ -1442,7 +1597,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
# Save init image
|
# Save init image
|
||||||
if opts.save_init_img:
|
if opts.save_init_img:
|
||||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||||
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
|
||||||
|
|
||||||
image = images.flatten(img, opts.img2img_background_color)
|
image = images.flatten(img, opts.img2img_background_color)
|
||||||
|
|
||||||
@ -1464,6 +1619,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
if self.inpainting_fill != 1:
|
if self.inpainting_fill != 1:
|
||||||
image = masking.fill(image, latent_mask)
|
image = masking.fill(image, latent_mask)
|
||||||
|
|
||||||
|
if self.inpainting_fill == 0:
|
||||||
|
self.extra_generation_params["Masked content"] = 'fill'
|
||||||
|
|
||||||
if add_color_corrections:
|
if add_color_corrections:
|
||||||
self.color_corrections.append(setup_color_correction(image))
|
self.color_corrections.append(setup_color_correction(image))
|
||||||
|
|
||||||
@ -1503,6 +1661,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
||||||
latmask = latmask[0]
|
latmask = latmask[0]
|
||||||
|
if self.mask_round:
|
||||||
latmask = np.around(latmask)
|
latmask = np.around(latmask)
|
||||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||||
|
|
||||||
@ -1512,10 +1671,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
# this needs to be fixed to be done in sample() using actual seeds for batches
|
# this needs to be fixed to be done in sample() using actual seeds for batches
|
||||||
if self.inpainting_fill == 2:
|
if self.inpainting_fill == 2:
|
||||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
||||||
|
self.extra_generation_params["Masked content"] = 'latent noise'
|
||||||
|
|
||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
self.extra_generation_params["Masked content"] = 'latent nothing'
|
||||||
|
|
||||||
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
|
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
x = self.rng.next()
|
x = self.rng.next()
|
||||||
@ -1527,7 +1689,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
blended_samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.scripts is not None:
|
||||||
|
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
|
||||||
|
self.scripts.on_mask_blend(self, mba)
|
||||||
|
blended_samples = mba.blended_latent
|
||||||
|
|
||||||
|
samples = blended_samples
|
||||||
|
|
||||||
del x
|
del x
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts, sd_models
|
from modules import scripts, sd_models
|
||||||
|
from modules.infotext_utils import PasteField
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
from modules.ui_components import InputAccordion
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
|
|||||||
return None if info is None else info.title
|
return None if info is None else info.title
|
||||||
|
|
||||||
self.infotext_fields = [
|
self.infotext_fields = [
|
||||||
(enable_refiner, lambda d: 'Refiner' in d),
|
PasteField(enable_refiner, lambda d: 'Refiner' in d),
|
||||||
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
|
||||||
(refiner_switch_at, 'Refiner switch at'),
|
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
|
||||||
]
|
]
|
||||||
|
|
||||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||||
|
@ -3,8 +3,10 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import scripts, ui, errors
|
from modules import scripts, ui, errors
|
||||||
|
from modules.infotext_utils import PasteField
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules.ui_components import ToolButton
|
from modules.ui_components import ToolButton
|
||||||
|
from modules import infotext_utils
|
||||||
|
|
||||||
|
|
||||||
class ScriptSeed(scripts.ScriptBuiltinUI):
|
class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||||
@ -51,12 +53,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
|||||||
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||||
|
|
||||||
self.infotext_fields = [
|
self.infotext_fields = [
|
||||||
(self.seed, "Seed"),
|
PasteField(self.seed, "Seed", api="seed"),
|
||||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
(subseed, "Variation seed"),
|
PasteField(subseed, "Variation seed", api="subseed"),
|
||||||
(subseed_strength, "Variation seed strength"),
|
PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
|
||||||
]
|
]
|
||||||
|
|
||||||
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||||
@ -76,7 +78,6 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
|||||||
p.seed_resize_from_h = seed_resize_from_h
|
p.seed_resize_from_h = seed_resize_from_h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||||
@ -84,21 +85,14 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
|
|
||||||
def copy_seed(gen_info_string: str, index):
|
def copy_seed(gen_info_string: str, index):
|
||||||
res = -1
|
res = -1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gen_info = json.loads(gen_info_string)
|
gen_info = json.loads(gen_info_string)
|
||||||
index -= gen_info.get('index_of_first_image', 0)
|
infotext = gen_info.get('infotexts')[index]
|
||||||
|
gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])
|
||||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))
|
||||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
except Exception:
|
||||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
|
||||||
else:
|
|
||||||
all_seeds = gen_info.get('all_seeds', [-1])
|
|
||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
if gen_info_string:
|
if gen_info_string:
|
||||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
errors.report(f"Error retrieving seed from generation info: {gen_info_string}", exc_info=True)
|
||||||
|
|
||||||
return [res, gr.update()]
|
return [res, gr.update()]
|
||||||
|
|
||||||
|
@ -8,10 +8,13 @@ from pydantic import BaseModel, Field
|
|||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
from collections import OrderedDict
|
||||||
|
import string
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
current_task = None
|
current_task = None
|
||||||
pending_tasks = {}
|
pending_tasks = OrderedDict()
|
||||||
finished_tasks = []
|
finished_tasks = []
|
||||||
recorded_results = []
|
recorded_results = []
|
||||||
recorded_results_limit = 2
|
recorded_results_limit = 2
|
||||||
@ -34,6 +37,11 @@ def finish_task(id_task):
|
|||||||
if len(finished_tasks) > 16:
|
if len(finished_tasks) > 16:
|
||||||
finished_tasks.pop(0)
|
finished_tasks.pop(0)
|
||||||
|
|
||||||
|
def create_task_id(task_type):
|
||||||
|
N = 7
|
||||||
|
res = ''.join(random.choices(string.ascii_uppercase +
|
||||||
|
string.digits, k=N))
|
||||||
|
return f"task({task_type}-{res})"
|
||||||
|
|
||||||
def record_results(id_task, res):
|
def record_results(id_task, res):
|
||||||
recorded_results.append((id_task, res))
|
recorded_results.append((id_task, res))
|
||||||
@ -44,6 +52,9 @@ def record_results(id_task, res):
|
|||||||
def add_task_to_queue(id_job):
|
def add_task_to_queue(id_job):
|
||||||
pending_tasks[id_job] = time.time()
|
pending_tasks[id_job] = time.time()
|
||||||
|
|
||||||
|
class PendingTasksResponse(BaseModel):
|
||||||
|
size: int = Field(title="Pending task size")
|
||||||
|
tasks: List[str] = Field(title="Pending task ids")
|
||||||
|
|
||||||
class ProgressRequest(BaseModel):
|
class ProgressRequest(BaseModel):
|
||||||
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
|
||||||
@ -63,9 +74,16 @@ class ProgressResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def setup_progress_api(app):
|
def setup_progress_api(app):
|
||||||
|
app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
|
||||||
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pending_tasks():
|
||||||
|
pending_tasks_ids = list(pending_tasks)
|
||||||
|
pending_len = len(pending_tasks_ids)
|
||||||
|
return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
|
||||||
|
|
||||||
|
|
||||||
def progressapi(req: ProgressRequest):
|
def progressapi(req: ProgressRequest):
|
||||||
active = req.id_task == current_task
|
active = req.id_task == current_task
|
||||||
queued = req.id_task in pending_tasks
|
queued = req.id_task in pending_tasks
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import cmd_opts, opts
|
|
||||||
from modules import modelloader, errors
|
from modules import modelloader, errors
|
||||||
|
from modules.shared import cmd_opts, opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
@ -14,13 +11,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.name = "RealESRGAN"
|
self.name = "RealESRGAN"
|
||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
|
||||||
from realesrgan import RealESRGANer # noqa: F401
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
|
||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = get_realesrgan_models(self)
|
||||||
|
|
||||||
local_model_paths = self.find_models(ext_filter=[".pth"])
|
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||||
for scaler in scalers:
|
for scaler in scalers:
|
||||||
@ -33,11 +26,6 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
if scaler.name in opts.realesrgan_enabled_models:
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error importing Real-ESRGAN", exc_info=True)
|
|
||||||
self.enable = False
|
|
||||||
self.scalers = []
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
if not self.enable:
|
if not self.enable:
|
||||||
return img
|
return img
|
||||||
@ -48,20 +36,19 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
model_descriptor = modelloader.load_spandrel_model(
|
||||||
scale=info.scale,
|
info.local_data_path,
|
||||||
model_path=info.local_data_path,
|
|
||||||
model=info.model(),
|
|
||||||
half=not cmd_opts.no_half and not cmd_opts.upcast_sampling,
|
|
||||||
tile=opts.ESRGAN_tile,
|
|
||||||
tile_pad=opts.ESRGAN_tile_overlap,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||||
|
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||||
|
)
|
||||||
|
return upscale_with_model(
|
||||||
|
model_descriptor,
|
||||||
|
img,
|
||||||
|
tile_size=opts.ESRGAN_tile,
|
||||||
|
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||||
|
# TODO: `outscale`?
|
||||||
)
|
)
|
||||||
|
|
||||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
|
||||||
|
|
||||||
image = Image.fromarray(upsampled)
|
|
||||||
return image
|
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
for scaler in self.scalers:
|
for scaler in self.scalers:
|
||||||
@ -76,58 +63,43 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
return scaler
|
return scaler
|
||||||
raise ValueError(f"Unable to find model info: {path}")
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
|
|
||||||
def load_models(self, _):
|
|
||||||
return get_realesrgan_models(self)
|
|
||||||
|
|
||||||
|
def get_realesrgan_models(scaler: UpscalerRealESRGAN):
|
||||||
def get_realesrgan_models(scaler):
|
return [
|
||||||
try:
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
models = [
|
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN General 4xV3",
|
name="R-ESRGAN General 4xV3",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN General WDN 4xV3",
|
name="R-ESRGAN General WDN 4xV3",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN AnimeVideo",
|
name="R-ESRGAN AnimeVideo",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN 4x+",
|
name="R-ESRGAN 4x+",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN 4x+ Anime6B",
|
name="R-ESRGAN 4x+ Anime6B",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
scale=4,
|
scale=4,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
|
||||||
),
|
),
|
||||||
UpscalerData(
|
UpscalerData(
|
||||||
name="R-ESRGAN 2x+",
|
name="R-ESRGAN 2x+",
|
||||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
scale=2,
|
scale=2,
|
||||||
upscaler=scaler,
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
|
||||||
except Exception:
|
|
||||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
|
||||||
|
@ -41,7 +41,7 @@ class ExtraNoiseParams:
|
|||||||
|
|
||||||
|
|
||||||
class CFGDenoiserParams:
|
class CFGDenoiserParams:
|
||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
@ -63,6 +63,9 @@ class CFGDenoiserParams:
|
|||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
|
self.denoiser = denoiser
|
||||||
|
"""Current CFGDenoiser object with processing parameters"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoisedParams:
|
class CFGDenoisedParams:
|
||||||
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||||
|
@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
|
|||||||
|
|
||||||
AlwaysVisible = object()
|
AlwaysVisible = object()
|
||||||
|
|
||||||
|
class MaskBlendArgs:
|
||||||
|
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
|
||||||
|
self.current_latent = current_latent
|
||||||
|
self.nmask = nmask
|
||||||
|
self.init_latent = init_latent
|
||||||
|
self.mask = mask
|
||||||
|
self.blended_latent = blended_latent
|
||||||
|
|
||||||
|
self.denoiser = denoiser
|
||||||
|
self.is_final_blend = denoiser is None
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
class PostSampleArgs:
|
||||||
|
def __init__(self, samples):
|
||||||
|
self.samples = samples
|
||||||
|
|
||||||
class PostprocessImageArgs:
|
class PostprocessImageArgs:
|
||||||
def __init__(self, image):
|
def __init__(self, image):
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
class PostProcessMaskOverlayArgs:
|
||||||
|
def __init__(self, index, mask_for_overlay, overlay_image):
|
||||||
|
self.index = index
|
||||||
|
self.mask_for_overlay = mask_for_overlay
|
||||||
|
self.overlay_image = overlay_image
|
||||||
|
|
||||||
class PostprocessBatchListArgs:
|
class PostprocessBatchListArgs:
|
||||||
def __init__(self, images):
|
def __init__(self, images):
|
||||||
@ -71,6 +91,9 @@ class Script:
|
|||||||
setup_for_ui_only = False
|
setup_for_ui_only = False
|
||||||
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
"""If true, the script setup will only be run in Gradio UI, not in API"""
|
||||||
|
|
||||||
|
controls = None
|
||||||
|
"""A list of controls retured by the ui()."""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -206,6 +229,25 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
|
||||||
|
"""
|
||||||
|
Called in inpainting mode when the original content is blended with the inpainted content.
|
||||||
|
This is called at every step in the denoising process and once at the end.
|
||||||
|
If is_final_blend is true, this is called for the final blending stage.
|
||||||
|
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: PostSampleArgs, *args):
|
||||||
|
"""
|
||||||
|
Called after the samples have been generated,
|
||||||
|
but before they have been decoded by the VAE, if applicable.
|
||||||
|
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||||
"""
|
"""
|
||||||
Called for every image after it has been generated.
|
Called for every image after it has been generated.
|
||||||
@ -213,6 +255,22 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
|
||||||
|
"""
|
||||||
|
Called for every image after it has been generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):
|
||||||
|
"""
|
||||||
|
Called for every image after it has been generated.
|
||||||
|
Same as postprocess_image but after inpaint_full_res composite
|
||||||
|
So that it operates on the full image instead of the inpaint_full_res crop region.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess(self, p, processed, *args):
|
||||||
"""
|
"""
|
||||||
This function is called after processing ends for AlwaysVisible scripts.
|
This function is called after processing ends for AlwaysVisible scripts.
|
||||||
@ -520,7 +578,12 @@ class ScriptRunner:
|
|||||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
for script_data in auto_processing_scripts + scripts_data:
|
for script_data in auto_processing_scripts + scripts_data:
|
||||||
|
try:
|
||||||
script = script_data.script_class()
|
script = script_data.script_class()
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
script.filename = script_data.path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
@ -573,6 +636,7 @@ class ScriptRunner:
|
|||||||
import modules.api.models as api_models
|
import modules.api.models as api_models
|
||||||
|
|
||||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||||
|
script.controls = controls
|
||||||
|
|
||||||
if controls is None:
|
if controls is None:
|
||||||
return
|
return
|
||||||
@ -645,6 +709,8 @@ class ScriptRunner:
|
|||||||
self.setup_ui_for_section(None, self.selectable_scripts)
|
self.setup_ui_for_section(None, self.selectable_scripts)
|
||||||
|
|
||||||
def select_script(script_index):
|
def select_script(script_index):
|
||||||
|
if script_index is None:
|
||||||
|
script_index = 0
|
||||||
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
|
||||||
|
|
||||||
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
|
||||||
@ -688,7 +754,7 @@ class ScriptRunner:
|
|||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
script_index = args[0]
|
script_index = args[0]
|
||||||
|
|
||||||
if script_index == 0:
|
if script_index == 0 or script_index is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
script = self.selectable_scripts[script_index-1]
|
script = self.selectable_scripts[script_index-1]
|
||||||
@ -767,6 +833,22 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def post_sample(self, p, ps: PostSampleArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.post_sample(p, ps, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def on_mask_blend(self, p, mba: MaskBlendArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.on_mask_blend(p, mba, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@ -775,6 +857,22 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_maskoverlay(p, ppmo, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_image_after_composite(p, pp, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running postprocess_image_after_composite: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||||
try:
|
try:
|
||||||
@ -841,6 +939,35 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):
|
||||||
|
"""Locate an arg of a specific script in script_args and set its value
|
||||||
|
Args:
|
||||||
|
args: all script args of process p, p.script_args
|
||||||
|
script_name: the name target script name to
|
||||||
|
arg_elem_id: the elem_id of the target arg
|
||||||
|
value: the value to set
|
||||||
|
fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match
|
||||||
|
Returns:
|
||||||
|
Updated script args
|
||||||
|
when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError
|
||||||
|
"""
|
||||||
|
script = next((x for x in self.scripts if x.name == script_name), None)
|
||||||
|
if script is None:
|
||||||
|
raise RuntimeError(f"script {script_name} not found")
|
||||||
|
|
||||||
|
for i, control in enumerate(script.controls):
|
||||||
|
if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:
|
||||||
|
index = script.args_from + i
|
||||||
|
|
||||||
|
if isinstance(args, tuple):
|
||||||
|
return args[:index] + (value,) + args[index + 1:]
|
||||||
|
elif isinstance(args, list):
|
||||||
|
args[index] = value
|
||||||
|
return args
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"args is not a list or tuple, but {type(args)}")
|
||||||
|
raise RuntimeError(f"arg_elem_id {arg_elem_id} not found in script {script_name}")
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img: ScriptRunner = None
|
scripts_txt2img: ScriptRunner = None
|
||||||
scripts_img2img: ScriptRunner = None
|
scripts_img2img: ScriptRunner = None
|
||||||
|
@ -11,10 +11,14 @@ class CondFunc:
|
|||||||
break
|
break
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
for attr_name in func_path[i:-1]:
|
for attr_name in func_path[i:-1]:
|
||||||
resolved_obj = getattr(resolved_obj, attr_name)
|
resolved_obj = getattr(resolved_obj, attr_name)
|
||||||
orig_func = getattr(resolved_obj, func_path[-1])
|
orig_func = getattr(resolved_obj, func_path[-1])
|
||||||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||||
|
except AttributeError:
|
||||||
|
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
|
||||||
|
pass
|
||||||
self.__init__(orig_func, sub_func, cond_func)
|
self.__init__(orig_func, sub_func, cond_func)
|
||||||
return lambda *args, **kwargs: self(*args, **kwargs)
|
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||||
def __init__(self, orig_func, sub_func, cond_func):
|
def __init__(self, orig_func, sub_func, cond_func):
|
||||||
|
@ -348,10 +348,28 @@ class SkipWritingToConfig:
|
|||||||
SkipWritingToConfig.skip = self.previous
|
SkipWritingToConfig.skip = self.previous
|
||||||
|
|
||||||
|
|
||||||
|
def check_fp8(model):
|
||||||
|
if model is None:
|
||||||
|
return None
|
||||||
|
if devices.get_optimal_device_name() == "mps":
|
||||||
|
enable_fp8 = False
|
||||||
|
elif shared.opts.fp8_storage == "Enable":
|
||||||
|
enable_fp8 = True
|
||||||
|
elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL":
|
||||||
|
enable_fp8 = True
|
||||||
|
else:
|
||||||
|
enable_fp8 = False
|
||||||
|
return enable_fp8
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
timer.record("calculate hash")
|
timer.record("calculate hash")
|
||||||
|
|
||||||
|
if devices.fp8:
|
||||||
|
# prevent model to load state dict in fp8
|
||||||
|
model.half()
|
||||||
|
|
||||||
if not SkipWritingToConfig.skip:
|
if not SkipWritingToConfig.skip:
|
||||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||||
|
|
||||||
@ -383,6 +401,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
model.float()
|
model.float()
|
||||||
|
model.alphas_cumprod_original = model.alphas_cumprod
|
||||||
devices.dtype_unet = torch.float32
|
devices.dtype_unet = torch.float32
|
||||||
timer.record("apply float()")
|
timer.record("apply float()")
|
||||||
else:
|
else:
|
||||||
@ -396,7 +415,11 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||||
model.depth_model = None
|
model.depth_model = None
|
||||||
|
|
||||||
|
alphas_cumprod = model.alphas_cumprod
|
||||||
|
model.alphas_cumprod = None
|
||||||
model.half()
|
model.half()
|
||||||
|
model.alphas_cumprod = alphas_cumprod
|
||||||
|
model.alphas_cumprod_original = alphas_cumprod
|
||||||
model.first_stage_model = vae
|
model.first_stage_model = vae
|
||||||
if depth_model:
|
if depth_model:
|
||||||
model.depth_model = depth_model
|
model.depth_model = depth_model
|
||||||
@ -404,6 +427,28 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
devices.dtype_unet = torch.float16
|
devices.dtype_unet = torch.float16
|
||||||
timer.record("apply half()")
|
timer.record("apply half()")
|
||||||
|
|
||||||
|
for module in model.modules():
|
||||||
|
if hasattr(module, 'fp16_weight'):
|
||||||
|
del module.fp16_weight
|
||||||
|
if hasattr(module, 'fp16_bias'):
|
||||||
|
del module.fp16_bias
|
||||||
|
|
||||||
|
if check_fp8(model):
|
||||||
|
devices.fp8 = True
|
||||||
|
first_stage = model.first_stage_model
|
||||||
|
model.first_stage_model = None
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||||
|
if shared.opts.cache_fp16_weight:
|
||||||
|
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||||
|
if module.bias is not None:
|
||||||
|
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||||
|
module.to(torch.float8_e4m3fn)
|
||||||
|
model.first_stage_model = first_stage
|
||||||
|
timer.record("apply fp8")
|
||||||
|
else:
|
||||||
|
devices.fp8 = False
|
||||||
|
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
@ -651,6 +696,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
else:
|
else:
|
||||||
weight_dtype_conversion = {
|
weight_dtype_conversion = {
|
||||||
'first_stage_model': None,
|
'first_stage_model': None,
|
||||||
|
'alphas_cumprod': None,
|
||||||
'': torch.float16,
|
'': torch.float16,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -746,7 +792,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
@ -758,11 +804,14 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
else:
|
else:
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if check_fp8(sd_model) != devices.fp8:
|
||||||
|
# load from state dict again to prevent extra numerical errors
|
||||||
|
forced_reload = True
|
||||||
|
elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
if sd_model is not None:
|
if sd_model is not None:
|
||||||
@ -793,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
|
||||||
timer.record("script callbacks")
|
|
||||||
|
|
||||||
if not sd_model.lowvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
timer.record("script callbacks")
|
||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
model_data.set_sd_model(sd_model)
|
model_data.set_sd_model(sd_model)
|
||||||
|
@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
|||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
|
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
@ -71,6 +72,9 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
|
||||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
return config_sdxl_inpainting
|
||||||
|
else:
|
||||||
return config_sdxl
|
return config_sdxl
|
||||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
return config_sdxl_refiner
|
return config_sdxl_refiner
|
||||||
|
@ -6,6 +6,7 @@ import sgm.models.diffusion
|
|||||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
import sgm.modules.diffusionmodules.discretizer
|
import sgm.modules.diffusionmodules.discretizer
|
||||||
from modules import devices, shared, prompt_parser
|
from modules import devices, shared, prompt_parser
|
||||||
|
from modules import torch_utils
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
@ -34,6 +35,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
|||||||
|
|
||||||
|
|
||||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||||
|
sd = self.model.state_dict()
|
||||||
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||||
|
if diffusion_model_input is not None:
|
||||||
|
if diffusion_model_input.shape[1] == 9:
|
||||||
|
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||||
|
|
||||||
return self.model(x, t, cond)
|
return self.model(x, t, cond)
|
||||||
|
|
||||||
|
|
||||||
@ -84,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
|
|||||||
def extend_sdxl(model):
|
def extend_sdxl(model):
|
||||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
|
||||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
||||||
model.model.diffusion_model.dtype = dtype
|
model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
model.model.conditioning_key = 'crossattn'
|
||||||
model.cond_stage_key = 'txt'
|
model.cond_stage_key = 'txt'
|
||||||
@ -93,7 +100,7 @@ def extend_sdxl(model):
|
|||||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
|
|
||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
||||||
|
|
||||||
model.conditioner.wrapped = torch.nn.Module()
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared
|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
|
|||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
*sd_samplers_timesteps.samplers_data_timesteps,
|
*sd_samplers_timesteps.samplers_data_timesteps,
|
||||||
|
*sd_samplers_lcm.samplers_data_lcm,
|
||||||
]
|
]
|
||||||
all_samplers_map = {x.name: x for x in all_samplers}
|
all_samplers_map = {x.name: x for x in all_samplers}
|
||||||
|
|
||||||
|
@ -53,9 +53,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
self.image_cfg_scale = None
|
self.image_cfg_scale = None
|
||||||
self.padded_cond_uncond = False
|
self.padded_cond_uncond = False
|
||||||
|
self.padded_cond_uncond_v0 = False
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.p = None
|
self.p = None
|
||||||
|
|
||||||
|
# NOTE: masking before denoising can cause the original latents to be oversmoothed
|
||||||
|
# as the original latents do not have noise
|
||||||
self.mask_before_denoising = False
|
self.mask_before_denoising = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -88,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.sampler.sampler_extra_args['cond'] = c
|
self.sampler.sampler_extra_args['cond'] = c
|
||||||
self.sampler.sampler_extra_args['uncond'] = uc
|
self.sampler.sampler_extra_args['uncond'] = uc
|
||||||
|
|
||||||
|
def pad_cond_uncond(self, cond, uncond):
|
||||||
|
empty = shared.sd_model.cond_stage_model_empty_prompt
|
||||||
|
num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
|
if num_repeats < 0:
|
||||||
|
cond = pad_cond(cond, -num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
elif num_repeats > 0:
|
||||||
|
uncond = pad_cond(uncond, num_repeats, empty)
|
||||||
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
|
return cond, uncond
|
||||||
|
|
||||||
|
def pad_cond_uncond_v0(self, cond, uncond):
|
||||||
|
"""
|
||||||
|
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
|
||||||
|
|
||||||
|
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
|
||||||
|
If 'uncond' is a tensor, it is padded directly.
|
||||||
|
|
||||||
|
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
|
||||||
|
is repeated to match the number of columns in 'cond'.
|
||||||
|
|
||||||
|
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
|
||||||
|
to match the number of columns in 'cond'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
|
||||||
|
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is the padding that was always used in DDIM before version 1.6.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_dict_cond = isinstance(uncond, dict)
|
||||||
|
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
|
||||||
|
|
||||||
|
if uncond_vec.shape[1] < cond.shape[1]:
|
||||||
|
last_vector = uncond_vec[:, -1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
|
||||||
|
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
|
||||||
|
self.padded_cond_uncond_v0 = True
|
||||||
|
elif uncond_vec.shape[1] > cond.shape[1]:
|
||||||
|
uncond_vec = uncond_vec[:, :cond.shape[1]]
|
||||||
|
self.padded_cond_uncond_v0 = True
|
||||||
|
|
||||||
|
if is_dict_cond:
|
||||||
|
uncond['crossattn'] = uncond_vec
|
||||||
|
else:
|
||||||
|
uncond = uncond_vec
|
||||||
|
|
||||||
|
return cond, uncond
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
@ -105,8 +165,21 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
|
# If we use masks, blending between the denoised and original latent images occurs here.
|
||||||
|
def apply_blend(current_latent):
|
||||||
|
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.p.scripts is not None:
|
||||||
|
from modules import scripts
|
||||||
|
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
|
||||||
|
self.p.scripts.on_mask_blend(self.p, mba)
|
||||||
|
blended_latent = mba.blended_latent
|
||||||
|
|
||||||
|
return blended_latent
|
||||||
|
|
||||||
|
# Blend in the original latents (before)
|
||||||
if self.mask_before_denoising and self.mask is not None:
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
x = self.init_latent * self.mask + self.nmask * x
|
x = apply_blend(x)
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
@ -130,7 +203,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||||
|
|
||||||
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
|
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)
|
||||||
cfg_denoiser_callback(denoiser_params)
|
cfg_denoiser_callback(denoiser_params)
|
||||||
x_in = denoiser_params.x
|
x_in = denoiser_params.x
|
||||||
image_cond_in = denoiser_params.image_cond
|
image_cond_in = denoiser_params.image_cond
|
||||||
@ -146,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
sigma_in = sigma_in[:-batch_size]
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
self.padded_cond_uncond = False
|
self.padded_cond_uncond = False
|
||||||
|
self.padded_cond_uncond_v0 = False
|
||||||
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
|
||||||
empty = shared.sd_model.cond_stage_model_empty_prompt
|
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
|
||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
|
||||||
|
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
|
||||||
if num_repeats < 0:
|
|
||||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
elif num_repeats > 0:
|
|
||||||
uncond = pad_cond(uncond, num_repeats, empty)
|
|
||||||
self.padded_cond_uncond = True
|
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
if is_edit_model:
|
if is_edit_model:
|
||||||
@ -207,8 +275,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
# Blend in the original latents (after)
|
||||||
if not self.mask_before_denoising and self.mask is not None:
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = apply_blend(denoised)
|
||||||
|
|
||||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
|
@ -335,3 +335,10 @@ class Sampler:
|
|||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_infotext(self, p):
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
p.extra_generation_params["Pad conds"] = True
|
||||||
|
|
||||||
|
if self.model_wrap_cfg.padded_cond_uncond_v0:
|
||||||
|
p.extra_generation_params["Pad conds v0"] = True
|
||||||
|
@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
self.add_infotext(p)
|
||||||
p.extra_generation_params["Pad conds"] = True
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
self.add_infotext(p)
|
||||||
p.extra_generation_params["Pad conds"] = True
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
104
modules/sd_samplers_lcm.py
Normal file
104
modules/sd_samplers_lcm.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from k_diffusion import utils, sampling
|
||||||
|
from k_diffusion.external import DiscreteEpsDDPMDenoiser
|
||||||
|
from k_diffusion.sampling import default_noise_sampler, trange
|
||||||
|
|
||||||
|
from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common
|
||||||
|
|
||||||
|
|
||||||
|
class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||||
|
def __init__(self, model):
|
||||||
|
timesteps = 1000
|
||||||
|
original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
|
||||||
|
self.skip_steps = timesteps // original_timesteps
|
||||||
|
|
||||||
|
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
|
||||||
|
for x in range(original_timesteps):
|
||||||
|
alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
|
super().__init__(model, alphas_cumprod_valid, quantize=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sigmas(self, n=None,):
|
||||||
|
if n is None:
|
||||||
|
return sampling.append_zero(self.sigmas.flip(0))
|
||||||
|
|
||||||
|
start = self.sigma_to_t(self.sigma_max)
|
||||||
|
end = self.sigma_to_t(self.sigma_min)
|
||||||
|
|
||||||
|
t = torch.linspace(start, end, n, device=shared.sd_model.device)
|
||||||
|
|
||||||
|
return sampling.append_zero(self.t_to_sigma(t))
|
||||||
|
|
||||||
|
|
||||||
|
def sigma_to_t(self, sigma, quantize=None):
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def t_to_sigma(self, timestep):
|
||||||
|
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
|
return super().t_to_sigma(t)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eps(self, *args, **kwargs):
|
||||||
|
return self.inner_model.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scaled_out(self, sigma, output, input):
|
||||||
|
sigma_data = 0.5
|
||||||
|
scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
|
||||||
|
|
||||||
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||||
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||||
|
|
||||||
|
return c_out * output + c_skip * input
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input, sigma, **kwargs):
|
||||||
|
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||||
|
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||||
|
return self.get_scaled_out(sigma, input + eps * c_out, input)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
x = denoised
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
|
||||||
|
@property
|
||||||
|
def inner_model(self):
|
||||||
|
if self.model_wrap is None:
|
||||||
|
denoiser = LCMCompVisDenoiser
|
||||||
|
self.model_wrap = denoiser(shared.sd_model)
|
||||||
|
|
||||||
|
return self.model_wrap
|
||||||
|
|
||||||
|
|
||||||
|
class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||||
|
def __init__(self, funcname, sd_model, options=None):
|
||||||
|
super().__init__(funcname, sd_model, options)
|
||||||
|
self.model_wrap_cfg = CFGDenoiserLCM(self)
|
||||||
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
|
|
||||||
|
|
||||||
|
samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
|
||||||
|
samplers_data_lcm = [
|
||||||
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
|
||||||
|
for label, funcname, aliases, options in samplers_lcm
|
||||||
|
]
|
@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
|
|||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
|
|
||||||
def predict_eps_from_z_and_v(self, x_t, t, v):
|
def predict_eps_from_z_and_v(self, x_t, t, v):
|
||||||
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
|
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
|
||||||
|
|
||||||
def forward(self, input, timesteps, **kwargs):
|
def forward(self, input, timesteps, **kwargs):
|
||||||
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
|
||||||
@ -80,6 +80,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
self.eta_default = 0.0
|
self.eta_default = 0.0
|
||||||
|
|
||||||
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
self.model_wrap_cfg = CFGDenoiserTimesteps(self)
|
||||||
|
self.model_wrap = self.model_wrap_cfg.inner_model
|
||||||
|
|
||||||
def get_timesteps(self, p, steps):
|
def get_timesteps(self, p, steps):
|
||||||
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
|
||||||
@ -132,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
self.add_infotext(p)
|
||||||
p.extra_generation_params["Pad conds"] = True
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -157,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
}
|
}
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
self.add_infotext(p)
|
||||||
p.extra_generation_params["Pad conds"] = True
|
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
load_vae(sd_model, vae_file, vae_source)
|
load_vae(sd_model, vae_file, vae_source)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
|
||||||
|
|
||||||
if not sd_model.lowvram:
|
if not sd_model.lowvram:
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
print("VAE weights loaded.")
|
print("VAE weights loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
|
|||||||
|
|
||||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||||
parallel_processing_allowed = True
|
parallel_processing_allowed = True
|
||||||
styles_filename = cmd_opts.styles_file
|
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
|
@ -65,3 +65,7 @@ def reload_gradio_theme(theme_name=None):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "changing gradio theme")
|
errors.display(e, "changing gradio theme")
|
||||||
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
shared.gradio_theme = gr.themes.Default(**default_theme_args)
|
||||||
|
|
||||||
|
# append additional values gradio_theme
|
||||||
|
shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity
|
||||||
|
shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity
|
||||||
|
@ -18,8 +18,10 @@ def initialize():
|
|||||||
shared.options_templates = shared_options.options_templates
|
shared.options_templates = shared_options.options_templates
|
||||||
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
|
||||||
shared.restricted_opts = shared_options.restricted_opts
|
shared.restricted_opts = shared_options.restricted_opts
|
||||||
if os.path.exists(shared.config_filename):
|
try:
|
||||||
shared.opts.load(shared.config_filename)
|
shared.opts.load(shared.config_filename)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
from modules import devices
|
from modules import devices
|
||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||||
@ -27,6 +29,7 @@ def initialize():
|
|||||||
|
|
||||||
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
|
||||||
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
|
||||||
|
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
|
||||||
|
|
||||||
shared.device = devices.device
|
shared.device = devices.device
|
||||||
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||||
|
@ -8,6 +8,11 @@ def realesrgan_models_names():
|
|||||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||||
|
|
||||||
|
|
||||||
|
def dat_models_names():
|
||||||
|
import modules.dat_model
|
||||||
|
return [x.name for x in modules.dat_model.get_dat_models(None)]
|
||||||
|
|
||||||
|
|
||||||
def postprocessing_scripts():
|
def postprocessing_scripts():
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
@ -67,14 +72,14 @@ def reload_hypernetworks():
|
|||||||
|
|
||||||
|
|
||||||
def get_infotext_names():
|
def get_infotext_names():
|
||||||
from modules import generation_parameters_copypaste, shared
|
from modules import infotext_utils, shared
|
||||||
res = {}
|
res = {}
|
||||||
|
|
||||||
for info in shared.opts.data_labels.values():
|
for info in shared.opts.data_labels.values():
|
||||||
if info.infotext:
|
if info.infotext:
|
||||||
res[info.infotext] = 1
|
res[info.infotext] = 1
|
||||||
|
|
||||||
for tab_data in generation_parameters_copypaste.paste_fields.values():
|
for tab_data in infotext_utils.paste_fields.values():
|
||||||
for _, name in tab_data.get("fields") or []:
|
for _, name in tab_data.get("fields") or []:
|
||||||
if isinstance(name, str):
|
if isinstance(name, str):
|
||||||
res[name] = 1
|
res[name] = 1
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
|
import os
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
|
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
|
||||||
from modules.shared_cmd_options import cmd_opts
|
from modules.shared_cmd_options import cmd_opts
|
||||||
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
from modules.options import options_section, OptionInfo, OptionHTML, categories
|
||||||
|
|
||||||
@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
|
|
||||||
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
"outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
"outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
|
"outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
|
||||||
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
"outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
"outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
|
||||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
"outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||||
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
"outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {
|
||||||
@ -96,6 +97,9 @@ options_templates.update(options_section(('upscaling', "Upscaling", "postprocess
|
|||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
|
"dat_enabled_models": OptionInfo(["DAT x2", "DAT x3", "DAT x4"], "Select which DAT models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.dat_models_names()}),
|
||||||
|
"DAT_tile": OptionInfo(192, "Tile size for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"),
|
||||||
|
"DAT_tile_overlap": OptionInfo(8, "Tile overlap for DAT upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -114,6 +118,7 @@ options_templates.update(options_section(('system', "System", "system"), {
|
|||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
|
"enable_upscale_progressbar": OptionInfo(True, "Show a progress bar in the console for tiled upscaling."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
@ -176,6 +181,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
|
"auto_vae_precision_bfloat16": OptionInfo(False, "Automatically convert VAE to bfloat16").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below"),
|
||||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||||
@ -195,6 +201,7 @@ options_templates.update(options_section(('img2img', "img2img", "sd"), {
|
|||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
"img2img_batch_show_results_limit": OptionInfo(32, "Show the first N batch img2img results in UI", gr.Slider, {"minimum": -1, "maximum": 1000, "step": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),
|
||||||
|
"overlay_inpaint": OptionInfo(True, "Overlay original for inpaint").info("when inpainting, overlay the original image over the areas that weren't inpainted."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
options_templates.update(options_section(('optimizations', "Optimizations", "sd"), {
|
||||||
@ -203,12 +210,16 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
|||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
|
||||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||||
|
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; ignored if the above is set; changes seeds"),
|
||||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||||
|
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||||
|
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||||
|
"auto_backcompat": OptionInfo(True, "Automatic backward compatibility").info("automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version."),
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
@ -216,6 +227,7 @@ options_templates.update(options_section(('compatibility', "Compatibility", "sd"
|
|||||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
"use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"),
|
||||||
|
"use_downcasted_alpha_bar": OptionInfo(False, "Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.", infotext="Downcast alphas_cumprod")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
@ -244,6 +256,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks", "s
|
|||||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||||
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
"extra_networks_card_order_field": OptionInfo("Path", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||||
|
"extra_networks_tree_view_default_enabled": OptionInfo(False, "Enables the Extra Networks directory tree view by default").needs_reload_ui(),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||||
@ -267,6 +280,8 @@ options_templates.update(options_section(('ui_gallery', "Gallery", "ui"), {
|
|||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Full page image viewer: show images zoomed in by default"),
|
||||||
"js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
|
"js_modal_lightbox_gamepad": OptionInfo(False, "Full page image viewer: navigate with gamepad"),
|
||||||
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Full page image viewer: gamepad repeat period").info("in milliseconds"),
|
||||||
|
"sd_webui_modal_lightbox_icon_opacity": OptionInfo(1, "Full page image viewer: control icon unfocused opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||||
|
"sd_webui_modal_lightbox_toolbar_opacity": OptionInfo(0.9, "Full page image viewer: tool bar opacity", gr.Slider, {"minimum": 0.0, "maximum": 1, "step": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),
|
||||||
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
"gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("can be any valid CSS value, for example 768px or 20em").needs_reload_ui(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -279,6 +294,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives",
|
|||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||||
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
"txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(),
|
||||||
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
"img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(),
|
||||||
|
"interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
options_templates.update(options_section(('ui', "User interface", "ui"), {
|
||||||
@ -354,6 +370,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'),
|
||||||
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"),
|
||||||
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'),
|
||||||
|
'sd_noise_schedule': OptionInfo("Default", "Noise schedule for sampling", gr.Radio, {"choices": ["Default", "Zero Terminal SNR"]}, infotext="Noise Schedule").info("for use with zero terminal SNR trained models")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
options_templates.update(options_section(('postprocessing', "Postprocessing", "postprocessing"), {
|
||||||
|
@ -12,6 +12,7 @@ log = logging.getLogger(__name__)
|
|||||||
class State:
|
class State:
|
||||||
skipped = False
|
skipped = False
|
||||||
interrupted = False
|
interrupted = False
|
||||||
|
stopping_generation = False
|
||||||
job = ""
|
job = ""
|
||||||
job_no = 0
|
job_no = 0
|
||||||
job_count = 0
|
job_count = 0
|
||||||
@ -79,6 +80,10 @@ class State:
|
|||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
log.info("Received interrupt request")
|
log.info("Received interrupt request")
|
||||||
|
|
||||||
|
def stop_generating(self):
|
||||||
|
self.stopping_generation = True
|
||||||
|
log.info("Received stop generating request")
|
||||||
|
|
||||||
def nextjob(self):
|
def nextjob(self):
|
||||||
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
|
||||||
self.do_set_current_image()
|
self.do_set_current_image()
|
||||||
@ -91,6 +96,7 @@ class State:
|
|||||||
obj = {
|
obj = {
|
||||||
"skipped": self.skipped,
|
"skipped": self.skipped,
|
||||||
"interrupted": self.interrupted,
|
"interrupted": self.interrupted,
|
||||||
|
"stopping_generation": self.stopping_generation,
|
||||||
"job": self.job,
|
"job": self.job,
|
||||||
"job_count": self.job_count,
|
"job_count": self.job_count,
|
||||||
"job_timestamp": self.job_timestamp,
|
"job_timestamp": self.job_timestamp,
|
||||||
@ -114,6 +120,7 @@ class State:
|
|||||||
self.id_live_preview = 0
|
self.id_live_preview = 0
|
||||||
self.skipped = False
|
self.skipped = False
|
||||||
self.interrupted = False
|
self.interrupted = False
|
||||||
|
self.stopping_generation = False
|
||||||
self.textinfo = None
|
self.textinfo = None
|
||||||
self.job = job
|
self.job = job
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
|
from pathlib import Path
|
||||||
import csv
|
import csv
|
||||||
import fnmatch
|
|
||||||
import os
|
import os
|
||||||
import os.path
|
|
||||||
import typing
|
import typing
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(typing.NamedTuple):
|
class PromptStyle(typing.NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str | None
|
||||||
negative_prompt: str
|
negative_prompt: str | None
|
||||||
path: str = None
|
path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||||
@ -30,38 +29,29 @@ def apply_styles_to_prompt(prompt, styles):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def unwrap_style_text_from_prompt(style_text, prompt):
|
def extract_style_text_from_prompt(style_text, prompt):
|
||||||
"""
|
"""This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
|
||||||
Checks the prompt to see if the style text is wrapped around it. If so,
|
|
||||||
returns True plus the prompt text without the style text. Otherwise, returns
|
|
||||||
False with the original prompt.
|
|
||||||
|
|
||||||
Note that the "cleaned" version of the style text is only used for matching
|
extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
|
||||||
purposes here. It isn't returned; the original style text is not modified.
|
extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
|
||||||
|
extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
|
||||||
"""
|
"""
|
||||||
stripped_prompt = prompt
|
|
||||||
stripped_style_text = style_text
|
stripped_prompt = prompt.strip()
|
||||||
|
stripped_style_text = style_text.strip()
|
||||||
|
|
||||||
if "{prompt}" in stripped_style_text:
|
if "{prompt}" in stripped_style_text:
|
||||||
# Work out whether the prompt is wrapped in the style text. If so, we
|
|
||||||
# return True and the "inner" prompt text that isn't part of the style.
|
|
||||||
try:
|
|
||||||
left, right = stripped_style_text.split("{prompt}", 2)
|
left, right = stripped_style_text.split("{prompt}", 2)
|
||||||
except ValueError as e:
|
|
||||||
# If the style text has multple "{prompt}"s, we can't split it into
|
|
||||||
# two parts. This is an error, but we can't do anything about it.
|
|
||||||
print(f"Unable to compare style text to prompt:\n{style_text}")
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return False, prompt
|
|
||||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||||
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
||||||
return True, prompt
|
return True, prompt
|
||||||
else:
|
else:
|
||||||
# Work out whether the given prompt ends with the style text. If so, we
|
|
||||||
# return True and the prompt text up to where the style text starts.
|
|
||||||
if stripped_prompt.endswith(stripped_style_text):
|
if stripped_prompt.endswith(stripped_style_text):
|
||||||
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
||||||
if prompt.endswith(", "):
|
|
||||||
|
if prompt.endswith(', '):
|
||||||
prompt = prompt[:-2]
|
prompt = prompt[:-2]
|
||||||
|
|
||||||
return True, prompt
|
return True, prompt
|
||||||
|
|
||||||
return False, prompt
|
return False, prompt
|
||||||
@ -76,15 +66,11 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
|||||||
if not style.prompt and not style.negative_prompt:
|
if not style.prompt and not style.negative_prompt:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
||||||
style.prompt, prompt
|
|
||||||
)
|
|
||||||
if not match_positive:
|
if not match_positive:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
||||||
style.negative_prompt, negative_prompt
|
|
||||||
)
|
|
||||||
if not match_negative:
|
if not match_negative:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
@ -92,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
|||||||
|
|
||||||
|
|
||||||
class StyleDatabase:
|
class StyleDatabase:
|
||||||
def __init__(self, path: str):
|
def __init__(self, paths: list[str | Path]):
|
||||||
self.no_style = PromptStyle("None", "", "", None)
|
self.no_style = PromptStyle("None", "", "", None)
|
||||||
self.styles = {}
|
self.styles = {}
|
||||||
self.path = path
|
self.paths = paths
|
||||||
|
self.all_styles_files: list[Path] = []
|
||||||
|
|
||||||
folder, file = os.path.split(self.path)
|
folder, file = os.path.split(self.paths[0])
|
||||||
filename, _, ext = file.partition('*')
|
if '*' in file or '?' in file:
|
||||||
self.default_path = os.path.join(folder, filename + ext)
|
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
||||||
|
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
||||||
|
self.paths.insert(0, self.default_path)
|
||||||
|
else:
|
||||||
|
self.default_path = Path(self.paths[0])
|
||||||
|
|
||||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
@ -112,33 +103,31 @@ class StyleDatabase:
|
|||||||
"""
|
"""
|
||||||
self.styles.clear()
|
self.styles.clear()
|
||||||
|
|
||||||
path, filename = os.path.split(self.path)
|
# scans for all styles files
|
||||||
|
all_styles_files = []
|
||||||
if "*" in filename:
|
for pattern in self.paths:
|
||||||
fileglob = filename.split("*")[0] + "*.csv"
|
folder, file = os.path.split(pattern)
|
||||||
filelist = []
|
if '*' in file or '?' in file:
|
||||||
for file in os.listdir(path):
|
found_files = Path(folder).glob(file)
|
||||||
if fnmatch.fnmatch(file, fileglob):
|
[all_styles_files.append(file) for file in found_files]
|
||||||
filelist.append(file)
|
|
||||||
# Add a visible divider to the style list
|
|
||||||
half_len = round(len(file) / 2)
|
|
||||||
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
|
||||||
divider = f"{divider} {'-' * (40 - len(divider))}"
|
|
||||||
self.styles[divider] = PromptStyle(
|
|
||||||
f"{divider}", None, None, "do_not_save"
|
|
||||||
)
|
|
||||||
# Add styles from this CSV file
|
|
||||||
self.load_from_csv(os.path.join(path, file))
|
|
||||||
if len(filelist) == 0:
|
|
||||||
print(f"No styles found in {path} matching {fileglob}")
|
|
||||||
return
|
|
||||||
elif not os.path.exists(self.path):
|
|
||||||
print(f"Style database not found: {self.path}")
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
self.load_from_csv(self.path)
|
# if os.path.exists(pattern):
|
||||||
|
all_styles_files.append(Path(pattern))
|
||||||
|
|
||||||
def load_from_csv(self, path: str):
|
# Remove any duplicate entries
|
||||||
|
seen = set()
|
||||||
|
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
||||||
|
|
||||||
|
for styles_file in self.all_styles_files:
|
||||||
|
if len(all_styles_files) > 1:
|
||||||
|
# add divider when more than styles file
|
||||||
|
# '---------------- STYLES ----------------'
|
||||||
|
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
||||||
|
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
||||||
|
if styles_file.is_file():
|
||||||
|
self.load_from_csv(styles_file)
|
||||||
|
|
||||||
|
def load_from_csv(self, path: str | Path):
|
||||||
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||||
reader = csv.DictReader(file, skipinitialspace=True)
|
reader = csv.DictReader(file, skipinitialspace=True)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
@ -150,7 +139,7 @@ class StyleDatabase:
|
|||||||
negative_prompt = row.get("negative_prompt", "")
|
negative_prompt = row.get("negative_prompt", "")
|
||||||
# Add style to database
|
# Add style to database
|
||||||
self.styles[row["name"]] = PromptStyle(
|
self.styles[row["name"]] = PromptStyle(
|
||||||
row["name"], prompt, negative_prompt, path
|
row["name"], prompt, negative_prompt, str(path)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_style_paths(self) -> set:
|
def get_style_paths(self) -> set:
|
||||||
@ -158,11 +147,11 @@ class StyleDatabase:
|
|||||||
# Update any styles without a path to the default path
|
# Update any styles without a path to the default path
|
||||||
for style in list(self.styles.values()):
|
for style in list(self.styles.values()):
|
||||||
if not style.path:
|
if not style.path:
|
||||||
self.styles[style.name] = style._replace(path=self.default_path)
|
self.styles[style.name] = style._replace(path=str(self.default_path))
|
||||||
|
|
||||||
# Create a list of all distinct paths, including the default path
|
# Create a list of all distinct paths, including the default path
|
||||||
style_paths = set()
|
style_paths = set()
|
||||||
style_paths.add(self.default_path)
|
style_paths.add(str(self.default_path))
|
||||||
for _, style in self.styles.items():
|
for _, style in self.styles.items():
|
||||||
if style.path:
|
if style.path:
|
||||||
style_paths.add(style.path)
|
style_paths.add(style.path)
|
||||||
@ -190,7 +179,6 @@ class StyleDatabase:
|
|||||||
|
|
||||||
def save_styles(self, path: str = None) -> None:
|
def save_styles(self, path: str = None) -> None:
|
||||||
# The path argument is deprecated, but kept for backwards compatibility
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
_ = path
|
|
||||||
|
|
||||||
style_paths = self.get_style_paths()
|
style_paths = self.get_style_paths()
|
||||||
|
|
||||||
|
@ -24,13 +24,13 @@ environment_whitelist = {
|
|||||||
"XFORMERS_PACKAGE",
|
"XFORMERS_PACKAGE",
|
||||||
"CLIP_PACKAGE",
|
"CLIP_PACKAGE",
|
||||||
"OPENCLIP_PACKAGE",
|
"OPENCLIP_PACKAGE",
|
||||||
|
"ASSETS_REPO",
|
||||||
"STABLE_DIFFUSION_REPO",
|
"STABLE_DIFFUSION_REPO",
|
||||||
"K_DIFFUSION_REPO",
|
"K_DIFFUSION_REPO",
|
||||||
"CODEFORMER_REPO",
|
|
||||||
"BLIP_REPO",
|
"BLIP_REPO",
|
||||||
|
"ASSETS_COMMIT_HASH",
|
||||||
"STABLE_DIFFUSION_COMMIT_HASH",
|
"STABLE_DIFFUSION_COMMIT_HASH",
|
||||||
"K_DIFFUSION_COMMIT_HASH",
|
"K_DIFFUSION_COMMIT_HASH",
|
||||||
"CODEFORMER_COMMIT_HASH",
|
|
||||||
"BLIP_COMMIT_HASH",
|
"BLIP_COMMIT_HASH",
|
||||||
"COMMANDLINE_ARGS",
|
"COMMANDLINE_ARGS",
|
||||||
"IGNORE_CMD_ARGS_ERRORS",
|
"IGNORE_CMD_ARGS_ERRORS",
|
||||||
|
@ -11,7 +11,6 @@ import safetensors.torch
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
@ -348,6 +347,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
|
|||||||
})
|
})
|
||||||
|
|
||||||
def tensorboard_setup(log_directory):
|
def tensorboard_setup(log_directory):
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
|
||||||
return SummaryWriter(
|
return SummaryWriter(
|
||||||
log_dir=os.path.join(log_directory, "tensorboard"),
|
log_dir=os.path.join(log_directory, "tensorboard"),
|
||||||
@ -452,8 +452,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
|
tensorboard_writer = None
|
||||||
if shared.opts.training_enable_tensorboard:
|
if shared.opts.training_enable_tensorboard:
|
||||||
|
try:
|
||||||
tensorboard_writer = tensorboard_setup(log_directory)
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
except ImportError:
|
||||||
|
errors.report("Error initializing tensorboard", exc_info=True)
|
||||||
|
|
||||||
pin_memory = shared.opts.pin_memory
|
pin_memory = shared.opts.pin_memory
|
||||||
|
|
||||||
@ -626,7 +630,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
|
if tensorboard_writer and shared.opts.training_tensorboard_save_images:
|
||||||
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
|
||||||
|
|
||||||
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
|
||||||
|
17
modules/torch_utils.py
Normal file
17
modules/torch_utils.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_param(model) -> torch.nn.Parameter:
|
||||||
|
"""
|
||||||
|
Find the first parameter in a model or module.
|
||||||
|
"""
|
||||||
|
if hasattr(model, "model") and hasattr(model.model, "parameters"):
|
||||||
|
# Unpeel a model descriptor to get at the actual Torch module.
|
||||||
|
model = model.model
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
return param
|
||||||
|
|
||||||
|
raise ValueError(f"No parameters found in model {model!r}")
|
@ -1,17 +1,22 @@
|
|||||||
|
import json
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules import processing
|
from modules import processing, infotext_utils
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
from PIL import Image
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
|
if force_enable_hr:
|
||||||
|
enable_hr = True
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
@ -27,7 +32,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
enable_hr=enable_hr,
|
enable_hr=enable_hr,
|
||||||
denoising_strength=denoising_strength if enable_hr else None,
|
denoising_strength=denoising_strength,
|
||||||
hr_scale=hr_scale,
|
hr_scale=hr_scale,
|
||||||
hr_upscaler=hr_upscaler,
|
hr_upscaler=hr_upscaler,
|
||||||
hr_second_pass_steps=hr_second_pass_steps,
|
hr_second_pass_steps=hr_second_pass_steps,
|
||||||
@ -48,8 +53,58 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
if shared.opts.enable_console_prompts:
|
if shared.opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
|
||||||
|
assert len(gallery) > 0, 'No image to upscale'
|
||||||
|
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
|
||||||
|
|
||||||
|
p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)
|
||||||
|
p.batch_size = 1
|
||||||
|
p.n_iter = 1
|
||||||
|
# txt2img_upscale attribute that signifies this is called by txt2img_upscale
|
||||||
|
p.txt2img_upscale = True
|
||||||
|
|
||||||
|
geninfo = json.loads(generation_info)
|
||||||
|
|
||||||
|
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
|
||||||
|
p.firstpass_image = infotext_utils.image_from_url_text(image_info)
|
||||||
|
|
||||||
|
parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])
|
||||||
|
p.seed = parameters.get('Seed', -1)
|
||||||
|
p.subseed = parameters.get('Variation seed', -1)
|
||||||
|
|
||||||
|
p.override_settings['save_images_before_highres_fix'] = False
|
||||||
|
|
||||||
with closing(p):
|
with closing(p):
|
||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
|
||||||
|
|
||||||
|
if processed is None:
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
|
new_gallery = []
|
||||||
|
for i, image in enumerate(gallery):
|
||||||
|
if i == gallery_index:
|
||||||
|
geninfo["infotexts"][gallery_index: gallery_index+1] = processed.infotexts
|
||||||
|
new_gallery.extend(processed.images)
|
||||||
|
else:
|
||||||
|
fake_image = Image.new(mode="RGB", size=(1, 1))
|
||||||
|
fake_image.already_saved_as = image["name"].rsplit('?', 1)[0]
|
||||||
|
new_gallery.append(fake_image)
|
||||||
|
|
||||||
|
geninfo["infotexts"][gallery_index] = processed.info
|
||||||
|
|
||||||
|
return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname="comments")
|
||||||
|
|
||||||
|
|
||||||
|
def txt2img(id_task: str, request: gr.Request, *args):
|
||||||
|
p = txt2img_create_processing(id_task, request, *args)
|
||||||
|
|
||||||
|
with closing(p):
|
||||||
|
processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)
|
||||||
|
|
||||||
if processed is None:
|
if processed is None:
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
|
131
modules/ui.py
131
modules/ui.py
@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript
|
|||||||
|
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.infotext_utils as parameters_copypaste
|
||||||
import modules.hypernetworks.ui as hypernetworks_ui
|
import modules.hypernetworks.ui as hypernetworks_ui
|
||||||
import modules.textual_inversion.ui as textual_inversion_ui
|
import modules.textual_inversion.ui as textual_inversion_ui
|
||||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.infotext_utils import image_from_url_text, PasteField
|
||||||
|
|
||||||
create_setting_component = ui_settings.create_setting_component
|
create_setting_component = ui_settings.create_setting_component
|
||||||
|
|
||||||
@ -177,7 +177,6 @@ def update_negative_prompt_token_counter(text, steps):
|
|||||||
return update_token_counter(text, steps, is_positive=False)
|
return update_token_counter(text, steps, is_positive=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -267,7 +266,7 @@ def create_ui():
|
|||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs", elem_classes=["extra-networks"])
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
@ -376,12 +375,9 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_inputs = [
|
||||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
|
||||||
_js="submit",
|
|
||||||
inputs=[
|
|
||||||
dummy_component,
|
dummy_component,
|
||||||
toprow.prompt,
|
toprow.prompt,
|
||||||
toprow.negative_prompt,
|
toprow.negative_prompt,
|
||||||
@ -405,21 +401,34 @@ def create_ui():
|
|||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
override_settings,
|
override_settings,
|
||||||
|
] + custom_inputs
|
||||||
|
|
||||||
] + custom_inputs,
|
txt2img_outputs = [
|
||||||
|
output_panel.gallery,
|
||||||
|
output_panel.generation_info,
|
||||||
|
output_panel.infotext,
|
||||||
|
output_panel.html_log,
|
||||||
|
]
|
||||||
|
|
||||||
outputs=[
|
txt2img_args = dict(
|
||||||
txt2img_gallery,
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||||
generation_info,
|
_js="submit",
|
||||||
html_info,
|
inputs=txt2img_inputs,
|
||||||
html_log,
|
outputs=txt2img_outputs,
|
||||||
],
|
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
toprow.prompt.submit(**txt2img_args)
|
toprow.prompt.submit(**txt2img_args)
|
||||||
toprow.submit.click(**txt2img_args)
|
toprow.submit.click(**txt2img_args)
|
||||||
|
|
||||||
|
output_panel.button_upscale.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
|
||||||
|
_js="submit_txt2img_upscale",
|
||||||
|
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
|
||||||
|
outputs=txt2img_outputs,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
||||||
|
|
||||||
toprow.restore_progress_button.click(
|
toprow.restore_progress_button.click(
|
||||||
@ -427,37 +436,37 @@ def create_ui():
|
|||||||
_js="restoreProgressTxt2img",
|
_js="restoreProgressTxt2img",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
outputs=[
|
outputs=[
|
||||||
txt2img_gallery,
|
output_panel.gallery,
|
||||||
generation_info,
|
output_panel.generation_info,
|
||||||
html_info,
|
output_panel.infotext,
|
||||||
html_log,
|
output_panel.html_log,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
(toprow.prompt, "Prompt"),
|
PasteField(toprow.prompt, "Prompt", api="prompt"),
|
||||||
(toprow.negative_prompt, "Negative prompt"),
|
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
|
||||||
(steps, "Steps"),
|
PasteField(steps, "Steps", api="steps"),
|
||||||
(sampler_name, "Sampler"),
|
PasteField(sampler_name, "Sampler", api="sampler_name"),
|
||||||
(cfg_scale, "CFG scale"),
|
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
|
||||||
(width, "Size-1"),
|
PasteField(width, "Size-1", api="width"),
|
||||||
(height, "Size-2"),
|
PasteField(height, "Size-2", api="height"),
|
||||||
(batch_size, "Batch size"),
|
PasteField(batch_size, "Batch size", api="batch_size"),
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
|
||||||
(denoising_strength, "Denoising strength"),
|
PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
|
||||||
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
|
||||||
(hr_scale, "Hires upscale"),
|
PasteField(hr_scale, "Hires upscale", api="hr_scale"),
|
||||||
(hr_upscaler, "Hires upscaler"),
|
PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
|
||||||
(hr_second_pass_steps, "Hires steps"),
|
PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
|
||||||
(hr_resize_x, "Hires resize-1"),
|
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
|
||||||
(hr_resize_y, "Hires resize-2"),
|
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
|
||||||
(hr_checkpoint_name, "Hires checkpoint"),
|
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
|
||||||
(hr_sampler_name, "Hires sampler"),
|
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
|
||||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||||
(hr_prompt, "Hires prompt"),
|
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
|
||||||
(hr_negative_prompt, "Hires negative prompt"),
|
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
|
||||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||||
*scripts.scripts_txt2img.infotext_fields
|
*scripts.scripts_txt2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||||
@ -480,7 +489,7 @@ def create_ui():
|
|||||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
|
||||||
|
|
||||||
extra_tabs.__exit__()
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
@ -490,7 +499,7 @@ def create_ui():
|
|||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs", elem_classes=["extra-networks"])
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
with gr.Tab("Generation", id="img2img_generation") as img2img_generation_tab, ResizeHandleRow(equal_height=False):
|
||||||
@ -523,7 +532,7 @@ def create_ui():
|
|||||||
|
|
||||||
if category == "image":
|
if category == "image":
|
||||||
with gr.Tabs(elem_id="mode_img2img"):
|
with gr.Tabs(elem_id="mode_img2img"):
|
||||||
img2img_selected_tab = gr.State(0)
|
img2img_selected_tab = gr.Number(value=0, visible=False)
|
||||||
|
|
||||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||||
@ -604,7 +613,7 @@ def create_ui():
|
|||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
selected_scale_tab = gr.State(value=0)
|
selected_scale_tab = gr.Number(value=0, visible=False)
|
||||||
|
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
|
||||||
@ -711,7 +720,7 @@ def create_ui():
|
|||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
@ -756,10 +765,10 @@ def create_ui():
|
|||||||
img2img_batch_png_info_dir,
|
img2img_batch_png_info_dir,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
img2img_gallery,
|
output_panel.gallery,
|
||||||
generation_info,
|
output_panel.generation_info,
|
||||||
html_info,
|
output_panel.infotext,
|
||||||
html_log,
|
output_panel.html_log,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
@ -797,10 +806,10 @@ def create_ui():
|
|||||||
_js="restoreProgressImg2img",
|
_js="restoreProgressImg2img",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
outputs=[
|
outputs=[
|
||||||
img2img_gallery,
|
output_panel.gallery,
|
||||||
generation_info,
|
output_panel.generation_info,
|
||||||
html_info,
|
output_panel.infotext,
|
||||||
html_log,
|
output_panel.html_log,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
@ -831,6 +840,10 @@ def create_ui():
|
|||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(mask_blur, "Mask blur"),
|
(mask_blur, "Mask blur"),
|
||||||
|
(inpainting_mask_invert, 'Mask mode'),
|
||||||
|
(inpainting_fill, 'Masked content'),
|
||||||
|
(inpaint_full_res, 'Inpaint area'),
|
||||||
|
(inpaint_full_res_padding, 'Masked area padding'),
|
||||||
*scripts.scripts_img2img.infotext_fields
|
*scripts.scripts_img2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
||||||
@ -840,7 +853,7 @@ def create_ui():
|
|||||||
))
|
))
|
||||||
|
|
||||||
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
|
||||||
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
|
ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
|
||||||
|
|
||||||
extra_tabs.__exit__()
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
@ -1086,6 +1099,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
|
||||||
|
ui_settings_from_file = loadsave.ui_settings.copy()
|
||||||
|
|
||||||
settings = ui_settings.UiSettings()
|
settings = ui_settings.UiSettings()
|
||||||
settings.create_ui(loadsave, dummy_component)
|
settings.create_ui(loadsave, dummy_component)
|
||||||
@ -1146,6 +1160,7 @@ def create_ui():
|
|||||||
|
|
||||||
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
||||||
|
|
||||||
|
if ui_settings_from_file != loadsave.ui_settings:
|
||||||
loadsave.dump_defaults()
|
loadsave.dump_defaults()
|
||||||
demo.ui_loadsave = loadsave
|
demo.ui_loadsave = loadsave
|
||||||
|
|
||||||
@ -1208,3 +1223,5 @@ def setup_ui_api(app):
|
|||||||
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
|
app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"])
|
||||||
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
|
app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"])
|
||||||
|
|
||||||
|
import fastapi.staticfiles
|
||||||
|
app.mount("/webui-assets", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name="webui-assets")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import csv
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import html
|
import html
|
||||||
import os
|
import os
|
||||||
@ -8,10 +10,10 @@ import gradio as gr
|
|||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
|
|
||||||
from modules import call_queue, shared
|
from modules import call_queue, shared
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.infotext_utils import image_from_url_text
|
||||||
import modules.images
|
import modules.images
|
||||||
from modules.ui_components import ToolButton
|
from modules.ui_components import ToolButton
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.infotext_utils as parameters_copypaste
|
||||||
|
|
||||||
folder_symbol = '\U0001f4c2' # 📂
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
refresh_symbol = '\U0001f504' # 🔄
|
refresh_symbol = '\U0001f504' # 🔄
|
||||||
@ -35,12 +37,38 @@ def plaintext_to_html(text, classname=None):
|
|||||||
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
return f"<p class='{classname}'>{content}</p>" if classname else f"<p>{content}</p>"
|
||||||
|
|
||||||
|
|
||||||
|
def update_logfile(logfile_path, fields):
|
||||||
|
"""Update a logfile from old format to new format to maintain CSV integrity."""
|
||||||
|
with open(logfile_path, "r", encoding="utf8", newline="") as file:
|
||||||
|
reader = csv.reader(file)
|
||||||
|
rows = list(reader)
|
||||||
|
|
||||||
|
# blank file: leave it as is
|
||||||
|
if not rows:
|
||||||
|
return
|
||||||
|
|
||||||
|
# file is already synced, do nothing
|
||||||
|
if len(rows[0]) == len(fields):
|
||||||
|
return
|
||||||
|
|
||||||
|
rows[0] = fields
|
||||||
|
|
||||||
|
# append new fields to each row as empty values
|
||||||
|
for row in rows[1:]:
|
||||||
|
while len(row) < len(fields):
|
||||||
|
row.append("")
|
||||||
|
|
||||||
|
with open(logfile_path, "w", encoding="utf8", newline="") as file:
|
||||||
|
writer = csv.writer(file)
|
||||||
|
writer.writerows(rows)
|
||||||
|
|
||||||
|
|
||||||
def save_files(js_data, images, do_make_zip, index):
|
def save_files(js_data, images, do_make_zip, index):
|
||||||
import csv
|
|
||||||
filenames = []
|
filenames = []
|
||||||
fullfns = []
|
fullfns = []
|
||||||
|
parsed_infotexts = []
|
||||||
|
|
||||||
#quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
# quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it
|
||||||
class MyObject:
|
class MyObject:
|
||||||
def __init__(self, d=None):
|
def __init__(self, d=None):
|
||||||
if d is not None:
|
if d is not None:
|
||||||
@ -48,35 +76,55 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
data = json.loads(js_data)
|
data = json.loads(js_data)
|
||||||
|
|
||||||
p = MyObject(data)
|
p = MyObject(data)
|
||||||
|
|
||||||
path = shared.opts.outdir_save
|
path = shared.opts.outdir_save
|
||||||
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
save_to_dirs = shared.opts.use_save_to_dirs_for_ui
|
||||||
extension: str = shared.opts.samples_format
|
extension: str = shared.opts.samples_format
|
||||||
start_index = 0
|
start_index = 0
|
||||||
only_one = False
|
|
||||||
|
|
||||||
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
only_one = True
|
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
start_index = index
|
start_index = index
|
||||||
|
|
||||||
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
os.makedirs(shared.opts.outdir_save, exist_ok=True)
|
||||||
|
|
||||||
with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
fields = [
|
||||||
|
"prompt",
|
||||||
|
"seed",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"sampler",
|
||||||
|
"cfgs",
|
||||||
|
"steps",
|
||||||
|
"filename",
|
||||||
|
"negative_prompt",
|
||||||
|
"sd_model_name",
|
||||||
|
"sd_model_hash",
|
||||||
|
]
|
||||||
|
logfile_path = os.path.join(shared.opts.outdir_save, "log.csv")
|
||||||
|
|
||||||
|
# NOTE: ensure csv integrity when fields are added by
|
||||||
|
# updating headers and padding with delimeters where needed
|
||||||
|
if os.path.exists(logfile_path):
|
||||||
|
update_logfile(logfile_path, fields)
|
||||||
|
|
||||||
|
with open(logfile_path, "a", encoding="utf8", newline='') as file:
|
||||||
at_start = file.tell() == 0
|
at_start = file.tell() == 0
|
||||||
writer = csv.writer(file)
|
writer = csv.writer(file)
|
||||||
if at_start:
|
if at_start:
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
writer.writerow(fields)
|
||||||
|
|
||||||
for image_index, filedata in enumerate(images, start_index):
|
for image_index, filedata in enumerate(images, start_index):
|
||||||
image = image_from_url_text(filedata)
|
image = image_from_url_text(filedata)
|
||||||
|
|
||||||
is_grid = image_index < p.index_of_first_image
|
is_grid = image_index < p.index_of_first_image
|
||||||
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
|
||||||
|
|
||||||
p.batch_index = image_index-1
|
p.batch_index = image_index-1
|
||||||
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
|
||||||
|
parameters = parameters_copypaste.parse_generation_parameters(data["infotexts"][image_index], [])
|
||||||
|
parsed_infotexts.append(parameters)
|
||||||
|
fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||||
|
|
||||||
filename = os.path.relpath(fullfn, path)
|
filename = os.path.relpath(fullfn, path)
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
@ -85,12 +133,12 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
filenames.append(os.path.basename(txt_fullfn))
|
filenames.append(os.path.basename(txt_fullfn))
|
||||||
fullfns.append(txt_fullfn)
|
fullfns.append(txt_fullfn)
|
||||||
|
|
||||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], parsed_infotexts[0]['Negative prompt'], data["sd_model_name"], data["sd_model_hash"]])
|
||||||
|
|
||||||
# Make Zip
|
# Make Zip
|
||||||
if do_make_zip:
|
if do_make_zip:
|
||||||
zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0]
|
p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]
|
||||||
namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True)
|
namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)
|
||||||
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")
|
||||||
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
zip_filepath = os.path.join(path, f"{zip_filename}.zip")
|
||||||
|
|
||||||
@ -104,7 +152,17 @@ def save_files(js_data, images, do_make_zip, index):
|
|||||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class OutputPanel:
|
||||||
|
gallery = None
|
||||||
|
generation_info = None
|
||||||
|
infotext = None
|
||||||
|
html_log = None
|
||||||
|
button_upscale = None
|
||||||
|
|
||||||
|
|
||||||
def create_output_panel(tabname, outdir, toprow=None):
|
def create_output_panel(tabname, outdir, toprow=None):
|
||||||
|
res = OutputPanel()
|
||||||
|
|
||||||
def open_folder(f):
|
def open_folder(f):
|
||||||
if not os.path.exists(f):
|
if not os.path.exists(f):
|
||||||
@ -136,9 +194,8 @@ Requested path was: {f}
|
|||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||||
|
|
||||||
generation_info = None
|
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||||
|
|
||||||
@ -152,6 +209,9 @@ Requested path was: {f}
|
|||||||
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tabname == 'txt2img':
|
||||||
|
res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
|
||||||
|
|
||||||
open_folder_button.click(
|
open_folder_button.click(
|
||||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
@ -162,17 +222,17 @@ Requested path was: {f}
|
|||||||
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
|
||||||
|
|
||||||
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
|
||||||
if tabname == 'txt2img' or tabname == 'img2img':
|
if tabname == 'txt2img' or tabname == 'img2img':
|
||||||
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
|
||||||
generation_info_button.click(
|
generation_info_button.click(
|
||||||
fn=update_generation_info,
|
fn=update_generation_info,
|
||||||
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
|
||||||
inputs=[generation_info, html_info, html_info],
|
inputs=[res.generation_info, res.infotext, res.infotext],
|
||||||
outputs=[html_info, html_info],
|
outputs=[res.infotext, res.infotext],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -180,14 +240,14 @@ Requested path was: {f}
|
|||||||
fn=call_queue.wrap_gradio_call(save_files),
|
fn=call_queue.wrap_gradio_call(save_files),
|
||||||
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
|
||||||
inputs=[
|
inputs=[
|
||||||
generation_info,
|
res.generation_info,
|
||||||
result_gallery,
|
res.gallery,
|
||||||
html_info,
|
res.infotext,
|
||||||
html_info,
|
res.infotext,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
download_files,
|
download_files,
|
||||||
html_log,
|
res.html_log,
|
||||||
],
|
],
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
@ -196,21 +256,21 @@ Requested path was: {f}
|
|||||||
fn=call_queue.wrap_gradio_call(save_files),
|
fn=call_queue.wrap_gradio_call(save_files),
|
||||||
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
|
||||||
inputs=[
|
inputs=[
|
||||||
generation_info,
|
res.generation_info,
|
||||||
result_gallery,
|
res.gallery,
|
||||||
html_info,
|
res.infotext,
|
||||||
html_info,
|
res.infotext,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
download_files,
|
download_files,
|
||||||
html_log,
|
res.html_log,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
|
||||||
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
|
||||||
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
|
||||||
|
|
||||||
paste_field_names = []
|
paste_field_names = []
|
||||||
if tabname == "txt2img":
|
if tabname == "txt2img":
|
||||||
@ -220,11 +280,11 @@ Requested path was: {f}
|
|||||||
|
|
||||||
for paste_tabname, paste_button in buttons.items():
|
for paste_tabname, paste_button in buttons.items():
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||||
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
|
paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery,
|
||||||
paste_field_names=paste_field_names
|
paste_field_names=paste_field_names
|
||||||
))
|
))
|
||||||
|
|
||||||
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
@ -2,23 +2,22 @@ import functools
|
|||||||
import os.path
|
import os.path
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util
|
||||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import json
|
import json
|
||||||
import html
|
import html
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
|
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.infotext_utils import image_from_url_text
|
||||||
from modules.ui_components import ToolButton
|
|
||||||
|
|
||||||
extra_pages = []
|
extra_pages = []
|
||||||
allowed_dirs = set()
|
allowed_dirs = set()
|
||||||
|
|
||||||
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
|
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def allowed_preview_extensions_with_extra(extra_extensions=None):
|
def allowed_preview_extensions_with_extra(extra_extensions=None):
|
||||||
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
|
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
|
||||||
@ -28,6 +27,62 @@ def allowed_preview_extensions():
|
|||||||
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
|
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtraNetworksItem:
|
||||||
|
"""Wrapper for dictionaries representing ExtraNetworks items."""
|
||||||
|
item: dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict:
|
||||||
|
"""Recursively builds a directory tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: Path or list of paths to directories. These paths are treated as roots from which
|
||||||
|
the tree will be built.
|
||||||
|
items: A dictionary associating filepaths to an ExtraNetworksItem instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result directory tree.
|
||||||
|
"""
|
||||||
|
if isinstance(paths, (str,)):
|
||||||
|
paths = [paths]
|
||||||
|
|
||||||
|
def _get_tree(_paths: list[str], _root: str):
|
||||||
|
_res = {}
|
||||||
|
for path in _paths:
|
||||||
|
relpath = os.path.relpath(path, _root)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
dir_items = os.listdir(path)
|
||||||
|
# Ignore empty directories.
|
||||||
|
if not dir_items:
|
||||||
|
continue
|
||||||
|
dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root)
|
||||||
|
# We only want to store non-empty folders in the tree.
|
||||||
|
if dir_tree:
|
||||||
|
_res[relpath] = dir_tree
|
||||||
|
else:
|
||||||
|
if path not in items:
|
||||||
|
continue
|
||||||
|
# Add the ExtraNetworksItem to the result.
|
||||||
|
_res[relpath] = items[path]
|
||||||
|
return _res
|
||||||
|
|
||||||
|
res = {}
|
||||||
|
# Handle each root directory separately.
|
||||||
|
# Each root WILL have a key/value at the root of the result dict though
|
||||||
|
# the value can be an empty dict if the directory is empty. We want these
|
||||||
|
# placeholders for empty dirs so we can inform the user later.
|
||||||
|
for path in paths:
|
||||||
|
root = os.path.dirname(path)
|
||||||
|
relpath = os.path.relpath(path, root)
|
||||||
|
# Wrap the path in a list since that is what the `_get_tree` expects.
|
||||||
|
res[relpath] = _get_tree([path], root)
|
||||||
|
if res[relpath]:
|
||||||
|
# We need to pull the inner path out one for these root dirs.
|
||||||
|
res[relpath] = res[relpath][relpath]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def register_page(page):
|
def register_page(page):
|
||||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||||
|
|
||||||
@ -80,7 +135,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
|||||||
item = page.items.get(name)
|
item = page.items.get(name)
|
||||||
|
|
||||||
page.read_user_metadata(item)
|
page.read_user_metadata(item)
|
||||||
item_html = page.create_html_for_item(item, tabname)
|
item_html = page.create_item_html(tabname, item)
|
||||||
|
|
||||||
return JSONResponse({"html": item_html})
|
return JSONResponse({"html": item_html})
|
||||||
|
|
||||||
@ -96,24 +151,31 @@ def quote_js(s):
|
|||||||
s = s.replace('"', '\\"')
|
s = s.replace('"', '\\"')
|
||||||
return f'"{s}"'
|
return f'"{s}"'
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPage:
|
class ExtraNetworksPage:
|
||||||
def __init__(self, title):
|
def __init__(self, title):
|
||||||
self.title = title
|
self.title = title
|
||||||
self.name = title.lower()
|
self.name = title.lower()
|
||||||
self.id_page = self.name.replace(" ", "_")
|
# This is the actual name of the extra networks tab (not txt2img/img2img).
|
||||||
self.card_page = shared.html("extra-networks-card.html")
|
self.extra_networks_tabname = self.name.replace(" ", "_")
|
||||||
self.allow_prompt = True
|
self.allow_prompt = True
|
||||||
self.allow_negative_prompt = False
|
self.allow_negative_prompt = False
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
self.items = {}
|
self.items = {}
|
||||||
|
self.lister = util.MassFileLister()
|
||||||
|
# HTML Templates
|
||||||
|
self.pane_tpl = shared.html("extra-networks-pane.html")
|
||||||
|
self.card_tpl = shared.html("extra-networks-card.html")
|
||||||
|
self.btn_tree_tpl = shared.html("extra-networks-tree-button.html")
|
||||||
|
self.btn_copy_path_tpl = shared.html("extra-networks-copy-path-button.html")
|
||||||
|
self.btn_metadata_tpl = shared.html("extra-networks-metadata-button.html")
|
||||||
|
self.btn_edit_item_tpl = shared.html("extra-networks-edit-item-button.html")
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def read_user_metadata(self, item):
|
def read_user_metadata(self, item):
|
||||||
filename = item.get("filename", None)
|
filename = item.get("filename", None)
|
||||||
metadata = extra_networks.get_user_metadata(filename)
|
metadata = extra_networks.get_user_metadata(filename, lister=self.lister)
|
||||||
|
|
||||||
desc = metadata.get("description", None)
|
desc = metadata.get("description", None)
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
@ -123,117 +185,74 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
def link_preview(self, filename):
|
def link_preview(self, filename):
|
||||||
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
|
||||||
mtime = os.path.getmtime(filename)
|
mtime, _ = self.lister.mctime(filename)
|
||||||
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
|
||||||
|
|
||||||
def search_terms_from_path(self, filename, possible_directories=None):
|
def search_terms_from_path(self, filename, possible_directories=None):
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):
|
||||||
parentdir = os.path.abspath(parentdir)
|
parentdir = os.path.dirname(os.path.abspath(parentdir))
|
||||||
if abspath.startswith(parentdir):
|
if abspath.startswith(parentdir):
|
||||||
return abspath[len(parentdir):].replace('\\', '/')
|
return os.path.relpath(abspath, parentdir)
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def create_html(self, tabname):
|
def create_item_html(
|
||||||
items_html = ''
|
self,
|
||||||
|
tabname: str,
|
||||||
|
item: dict,
|
||||||
|
template: Optional[str] = None,
|
||||||
|
) -> Union[str, dict]:
|
||||||
|
"""Generates HTML for a single ExtraNetworks Item.
|
||||||
|
|
||||||
self.metadata = {}
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
item: Dictionary containing item information.
|
||||||
|
template: Optional template string to use.
|
||||||
|
|
||||||
subdirs = {}
|
Returns:
|
||||||
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
|
If a template is passed: HTML string generated for this item.
|
||||||
for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
|
Can be empty if the item is not meant to be shown.
|
||||||
for dirname in sorted(dirs, key=shared.natural_sort_key):
|
If no template is passed: A dictionary containing the generated item's attributes.
|
||||||
x = os.path.join(root, dirname)
|
|
||||||
|
|
||||||
if not os.path.isdir(x):
|
|
||||||
continue
|
|
||||||
|
|
||||||
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
|
|
||||||
|
|
||||||
if shared.opts.extra_networks_dir_button_function:
|
|
||||||
if not subdir.startswith("/"):
|
|
||||||
subdir = "/" + subdir
|
|
||||||
else:
|
|
||||||
while subdir.startswith("/"):
|
|
||||||
subdir = subdir[1:]
|
|
||||||
|
|
||||||
is_empty = len(os.listdir(x)) == 0
|
|
||||||
if not is_empty and not subdir.endswith("/"):
|
|
||||||
subdir = subdir + "/"
|
|
||||||
|
|
||||||
if ("/." in subdir or subdir.startswith(".")) and not shared.opts.extra_networks_show_hidden_directories:
|
|
||||||
continue
|
|
||||||
|
|
||||||
subdirs[subdir] = 1
|
|
||||||
|
|
||||||
if subdirs:
|
|
||||||
subdirs = {"": 1, **subdirs}
|
|
||||||
|
|
||||||
subdirs_html = "".join([f"""
|
|
||||||
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'>
|
|
||||||
{html.escape(subdir if subdir!="" else "all")}
|
|
||||||
</button>
|
|
||||||
""" for subdir in subdirs])
|
|
||||||
|
|
||||||
self.items = {x["name"]: x for x in self.list_items()}
|
|
||||||
for item in self.items.values():
|
|
||||||
metadata = item.get("metadata")
|
|
||||||
if metadata:
|
|
||||||
self.metadata[item["name"]] = metadata
|
|
||||||
|
|
||||||
if "user_metadata" not in item:
|
|
||||||
self.read_user_metadata(item)
|
|
||||||
|
|
||||||
items_html += self.create_html_for_item(item, tabname)
|
|
||||||
|
|
||||||
if items_html == '':
|
|
||||||
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
|
||||||
items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
|
||||||
|
|
||||||
self_name_id = self.name.replace(" ", "_")
|
|
||||||
|
|
||||||
res = f"""
|
|
||||||
<div id='{tabname}_{self_name_id}_subdirs' class='extra-network-subdirs extra-network-subdirs-cards'>
|
|
||||||
{subdirs_html}
|
|
||||||
</div>
|
|
||||||
<div id='{tabname}_{self_name_id}_cards' class='extra-network-cards'>
|
|
||||||
{items_html}
|
|
||||||
</div>
|
|
||||||
"""
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def list_items(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def create_html_for_item(self, item, tabname):
|
|
||||||
"""
|
"""
|
||||||
Create HTML for card item in tab tabname; can return empty string if the item is not meant to be shown.
|
|
||||||
"""
|
|
||||||
|
|
||||||
preview = item.get("preview", None)
|
preview = item.get("preview", None)
|
||||||
|
style_height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
||||||
|
style_width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
||||||
|
style_font_size = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;"
|
||||||
|
card_style = style_height + style_width + style_font_size
|
||||||
|
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
||||||
|
|
||||||
onclick = item.get("onclick", None)
|
onclick = item.get("onclick", None)
|
||||||
if onclick is None:
|
if onclick is None:
|
||||||
onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
|
# Don't quote prompt/neg_prompt since they are stored as js strings already.
|
||||||
|
onclick_js_tpl = "cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});"
|
||||||
|
onclick = onclick_js_tpl.format(
|
||||||
|
**{
|
||||||
|
"tabname": tabname,
|
||||||
|
"prompt": item["prompt"],
|
||||||
|
"neg_prompt": item.get("negative_prompt", "''"),
|
||||||
|
"allow_neg": str(self.allow_negative_prompt).lower(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
onclick = html.escape(onclick)
|
||||||
|
|
||||||
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
|
btn_copy_path = self.btn_copy_path_tpl.format(**{"filename": item["filename"]})
|
||||||
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
|
btn_metadata = ""
|
||||||
background_image = f'<img src="{html.escape(preview)}" class="preview" loading="lazy">' if preview else ''
|
|
||||||
metadata_button = ""
|
|
||||||
metadata = item.get("metadata")
|
metadata = item.get("metadata")
|
||||||
if metadata:
|
if metadata:
|
||||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
|
btn_metadata = self.btn_metadata_tpl.format(
|
||||||
|
**{
|
||||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
|
"name": html.escape(item["name"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
btn_edit_item = self.btn_edit_item_tpl.format(
|
||||||
|
**{
|
||||||
|
"tabname": tabname,
|
||||||
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
|
"name": html.escape(item["name"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
local_path = ""
|
local_path = ""
|
||||||
filename = item.get("filename", "")
|
filename = item.get("filename", "")
|
||||||
@ -253,36 +272,292 @@ class ExtraNetworksPage:
|
|||||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
|
sort_keys = " ".join(
|
||||||
|
[
|
||||||
|
f'data-sort-{k}="{html.escape(str(v))}"'
|
||||||
|
for k, v in item.get("sort_keys", {}).items()
|
||||||
|
]
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
search_terms_html = ""
|
||||||
|
search_term_template = "<span class='hidden {class}'>{search_term}</span>"
|
||||||
|
for search_term in item.get("search_terms", []):
|
||||||
|
search_terms_html += search_term_template.format(
|
||||||
|
**{
|
||||||
|
"class": f"search_terms{' search_only' if search_only else ''}",
|
||||||
|
"search_term": search_term,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Some items here might not be used depending on HTML template used.
|
||||||
args = {
|
args = {
|
||||||
"background_image": background_image,
|
"background_image": background_image,
|
||||||
"style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
|
|
||||||
"prompt": item.get("prompt", None),
|
|
||||||
"tabname": quote_js(tabname),
|
|
||||||
"local_preview": quote_js(item["local_preview"]),
|
|
||||||
"name": html.escape(item["name"]),
|
|
||||||
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
|
||||||
"card_clicked": onclick,
|
"card_clicked": onclick,
|
||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
"copy_path_button": btn_copy_path,
|
||||||
"search_term": item.get("search_term", ""),
|
"description": (item.get("description", "") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||||
"metadata_button": metadata_button,
|
"edit_button": btn_edit_item,
|
||||||
"edit_button": edit_button,
|
"local_preview": quote_js(item["local_preview"]),
|
||||||
|
"metadata_button": btn_metadata,
|
||||||
|
"name": html.escape(item["name"]),
|
||||||
|
"prompt": item.get("prompt", None),
|
||||||
|
"save_card_preview": html.escape(f"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');"),
|
||||||
"search_only": " search_only" if search_only else "",
|
"search_only": " search_only" if search_only else "",
|
||||||
|
"search_terms": search_terms_html,
|
||||||
"sort_keys": sort_keys,
|
"sort_keys": sort_keys,
|
||||||
|
"style": card_style,
|
||||||
|
"tabname": tabname,
|
||||||
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.card_page.format(**args)
|
if template:
|
||||||
|
return template.format(**args)
|
||||||
|
else:
|
||||||
|
return args
|
||||||
|
|
||||||
|
def create_tree_dir_item_html(
|
||||||
|
self,
|
||||||
|
tabname: str,
|
||||||
|
dir_path: str,
|
||||||
|
content: Optional[str] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Generates HTML for a directory item in the tree.
|
||||||
|
|
||||||
|
The generated HTML is of the format:
|
||||||
|
```html
|
||||||
|
<li class="tree-list-item tree-list-item--has-subitem">
|
||||||
|
<div class="tree-list-content tree-list-content-dir"></div>
|
||||||
|
<ul class="tree-list tree-list--subgroup">
|
||||||
|
{content}
|
||||||
|
</ul>
|
||||||
|
</li>
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
dir_path: Path to the directory for this item.
|
||||||
|
content: Optional HTML string that will be wrapped by this <ul>.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string.
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
btn = self.btn_tree_tpl.format(
|
||||||
|
**{
|
||||||
|
"search_terms": "",
|
||||||
|
"subclass": "tree-list-content-dir",
|
||||||
|
"tabname": tabname,
|
||||||
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
|
"onclick_extra": "",
|
||||||
|
"data_path": dir_path,
|
||||||
|
"data_hash": "",
|
||||||
|
"action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
|
||||||
|
"action_list_item_visual_leading": "🗀",
|
||||||
|
"action_list_item_label": os.path.basename(dir_path),
|
||||||
|
"action_list_item_visual_trailing": "",
|
||||||
|
"action_list_item_action_trailing": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ul = f"<ul class='tree-list tree-list--subgroup' hidden>{content}</ul>"
|
||||||
|
return (
|
||||||
|
"<li class='tree-list-item tree-list-item--has-subitem' data-tree-entry-type='dir'>"
|
||||||
|
f"{btn}{ul}"
|
||||||
|
"</li>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str:
|
||||||
|
"""Generates HTML for a file item in the tree.
|
||||||
|
|
||||||
|
The generated HTML is of the format:
|
||||||
|
```html
|
||||||
|
<li class="tree-list-item tree-list-item--subitem">
|
||||||
|
<span data-filterable-item-text hidden></span>
|
||||||
|
<div class="tree-list-content tree-list-content-file"></div>
|
||||||
|
</li>
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
file_path: The path to the file for this item.
|
||||||
|
item: Dictionary containing the item information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string.
|
||||||
|
"""
|
||||||
|
item_html_args = self.create_item_html(tabname, item)
|
||||||
|
action_buttons = "".join(
|
||||||
|
[
|
||||||
|
item_html_args["copy_path_button"],
|
||||||
|
item_html_args["metadata_button"],
|
||||||
|
item_html_args["edit_button"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
action_buttons = f"<div class=\"button-row\">{action_buttons}</div>"
|
||||||
|
btn = self.btn_tree_tpl.format(
|
||||||
|
**{
|
||||||
|
"search_terms": "",
|
||||||
|
"subclass": "tree-list-content-file",
|
||||||
|
"tabname": tabname,
|
||||||
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
|
"onclick_extra": item_html_args["card_clicked"],
|
||||||
|
"data_path": file_path,
|
||||||
|
"data_hash": item["shorthash"],
|
||||||
|
"action_list_item_action_leading": "<i class='tree-list-item-action-chevron'></i>",
|
||||||
|
"action_list_item_visual_leading": "🗎",
|
||||||
|
"action_list_item_label": item["name"],
|
||||||
|
"action_list_item_visual_trailing": "",
|
||||||
|
"action_list_item_action_trailing": action_buttons,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"<li class='tree-list-item tree-list-item--subitem' data-tree-entry-type='file'>"
|
||||||
|
f"{btn}"
|
||||||
|
"</li>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tree_view_html(self, tabname: str) -> str:
|
||||||
|
"""Generates HTML for displaying folders in a tree view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML string generated for this tree view.
|
||||||
|
"""
|
||||||
|
res = ""
|
||||||
|
|
||||||
|
# Setup the tree dictionary.
|
||||||
|
roots = self.allowed_directories_for_previews()
|
||||||
|
tree_items = {v["filename"]: ExtraNetworksItem(v) for v in self.items.values()}
|
||||||
|
tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items)
|
||||||
|
|
||||||
|
if not tree:
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]:
|
||||||
|
"""Recursively builds HTML for a tree.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary representing a directory tree. Can be NoneType.
|
||||||
|
Data keys should be absolute paths from the root and values
|
||||||
|
should be subdirectory trees or an ExtraNetworksItem.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If data is not None: HTML string
|
||||||
|
Else: None
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Lists for storing <li> items html for directories and files separately.
|
||||||
|
_dir_li = []
|
||||||
|
_file_li = []
|
||||||
|
|
||||||
|
for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])):
|
||||||
|
if isinstance(v, (ExtraNetworksItem,)):
|
||||||
|
_file_li.append(self.create_tree_file_item_html(tabname, k, v.item))
|
||||||
|
else:
|
||||||
|
_dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v)))
|
||||||
|
|
||||||
|
# Directories should always be displayed before files so we order them here.
|
||||||
|
return "".join(_dir_li) + "".join(_file_li)
|
||||||
|
|
||||||
|
# Add each root directory to the tree.
|
||||||
|
for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])):
|
||||||
|
item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v))
|
||||||
|
# Only add non-empty entries to the tree.
|
||||||
|
if item_html is not None:
|
||||||
|
res += item_html
|
||||||
|
|
||||||
|
return f"<ul class='tree-list tree-list--tree'>{res}</ul>"
|
||||||
|
|
||||||
|
def create_card_view_html(self, tabname: str) -> str:
|
||||||
|
"""Generates HTML for the network Card View section for a tab.
|
||||||
|
|
||||||
|
This HTML goes into the `extra-networks-pane.html` <div> with
|
||||||
|
`id='{tabname}_{extra_networks_tabname}_cards`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string.
|
||||||
|
"""
|
||||||
|
res = ""
|
||||||
|
for item in self.items.values():
|
||||||
|
res += self.create_item_html(tabname, item, self.card_tpl)
|
||||||
|
|
||||||
|
if res == "":
|
||||||
|
dirs = "".join([f"<li>{x}</li>" for x in self.allowed_directories_for_previews()])
|
||||||
|
res = shared.html("extra-networks-no-cards.html").format(dirs=dirs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
def create_html(self, tabname):
|
||||||
|
"""Generates an HTML string for the current pane.
|
||||||
|
|
||||||
|
The generated HTML uses `extra-networks-pane.html` as a template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tabname: The name of the active tab.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string.
|
||||||
|
"""
|
||||||
|
self.lister.reset()
|
||||||
|
self.metadata = {}
|
||||||
|
self.items = {x["name"]: x for x in self.list_items()}
|
||||||
|
# Populate the instance metadata for each item.
|
||||||
|
for item in self.items.values():
|
||||||
|
metadata = item.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
self.metadata[item["name"]] = metadata
|
||||||
|
|
||||||
|
if "user_metadata" not in item:
|
||||||
|
self.read_user_metadata(item)
|
||||||
|
|
||||||
|
data_sortdir = shared.opts.extra_networks_card_order
|
||||||
|
data_sortmode = shared.opts.extra_networks_card_order_field.lower().replace("sort", "").replace(" ", "_").rstrip("_").strip()
|
||||||
|
data_sortkey = f"{data_sortmode}-{data_sortdir}-{len(self.items)}"
|
||||||
|
tree_view_btn_extra_class = ""
|
||||||
|
tree_view_div_extra_class = "hidden"
|
||||||
|
if shared.opts.extra_networks_tree_view_default_enabled:
|
||||||
|
tree_view_btn_extra_class = "extra-network-control--enabled"
|
||||||
|
tree_view_div_extra_class = ""
|
||||||
|
|
||||||
|
return self.pane_tpl.format(
|
||||||
|
**{
|
||||||
|
"tabname": tabname,
|
||||||
|
"extra_networks_tabname": self.extra_networks_tabname,
|
||||||
|
"data_sortmode": data_sortmode,
|
||||||
|
"data_sortkey": data_sortkey,
|
||||||
|
"data_sortdir": data_sortdir,
|
||||||
|
"tree_view_btn_extra_class": tree_view_btn_extra_class,
|
||||||
|
"tree_view_div_extra_class": tree_view_div_extra_class,
|
||||||
|
"tree_html": self.create_tree_view_html(tabname),
|
||||||
|
"items_html": self.create_card_view_html(tabname),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_item(self, name, index=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def list_items(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def allowed_directories_for_previews(self):
|
||||||
|
return []
|
||||||
|
|
||||||
def get_sort_keys(self, path):
|
def get_sort_keys(self, path):
|
||||||
"""
|
"""
|
||||||
List of default keys used for sorting in the UI.
|
List of default keys used for sorting in the UI.
|
||||||
"""
|
"""
|
||||||
pth = Path(path)
|
pth = Path(path)
|
||||||
stat = pth.stat()
|
mtime, ctime = self.lister.mctime(path)
|
||||||
return {
|
return {
|
||||||
"date_created": int(stat.st_ctime or 0),
|
"date_created": int(mtime),
|
||||||
"date_modified": int(stat.st_mtime or 0),
|
"date_modified": int(ctime),
|
||||||
"name": pth.name.lower(),
|
"name": pth.name.lower(),
|
||||||
"path": str(pth.parent).lower(),
|
"path": str(pth.parent).lower(),
|
||||||
}
|
}
|
||||||
@ -292,10 +567,10 @@ class ExtraNetworksPage:
|
|||||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
|
potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], [])
|
||||||
|
|
||||||
for file in potential_files:
|
for file in potential_files:
|
||||||
if os.path.isfile(file):
|
if self.lister.exists(file):
|
||||||
return self.link_preview(file)
|
return self.link_preview(file)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -305,6 +580,9 @@ class ExtraNetworksPage:
|
|||||||
Find and read a description file for a given path (without extension).
|
Find and read a description file for a given path (without extension).
|
||||||
"""
|
"""
|
||||||
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
for file in [f"{path}.txt", f"{path}.description.txt"]:
|
||||||
|
if not self.lister.exists(file):
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
with open(file, "r", encoding="utf-8", errors="replace") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
@ -360,10 +638,7 @@ def pages_in_preferred_order(pages):
|
|||||||
|
|
||||||
return sorted(pages, key=lambda x: tab_scores[x.name])
|
return sorted(pages, key=lambda x: tab_scores[x.name])
|
||||||
|
|
||||||
|
|
||||||
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||||
from modules.ui import switch_values_symbol
|
|
||||||
|
|
||||||
ui = ExtraNetworksUi()
|
ui = ExtraNetworksUi()
|
||||||
ui.pages = []
|
ui.pages = []
|
||||||
ui.pages_contents = []
|
ui.pages_contents = []
|
||||||
@ -373,62 +648,53 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
|||||||
|
|
||||||
related_tabs = []
|
related_tabs = []
|
||||||
|
|
||||||
|
button_refresh = gr.Button("Refresh", elem_id=f"{tabname}_extra_refresh_internal", visible=False)
|
||||||
|
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
with gr.Tab(page.title, elem_id=f"{tabname}_{page.extra_networks_tabname}", elem_classes=["extra-page"]) as tab:
|
||||||
with gr.Column(elem_id=f"{tabname}_{page.id_page}_prompts", elem_classes=["extra-page-prompts"]):
|
with gr.Column(elem_id=f"{tabname}_{page.extra_networks_tabname}_prompts", elem_classes=["extra-page-prompts"]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
elem_id = f"{tabname}_{page.extra_networks_tabname}_cards_html"
|
||||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
|
|
||||||
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[])
|
|
||||||
|
|
||||||
editor = page.create_user_metadata_editor(ui, tabname)
|
editor = page.create_user_metadata_editor(ui, tabname)
|
||||||
editor.create_ui()
|
editor.create_ui()
|
||||||
ui.user_metadata_editors.append(editor)
|
ui.user_metadata_editors.append(editor)
|
||||||
|
|
||||||
related_tabs.append(tab)
|
related_tabs.append(tab)
|
||||||
|
|
||||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
ui.button_save_preview = gr.Button('Save preview', elem_id=f"{tabname}_save_preview", visible=False)
|
||||||
dropdown_sort = gr.Dropdown(choices=['Path', 'Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f"{tabname}_preview_filename", visible=False)
|
||||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
|
||||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
|
||||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
|
||||||
|
|
||||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
|
||||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
|
||||||
|
|
||||||
tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
|
|
||||||
|
|
||||||
for tab in unrelated_tabs:
|
for tab in unrelated_tabs:
|
||||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
|
tab.select(fn=None, _js=f"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}", inputs=[], outputs=[], show_progress=False)
|
||||||
|
|
||||||
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
||||||
allow_prompt = "true" if page.allow_prompt else "false"
|
jscode = (
|
||||||
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
"function(){{"
|
||||||
|
f"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()}, '{tabname}_{page.extra_networks_tabname}');"
|
||||||
|
f"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');"
|
||||||
|
"}}"
|
||||||
|
)
|
||||||
|
tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)
|
||||||
|
|
||||||
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}_prompts" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
def create_html():
|
||||||
|
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
||||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
|
||||||
|
|
||||||
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
|
|
||||||
|
|
||||||
def pages_html():
|
def pages_html():
|
||||||
if not ui.pages_contents:
|
if not ui.pages_contents:
|
||||||
return refresh()
|
create_html()
|
||||||
|
|
||||||
return ui.pages_contents
|
return ui.pages_contents
|
||||||
|
|
||||||
def refresh():
|
def refresh():
|
||||||
for pg in ui.stored_extra_pages:
|
for pg in ui.stored_extra_pages:
|
||||||
pg.refresh()
|
pg.refresh()
|
||||||
|
create_html()
|
||||||
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
|
|
||||||
|
|
||||||
return ui.pages_contents
|
return ui.pages_contents
|
||||||
|
|
||||||
interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages])
|
interface.load(fn=pages_html, inputs=[], outputs=ui.pages)
|
||||||
|
# NOTE: Event is manually fired in extraNetworks.js:extraNetworksTreeRefreshOnClick()
|
||||||
|
# button is unused and hidden at all times. Only used in order to fire this event.
|
||||||
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
|
||||||
|
|
||||||
return ui
|
return ui
|
||||||
@ -478,5 +744,3 @@ def setup_ui(ui, gallery):
|
|||||||
|
|
||||||
for editor in ui.user_metadata_editors:
|
for editor in ui.user_metadata_editors:
|
||||||
editor.setup_ui(gallery)
|
editor.setup_ui(gallery)
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user