Merge branch 'dev' into Branch_AddNewFilenameGen

This commit is contained in:
AUTOMATIC1111 2023-04-29 13:10:46 +03:00 committed by GitHub
commit 87535fcf29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
72 changed files with 1994 additions and 1367 deletions

View File

@ -18,7 +18,7 @@ jobs:
cache-dependency-path: |
**/requirements*txt
- name: Run tests
run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()

2
.gitignore vendored
View File

@ -32,4 +32,4 @@ notification.mp3
/extensions
/test/stdout.txt
/test/stderr.txt
/cache.json
/cache.json*

View File

@ -13,9 +13,9 @@ A browser interface based on Gradio library for Stable Diffusion.
- Prompt Matrix
- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- a man in a `((tuxedo))` - will pay more attention to tuxedo
- a man in a `(tuxedo:1.21)` - alternative syntax
- select text and press `Ctrl+Up` or `Ctrl+Down` to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters
- Textual Inversion
@ -28,7 +28,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler
- ESRGAN, neural network upscaler with a lot of third party models
- SwinIR and Swin2SR([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers
- LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options
- Sampling method selection
@ -46,7 +46,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable)
- Running arbitrary python code from UI (must run with `--allow-code` to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Tiling support, a checkbox to create images that can be tiled like textures
@ -69,7 +69,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args)
- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI
- Generate forever option
- Training tab
@ -78,11 +78,11 @@ A browser interface based on Gradio library for Stable Diffusion.
- Clip skip
- Hypernetworks
- Loras (same as Hypernetworks but more pretty)
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt.
- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt
- Can select to load a different VAE from settings screen
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
@ -91,7 +91,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
- Now with a license!
- Reorder elements in the UI from settings screen
-
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
@ -101,7 +100,7 @@ Alternatively, use online services (like Google Colab):
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
2. Install [git](https://git-scm.com/download/win).
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
@ -121,6 +120,7 @@ sudo pacman -S wget git python3
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
```
3. Run `webui.sh`.
4. Check `webui-user.sh` for options.
### Installation on Apple Silicon
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
@ -159,4 +159,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
- (You)

View File

@ -4,8 +4,8 @@ channels:
- defaults
dependencies:
- python=3.10
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
- pip=23.0
- cudatoolkit=11.8
- pytorch=2.0
- torchvision=0.15
- numpy=1.23

View File

@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -2,20 +2,34 @@ import glob
import os
import re
import torch
from typing import Union
from modules import shared, devices, sd_models, errors
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
re_compiled = {}
suffix_conversion = {
"attentions": {},
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
}
def convert_diffusers_name_to_compvis(key):
def match(match_list, regex):
def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex_text):
regex = re_compiled.get(regex_text)
if regex is None:
regex = re.compile(regex_text)
re_compiled[regex_text] = regex
r = re.match(regex, key)
if not r:
return False
@ -26,16 +40,33 @@ def convert_diffusers_name_to_compvis(key):
m = []
if match(m, re_unet_down_blocks):
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, re_unet_mid_blocks):
return f"diffusion_model_middle_block_1_{m[1]}"
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if match(m, re_unet_up_blocks):
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
if match(m, re_text_block):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
@ -101,15 +132,22 @@ def load_lora(name, filename):
sd = sd_models.read_state_dict(filename)
keys_failed_to_match = []
keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items():
fullkey = convert_diffusers_name_to_compvis(key_diffusers)
key, lora_key = fullkey.split(".", 1)
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None:
keys_failed_to_match.append(key_diffusers)
m = re_x_proj.match(key)
if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
if sd_module is None:
keys_failed_to_match[key_diffusers] = key
continue
lora_module = lora.modules.get(key, None)
@ -123,15 +161,21 @@ def load_lora(name, filename):
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad():
module.weight.copy_(weight)
module.to(device=devices.device, dtype=devices.dtype)
module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight":
lora_module.up = module
@ -177,28 +221,120 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora)
def lora_forward(module, input, res):
if len(loaded_loras) == 0:
return res
def lora_calc_updown(lora, module, target):
with torch.no_grad():
up = module.up.weight.to(target.device, dtype=target.dtype)
down = module.down.weight.to(target.device, dtype=target.dtype)
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""
lora_layer_name = getattr(self, 'lora_layer_name', None)
if lora_layer_name is None:
return
current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:
weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
self.weight.copy_(weights_backup)
return res
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight)
continue
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
continue
if module is None:
continue
print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names)
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
def lora_Linear_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
def lora_MultiheadAttention_forward(self, *args, **kwargs):
lora_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras():
@ -211,7 +347,7 @@ def list_available_loras():
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
for filename in sorted(candidates):
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue

View File

@ -9,7 +9,11 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui():
@ -20,11 +24,27 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
@ -32,7 +52,5 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))

View File

@ -5,11 +5,15 @@ import traceback
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
tile = opts.SCUNET_tile
h, w = img.height, img.width
np_img = np.array(img)
np_img = np_img[:, :, ::-1] # RGB to BGR
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str):
device = devices.get_device_for('scunet')
@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = model.to(device)
return model

View File

@ -1,110 +1,42 @@
// Stable Diffusion WebUI - Bracket checker
// Version 1.0
// By Hingashi no Florin/Bwin4L
// By Hingashi no Florin/Bwin4L & @akx
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
function checkBrackets(evt, textArea, counterElt) {
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
function checkBrackets(textArea, counterElt) {
var counts = {};
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
counts[bracket] = (counts[bracket] || 0) + 1;
});
var errors = [];
openBracketRegExp = /\(/g;
closeBracketRegExp = /\)/g;
openSquareBracketRegExp = /\[/g;
closeSquareBracketRegExp = /\]/g;
openCurlyBracketRegExp = /\{/g;
closeCurlyBracketRegExp = /\}/g;
totalOpenBracketMatches = 0;
totalCloseBracketMatches = 0;
totalOpenSquareBracketMatches = 0;
totalCloseSquareBracketMatches = 0;
totalOpenCurlyBracketMatches = 0;
totalCloseCurlyBracketMatches = 0;
openBracketMatches = textArea.value.match(openBracketRegExp);
if(openBracketMatches) {
totalOpenBracketMatches = openBracketMatches.length;
}
closeBracketMatches = textArea.value.match(closeBracketRegExp);
if(closeBracketMatches) {
totalCloseBracketMatches = closeBracketMatches.length;
}
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
if(openSquareBracketMatches) {
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
}
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
if(closeSquareBracketMatches) {
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
}
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
if(openCurlyBracketMatches) {
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
}
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
if(closeCurlyBracketMatches) {
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
}
if(totalOpenBracketMatches != totalCloseBracketMatches) {
if(!counterElt.title.includes(errorStringParen)) {
counterElt.title += errorStringParen;
function checkPair(open, close, kind) {
if (counts[open] !== counts[close]) {
errors.push(
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
);
}
} else {
counterElt.title = counterElt.title.replace(errorStringParen, '');
}
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
if(!counterElt.title.includes(errorStringSquare)) {
counterElt.title += errorStringSquare;
}
} else {
counterElt.title = counterElt.title.replace(errorStringSquare, '');
}
checkPair('(', ')', 'round brackets');
checkPair('[', ']', 'square brackets');
checkPair('{', '}', 'curly brackets');
counterElt.title = errors.join('\n');
counterElt.classList.toggle('error', errors.length !== 0);
}
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
if(!counterElt.title.includes(errorStringCurly)) {
counterElt.title += errorStringCurly;
}
} else {
counterElt.title = counterElt.title.replace(errorStringCurly, '');
}
function setupBracketChecking(id_prompt, id_counter) {
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter)
if(counterElt.title != '') {
counterElt.classList.add('error');
} else {
counterElt.classList.remove('error');
if (textarea && counter) {
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
}
}
function setupBracketChecking(id_prompt, id_counter){
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
var counter = gradioApp().getElementById(id_counter)
textarea.addEventListener("input", function(evt){
checkBrackets(evt, textarea, counter)
});
}
var shadowRootLoaded = setInterval(function() {
var shadowRoot = document.querySelector('gradio-app').shadowRoot;
if(! shadowRoot) return false;
var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea');
if(shadowTextArea.length < 1) return false;
clearInterval(shadowRootLoaded);
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
setupBracketChecking('img2img_prompt', 'imgimg_token_counter')
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
}, 1000);
onUiLoaded(function () {
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
});

View File

@ -1,4 +1,4 @@
<div class='card' {preview_html} onclick={card_clicked}>
<div class='card' style={style} onclick={card_clicked}>
{metadata_button}
<div class='actions'>

View File

@ -635,4 +635,30 @@ SOFTWARE.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
</pre>
<h2><a href="https://github.com/explosion/curated-transformers/blob/main/LICENSE">Curated transformers</a></h2>
<small>The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers</small>
<pre>
The MIT License (MIT)
Copyright (C) 2021 ExplosionAI GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
</pre>

View File

@ -12,7 +12,7 @@ function dimensionChange(e, is_width, is_height){
currentHeight = e.target.value*1.0
}
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
var inImg2img = gradioApp().querySelector("#tab_img2img").style.display == "block";
if(!inImg2img){
return;
@ -22,7 +22,7 @@ function dimensionChange(e, is_width, is_height){
var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){ // img2img
targetElement = gradioApp().querySelector('div[data-testid=image] img');
targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');
} else if(tabIndex == 1){ //Sketch
targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');
} else if(tabIndex == 2){ // Inpaint
@ -30,7 +30,7 @@ function dimensionChange(e, is_width, is_height){
} else if(tabIndex == 3){ // Inpaint sketch
targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');
}
if(targetElement){
@ -38,7 +38,7 @@ function dimensionChange(e, is_width, is_height){
if(!arPreviewRect){
arPreviewRect = document.createElement('div')
arPreviewRect.id = "imageARPreview";
gradioApp().getRootNode().appendChild(arPreviewRect)
gradioApp().appendChild(arPreviewRect)
}
@ -91,23 +91,26 @@ onUiUpdate(function(){
if(arPreviewRect){
arPreviewRect.style.display = 'none';
}
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
if(inImg2img){
let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){
var is_width = e.parentElement.id == "img2img_width"
var is_height = e.parentElement.id == "img2img_height"
var tabImg2img = gradioApp().querySelector("#tab_img2img");
if (tabImg2img) {
var inImg2img = tabImg2img.style.display == "block";
if(inImg2img){
let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){
var is_width = e.parentElement.id == "img2img_width"
var is_height = e.parentElement.id == "img2img_height"
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch')
}
if(is_width){
currentWidth = e.value*1.0
}
if(is_height){
currentHeight = e.value*1.0
}
})
}
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch')
}
if(is_width){
currentWidth = e.value*1.0
}
if(is_height){
currentHeight = e.value*1.0
}
})
}
}
});

View File

@ -43,7 +43,7 @@ contextMenuInit = function(){
})
gradioApp().getRootNode().appendChild(contextMenu)
gradioApp().appendChild(contextMenu)
let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4;
@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})();
//End example Context Menu Items

View File

@ -1,6 +1,6 @@
function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! target.matches("[id*='_toprow'] [id*='_prompt'] textarea")) return;
if (! (event.metaKey || event.ctrlKey)) return;
let isPlus = event.key == "ArrowUp"
@ -17,7 +17,7 @@ function keyupEditAttention(event){
// Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
@ -27,7 +27,7 @@ function keyupEditAttention(event){
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
@ -43,10 +43,28 @@ function keyupEditAttention(event){
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
function selectCurrentWord(){
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
}
event.preventDefault();
@ -81,7 +99,13 @@ function keyupEditAttention(event){
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus();
target.value = text;
@ -93,4 +117,4 @@ function keyupEditAttention(event){
addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
});

View File

@ -1,5 +1,5 @@
function extensions_apply(_, _){
function extensions_apply(_, _, disable_all){
var disable = []
var update = []
@ -13,10 +13,10 @@ function extensions_apply(_, _){
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
}
function extensions_check(){
function extensions_check(_, _){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){

View File

@ -139,3 +139,41 @@ function extraNetworksShowMetadata(text){
popup(elem);
}
function requestGet(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest();
var args = Object.keys(data).map(function(k){ return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]) }).join('&')
xhr.open("GET", url + "?" + args, true);
xhr.onreadystatechange = function () {
if (xhr.readyState === 4) {
if (xhr.status === 200) {
try {
var js = JSON.parse(xhr.responseText);
handler(js)
} catch (error) {
console.error(error);
errorHandler()
}
} else{
errorHandler()
}
}
};
var js = JSON.stringify(data);
xhr.send(js);
}
function extraNetworksRequestMetadata(event, extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
extraNetworksShowMetadata(data.metadata)
} else{
showError()
}
}, showError)
event.stopPropagation()
}

View File

@ -16,7 +16,7 @@ onUiUpdate(function(){
let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
let selectedTab = gradioApp().querySelector('#tabs div button')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
});

View File

@ -18,11 +18,10 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
"\u{1f5d1}": "Clear prompt",
"\u{1f5d1}\ufe0f": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"\u{1f4d2}": "Paste available values into the field",
"\u{1f3b4}": "Show extra networks",
"\u{1f3b4}": "Show/hide extra networks",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
@ -40,8 +39,7 @@ titles = {
"Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image",
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
"Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
@ -71,8 +69,10 @@ titles = {
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.",
"Loops": "How many times to repeat processing an image and using it as input for the next iteration",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
"Loops": "How many times to process an image. Each output is used as the input of the next loop. If set to 1, behavior will be as if this script were not used.",
"Final denoising strength": "The denoising strength for the final loop of each image in the batch.",
"Denoising strength curve": "The denoising curve controls the rate of denoising strength change each loop. Aggressive: Most of the change will happen towards the start of the loops. Linear: Change will be constant through all loops. Lazy: Most of the change will happen towards the end of the loops.",
"Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",

View File

@ -32,13 +32,7 @@ function negmod(n, m) {
function updateOnBackgroundChange() {
const modalImage = gradioApp().getElementById("modalImage")
if (modalImage && modalImage.offsetParent) {
let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
let currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
let currentButton = selected_gallery_button();
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
@ -50,22 +44,10 @@ function updateOnBackgroundChange() {
}
function modalImageSwitch(offset) {
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
var galleryButtons = []
allgalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
galleryButtons.push(elem);
}
})
var galleryButtons = all_gallery_buttons();
if (galleryButtons.length > 1) {
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
var currentButton = null
allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
var currentButton = selected_gallery_button();
var result = -1
galleryButtons.forEach(function(v, i) {
@ -136,37 +118,29 @@ function modalKeyHandler(event) {
}
}
function showGalleryImage() {
setTimeout(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
function setupImageForLightbox(e) {
if (e.dataset.modded)
return;
if (fullImg_preview != null) {
fullImg_preview.forEach(function function_name(e) {
if (e.dataset.modded)
return;
e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer'
e.style.userSelect='none'
e.dataset.modded = true;
e.style.cursor='pointer'
e.style.userSelect='none'
var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
var isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1
// For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click'
// For Firefox, listening on click first switched to next image then shows the lightbox.
// If you know how to fix this without switching to mousedown event, please.
// For other browsers the event is click to make it possiblr to drag picture.
var event = isFirefox ? 'mousedown' : 'click'
e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt)
}, true);
}
});
}
e.addEventListener(event, function (evt) {
if(!opts.js_modal_lightbox || evt.button != 0) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
evt.preventDefault()
showModal(evt)
}, true);
}, 100);
}
function modalZoomSet(modalImage, enable) {
@ -199,21 +173,21 @@ function modalTileImageToggle(event) {
}
function galleryImageHandler(e) {
if (e && e.parentElement.tagName == 'BUTTON') {
//if (e && e.parentElement.tagName == 'BUTTON') {
e.onclick = showGalleryImage;
}
//}
}
onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full')
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) {
fullImg_preview.forEach(galleryImageHandler);
fullImg_preview.forEach(setupImageForLightbox);
}
updateOnBackgroundChange();
})
document.addEventListener("DOMContentLoaded", function() {
const modalFragment = document.createDocumentFragment();
//const modalFragment = document.createDocumentFragment();
const modal = document.createElement('div')
modal.onclick = closeModal;
modal.id = "lightboxModal";
@ -277,9 +251,12 @@ document.addEventListener("DOMContentLoaded", function() {
modal.appendChild(modalNext)
try {
gradioApp().appendChild(modal);
} catch (e) {
gradioApp().body.appendChild(modal);
}
gradioApp().getRootNode().appendChild(modal)
document.body.appendChild(modalFragment);
document.body.appendChild(modal);
});

View File

@ -15,7 +15,7 @@ onUiUpdate(function(){
}
}
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] img.h-full.w-full.overflow-hidden');
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] div[id$="_results"] .thumbnail-item > img');
if (galleryPreviews == null) return;

View File

@ -1,78 +1,13 @@
// code related to showing and updating progressbar shown as the image is being made
galleries = {}
storedGallerySelections = {}
galleryObservers = {}
function rememberGallerySelection(id_gallery){
storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
}
function getGallerySelectedIndex(id_gallery){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = -1
galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
return currentlySelectedIndex
}
// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
storedGallerySelections[id_gallery] = -1
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
prevSelectedIndex = storedGallerySelections[id_gallery]
storedGallerySelections[id_gallery] = -1
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
// automatically re-open previously selected index (if exists)
activeElement = gradioApp().activeElement;
let scrollX = window.scrollX;
let scrollY = window.scrollY;
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
// When the gallery button is clicked, it gains focus and scrolls itself into view
// We need to scroll back to the previous position
setTimeout(function (){
window.scrollTo(scrollX, scrollY);
}, 50);
if(activeElement){
// i fought this for about an hour; i don't know why the focus is lost or why this helps recover it
// if someone has a better solution please by all means
setTimeout(function (){
activeElement.focus({
preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it
})
}, 1);
}
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
}
onUiUpdate(function(){
check_gallery('txt2img_gallery')
check_gallery('img2img_gallery')
})
function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest();
var url = url;
@ -203,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return
}
if(elapsedFromStart > 5 && !res.queued && !res.active){
if(elapsedFromStart > 40 && !res.queued && !res.active){
removeProgressBar()
return
}

View File

@ -7,9 +7,31 @@ function set_theme(theme){
}
}
function all_gallery_buttons() {
var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small');
var visibleGalleryButtons = [];
allGalleryButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleGalleryButtons.push(elem);
}
})
return visibleGalleryButtons;
}
function selected_gallery_button() {
var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected');
var visibleCurrentButton = null;
allCurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
visibleCurrentButton = elem;
}
})
return visibleCurrentButton;
}
function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
var buttons = all_gallery_buttons();
var button = selected_gallery_button();
var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } })
@ -18,14 +40,18 @@ function selected_gallery_index(){
}
function extract_image_from_gallery(gallery){
if(gallery.length == 1){
return [gallery[0]]
if (gallery.length == 0){
return [null];
}
if (gallery.length == 1){
return [gallery[0]];
}
index = selected_gallery_index()
if (index < 0 || index >= gallery.length){
return [null]
// Use the first image in the gallery as the default
index = 0;
}
return [gallery[index]];
@ -86,7 +112,7 @@ function get_tab_index(tabId){
var res = 0
gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button').forEach(function(button, i){
if(button.className.indexOf('bg-white') != -1)
if(button.className.indexOf('selected') != -1)
res = i
})
@ -255,7 +281,6 @@ onUiUpdate(function(){
}
prompt.parentElement.insertBefore(counter, prompt)
counter.classList.add("token-counter")
prompt.parentElement.style.position = "relative"
promptTokecountUpdateFuncs[id] = function(){ update_token_counter(id_button); }

View File

@ -5,24 +5,25 @@ import sys
import importlib.util
import shlex
import platform
import argparse
import json
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, default='config.json')
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir
script_path = os.path.dirname(__file__)
data_path = os.getcwd()
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)
args, _ = cmd_args.parser.parse_known_args()
dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
def check_python_version():
@ -70,23 +71,6 @@ def commit_hash():
return stored_commit_hash
def extract_arg(args, name):
return [x for x in args if x != name], name in args
def extract_opt(args, name):
opt = None
is_present = False
if name in args:
is_present = True
idx = args.index(name)
del args[idx]
if idx < len(args) and args[idx][0] != "-":
opt = args[idx]
del args[idx]
return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
@ -137,12 +121,12 @@ def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
def run_pip(args, desc=None, live=False):
if skip_install:
return
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
def check_run_python(code):
@ -222,26 +206,29 @@ def list_extensions(settings_file):
print(e, file=sys.stderr)
disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none')
return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
if disable_all_extensions != 'none':
return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
if not os.path.isdir(dir_extensions):
if not os.path.isdir(extensions_dir):
return
for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(dir_extensions, dirname_extension))
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
def prepare_environment():
global skip_install
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
@ -252,27 +239,13 @@ def prepare_environment():
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args)
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv
if not skip_python_version_check:
if not args.skip_python_version_check:
check_python_version()
commit = commit_hash()
@ -280,10 +253,10 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
if not skip_torch_cuda_test:
if not args.skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("gfpgan"):
@ -295,10 +268,10 @@ def prepare_environment():
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers:
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
@ -307,7 +280,7 @@ def prepare_environment():
elif platform.system() == "Linux":
run_pip(f"install {xformers_package}", "xformers")
if not is_installed("pyngrok") and ngrok:
if not is_installed("pyngrok") and args.ngrok:
run_pip("install pyngrok", "ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
@ -323,22 +296,22 @@ def prepare_environment():
if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
run_pip(f"install -r \"{requirements_file}\"", "requirements")
run_extensions_installers(settings_file=args.ui_settings_file)
if update_check:
if args.update_check:
version_check(commit)
if update_all_extensions:
git_pull_recursive(os.path.join(data_path, dir_extensions))
if args.update_all_extensions:
git_pull_recursive(extensions_dir)
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)
if run_tests:
exitcode = tests(test_dir)
if args.tests and not args.no_tests:
exitcode = tests(args.tests)
exit(exitcode)
@ -352,6 +325,8 @@ def tests(test_dir):
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
if "--no-tests" not in sys.argv:
sys.argv.append("--no-tests")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")

Binary file not shown.

View File

@ -3,11 +3,14 @@ import io
import time
import datetime
import uvicorn
import gradio as gr
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
@ -18,7 +21,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@ -90,6 +93,16 @@ def encode_pil_to_base64(image):
return base64.b64encode(bytes_data)
def api_middleware(app: FastAPI):
rich_available = True
try:
import anyio # importing just so it can be placed on silent list
import starlette # importing just so it can be placed on silent list
from rich.console import Console
console = Console()
except:
import traceback
rich_available = False
@app.middleware("http")
async def log_and_time(req: Request, call_next):
ts = time.time()
@ -110,6 +123,36 @@ def api_middleware(app: FastAPI):
))
return res
def handle_exception(request: Request, e: Exception):
err = {
"error": type(e).__name__,
"detail": vars(e).get('detail', ''),
"body": vars(e).get('body', ''),
"errors": str(e),
}
print(f"API error: {request.method}: {request.url} {err}")
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
if rich_available:
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
else:
traceback.print_exc()
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
@app.middleware("http")
async def exception_handling(request: Request, call_next):
try:
return await call_next(request)
except Exception as e:
return handle_exception(request, e)
@app.exception_handler(Exception)
async def fastapi_exception_handler(request: Request, e: Exception):
return handle_exception(request, e)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, e: HTTPException):
return handle_exception(request, e)
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
@ -150,8 +193,13 @@ class Api:
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
@ -185,7 +233,7 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx]
def init_script_args(self, request, selectable_scripts, selectable_idx, script_runner):
def init_default_script_args(self, script_runner):
#find max idx from the scripts in runner and generate a none array to init script_args
last_arg_index = 1
for script in script_runner.scripts:
@ -193,13 +241,24 @@ class Api:
last_arg_index = script.args_to
# None everywhere except position 0 to initialize script args
script_args = [None]*last_arg_index
script_args[0] = 0
# get default values
with gr.Blocks(): # will throw errors calling ui function without this
for script in script_runner.scripts:
if script.ui(script.is_img2img):
ui_default_values = []
for elem in script.ui(script.is_img2img):
ui_default_values.append(elem.value)
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
script_args = default_script_args.copy()
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
script_args[0] = selectable_idx + 1
else:
# when [0] = 0 no selectable script to run
script_args[0] = 0
# Now check for always on scripts
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
@ -212,7 +271,9 @@ class Api:
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
# always on script with no arg should always run so you don't really need to add them to the requests
if "args" in request.alwayson_scripts[alwayson_script_name]:
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
# min between arg length in scriptrunner and arg length in the request
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
@ -220,6 +281,8 @@ class Api:
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
@ -235,7 +298,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
script_args = self.init_script_args(txt2imgreq, selectable_scripts, selectable_script_idx, script_runner)
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
@ -272,6 +335,8 @@ class Api:
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui()
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
@ -289,7 +354,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
script_args = self.init_script_args(img2imgreq, selectable_scripts, selectable_script_idx, script_runner)
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
@ -331,16 +396,11 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
image_list = reqDict.pop('imageList', [])
image_folder = [decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
@ -412,6 +472,16 @@ class Api:
return {}
def unloadapi(self):
unload_model_weights()
return {}
def reloadapi(self):
reload_model_weights()
return {}
def skip(self):
shared.state.skip()

103
modules/cmd_args.py Normal file
View File

@ -0,0 +1,103 @@
import argparse
import os
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
parser = argparse.ArgumentParser()
parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui
parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program")
parser.add_argument("--skip-python-version-check", action='store_true', help="launch.py argument: do not check python version")
parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly")
parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup")
parser.add_argument("--tests", type=str, default=None, help="launch.py argument: run tests in the specified directory")
parser.add_argument("--no-tests", action='store_true', help="launch.py argument: do not run tests even if --tests option is specified")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)

View File

@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed)
if device.type == 'mps':
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
if device.type == 'mps':
from modules.shared import opts
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)

View File

@ -5,17 +5,22 @@ import traceback
import time
import git
from modules import paths, shared
from modules import shared
from modules.paths_internal import extensions_dir, extensions_builtin_dir
extensions = []
extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
if not os.path.exists(extensions_dir):
os.makedirs(extensions_dir)
def active():
return [x for x in extensions if x.enabled]
if shared.opts.disable_all_extensions == "all":
return []
elif shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
class Extension:
@ -27,21 +32,29 @@ class Extension:
self.can_update = False
self.is_builtin = is_builtin
self.version = ''
self.remote = None
self.have_info_from_repo = False
def read_info_from_repo(self):
if self.have_info_from_repo:
return
self.have_info_from_repo = True
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
if os.path.exists(os.path.join(self.path, ".git")):
repo = git.Repo(self.path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
try:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
self.remote = next(repo.remote().urls, None)
head = repo.head.commit
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
self.version = f'{head.hexsha[:8]} ({ts})'
@ -89,7 +102,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
paths = []
if shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = []
for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname):
return
@ -99,9 +117,8 @@ def list_extensions():
if not os.path.isdir(path):
continue
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
for dirname, path, is_builtin in paths:
for dirname, path, is_builtin in extension_paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension)

View File

@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"
return res
@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('RNG', 'randn_source'),
]
@ -401,9 +406,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
button.click(
fn=paste_func,
_js=f"recalculate_prompts_{tabname}",
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
button.click(
fn=None,
_js=f"recalculate_prompts_{tabname}",
inputs=[],
outputs=[],
)

View File

@ -312,7 +312,7 @@ class Hypernetwork:
def list_hypernetworks(path):
res = {}
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":

View File

@ -261,9 +261,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]
assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}"
if len(upscalers) == 0:
upscaler = shared.sd_upscalers[0]
print(f"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback")
else:
upscaler = upscalers[0]
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h:
@ -350,6 +353,7 @@ class FilenameGenerator:
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
'hasprompt': lambda self, *args: self.hasprompt(*args), #accept formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'
@ -662,6 +666,8 @@ Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}
def image_data(data):
import gradio as gr
try:
image = Image.open(io.BytesIO(data))
textinfo, _ = read_info_from_image(image)
@ -677,7 +683,7 @@ def image_data(data):
except Exception:
pass
return '', None
return gr.update(), None
def flatten(img, bgcolor):

View File

@ -151,13 +151,14 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur
if mask:
p.extra_generation_params["Mask blur"] = mask_blur
if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"

View File

@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"]
try:
os.makedirs(tmpdir)
os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
os.removedirs(tmpdir)
class InterrogateModels:

View File

@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram):
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
# register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):

View File

@ -1,4 +1,5 @@
import torch
import platform
from modules import paths
from modules.sd_hijack_utils import CondFunc
from packaging import version
@ -32,6 +33,10 @@ if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
@ -49,4 +54,6 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)

View File

@ -4,7 +4,6 @@ import shutil
import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path
@ -59,6 +58,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
from basicsr.utils.download_util import load_file_from_url
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:

View File

@ -1,16 +1,9 @@
import argparse
import os
import sys
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)

22
modules/paths_internal.py Normal file
View File

@ -0,0 +1,22 @@
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
import argparse
import os
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser_pre.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")

View File

@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1:
for img in image_folder:
image = Image.open(img)
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0])
image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'

View File

@ -3,6 +3,7 @@ import math
import os
import sys
import warnings
import hashlib
import torch
import numpy as np
@ -78,22 +79,28 @@ def apply_overlay(image, paste_loc, index, overlays):
def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
class StableDiffusionProcessing:
"""
@ -190,6 +197,14 @@ class StableDiffusionProcessing:
return conditioning_image
def unclip_image_conditioning(self, source_image):
c_adm = self.sd_model.embedder(source_image)
if self.sd_model.noise_augmentor is not None:
noise_level = 0 # TODO: Allow other noise levels?
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True
@ -241,6 +256,9 @@ class StableDiffusionProcessing:
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@ -459,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": (opts.randn_source if opts.randn_source != "GPU" else None)
}
generation_params.update(p.extra_generation_params)
@ -622,8 +642,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
step_multiplier = 1
if not shared.opts.dont_fix_second_order_samplers_schedule:
try:
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
except:
pass
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@ -689,6 +715,22 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
if opts.return_mask_composite:
output_images.append(image_mask_composite)
del x_samples_ddim
devices.torch_gc()
@ -974,6 +1016,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3:

View File

@ -1,6 +1,5 @@
# this code is adapted from the script contributed by anon from /h/
import io
import pickle
import collections
import sys
@ -12,11 +11,9 @@ import _codecs
import zipfile
import re
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
out = _codecs.encode(*args)
return out
@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return TypedStorage()
return TypedStorage(_internal=True)
def find_class(self, module, name):
if self.extra_handler is not None:

View File

@ -239,7 +239,15 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
for scriptfile in sorted(scripts_list):
def orderby(basedir):
# 1st webui, 2nd extensions-builtin, 3rd extensions
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
@ -513,6 +521,18 @@ def reload_scripts():
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
@ -521,6 +541,8 @@ def IOComponent_init(self, *args, **kwargs):
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
@ -531,3 +553,15 @@ def IOComponent_init(self, *args, **kwargs):
original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.blocks.BlockContext.__init__ = BlockContext_init

View File

@ -109,7 +109,7 @@ class ScriptPostprocessingRunner:
inputs = []
for script in self.scripts_in_preferred_order():
with gr.Box() as group:
with gr.Row() as group:
self.create_script_ui(script, inputs)
script.group = group

View File

@ -337,7 +337,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q, k, v = q.float(), k.float(), v.float()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
@ -372,7 +372,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q, k, v = q.float(), k.float(), v.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(

View File

@ -67,7 +67,7 @@ def hijack_ddpm_edit():
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.1"):
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)

View File

@ -122,7 +122,7 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
for filename in sorted(model_list, key=str.lower):
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
@ -178,7 +178,7 @@ def select_checkpoint():
return checkpoint_info
chckpoint_dict_replacements = {
checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
@ -186,7 +186,7 @@ chckpoint_dict_replacements = {
def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items():
for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
@ -383,6 +383,14 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
@ -494,7 +502,7 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
try:
@ -517,3 +525,23 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
return sd_model
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
if shared.sd_model:
# shared.sd_model.cond_stage_model.to(devices.cpu)
# shared.sd_model.first_stage_model.to(devices.cpu)
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.")
return sd_model

View File

@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict):
def guess_model_config_from_state_dict(sd, filename):
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:
return config_unopenclip
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
if diffusion_model_input.shape[1] == 9:

View File

@ -60,3 +60,13 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
if opts.randn_source == "CPU":
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn

View File

@ -70,8 +70,13 @@ class VanillaStableDiffusionSampler:
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
uc_image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
if self.conditioning_key == "crossattn-adm":
image_conditioning = cond["c_adm"]
uc_image_conditioning = unconditional_conditioning["c_adm"]
else:
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
@ -98,8 +103,12 @@ class VanillaStableDiffusionSampler:
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
if self.conditioning_key == "crossattn-adm":
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
else:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
return x, ts, cond, unconditional_conditioning
@ -176,8 +185,12 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
if self.conditioning_key == "crossattn-adm":
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
else:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
@ -195,8 +208,12 @@ class VanillaStableDiffusionSampler:
# Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
if self.conditioning_key == "crossattn-adm":
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
else:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])

View File

@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module):
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else:
image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
cfg_denoiser_callback(denoiser_params)
@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module):
cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module):
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params)
@ -183,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
if x.device.type == 'mps':
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)

View File

@ -4,6 +4,7 @@ import json
import os
import sys
import time
import requests
from PIL import Image
import gradio as gr
@ -13,114 +14,22 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
from modules.paths import models_path, script_path, data_path
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
demo = None
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = cmd_args.parser
parser = argparse.ArgumentParser()
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--enable-insecure-extension-access", action='store_true', help="enable extensions tab regardless of other options")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None)
parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything')
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button")
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
script_loading.preload_extensions(extensions_dir, parser)
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
restricted_opts = {
"samples_filename_pattern",
"directories_filename_pattern",
@ -131,6 +40,7 @@ restricted_opts = {
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
"outdir_init_images"
}
ui_reorder_categories = [
@ -146,6 +56,21 @@ ui_reorder_categories = [
"scripts",
]
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"freddyaboulton/dracula_revamped",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/pakistan",
"dawood/microsoft_windows",
"ysharma/steampunk"
]
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
@ -332,6 +257,8 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"),
"save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"webp_lossless": OptionInfo(False, "Use lossless compression for webp images"),
"export_for_4chan": OptionInfo(True, "If the saved image file size is above the limit, or its either width or height are above the limit, save a downscaled copy as JPG"),
@ -343,6 +270,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
@ -358,6 +286,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
@ -373,6 +302,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -421,6 +352,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@ -428,6 +360,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
}))
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
@ -448,12 +381,16 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, {"choices": ["cards", "thumbs"]}),
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"),
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
@ -468,11 +405,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
}))
options_templates.update(options_section(('ui', "Live previews"), {
@ -508,7 +447,8 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
}))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
"disabled_extensions": OptionInfo([], "Disable these extensions"),
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))
@ -685,6 +625,24 @@ clip_model = None
progress_print_out = sys.stdout
gradio_theme = gr.themes.Base()
def reload_gradio_theme(theme_name=None):
global gradio_theme
if not theme_name:
theme_name = opts.gradio_theme
if theme_name == "Default":
gradio_theme = gr.themes.Default()
else:
try:
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
except requests.exceptions.ConnectionError:
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
gradio_theme = gr.themes.Default()
class TotalTQDM:
def __init__(self):
@ -726,7 +684,7 @@ mem_mon.start()
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]

View File

@ -152,7 +152,11 @@ class EmbeddingDatabase:
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
if data:
name = data.get('name', name)
else:
# if data is None, means this is not an embeding, just a preview image
return
elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:

View File

@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path
from modules.shared import opts, cmd_opts, restricted_opts
@ -70,17 +70,6 @@ def gr_show(visible=True):
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.wrap .z-20 svg { display:none!important; }
.wrap .z-20::before { content:"Loading..." }
.wrap.cover-bg .z-20::before { content:"" }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
.meta-text-center { display:none!important; }
"""
# Using constants for these since the variation selector isn't visible.
# Important that they exactly match script.js for tooltip to work.
random_symbol = '\U0001f3b2\ufe0f' # 🎲️
@ -89,7 +78,7 @@ paste_symbol = '\u2199\ufe0f' # ↙
refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
apply_style_symbol = '\U0001f4cb' # 📋
clear_prompt_symbol = '\U0001F5D1' # 🗑️
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅
@ -179,14 +168,13 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface):
with FormRow(elem_id=target_interface + '_seed_row'):
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed')
reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed')
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
with gr.Group(elem_id=target_interface + '_subseed_show_box'):
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
@ -195,8 +183,8 @@ def create_seed_inputs(target_interface):
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed')
subseed.style(container=False)
random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed')
reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
random_subseed = ToolButton(random_symbol, elem_id=target_interface + '_random_subseed')
reuse_subseed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_subseed')
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength')
with FormRow(visible=False) as seed_extra_row_2:
@ -291,19 +279,19 @@ def create_toprow(is_img2img):
with gr.Row():
with gr.Column(scale=80):
with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)")
button_interrogate = None
button_deepbooru = None
if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"):
with gr.Column(scale=1, elem_classes="interrogate-col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
with gr.Row(elem_id=f"{id_part}_generate_box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
@ -325,9 +313,9 @@ def create_toprow(is_img2img):
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
negative_token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_negative_token_counter")
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
clear_prompt_button.click(
@ -479,7 +467,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count")
@ -492,7 +482,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes":
with FormRow(elem_id="txt2img_checkboxes", variant="compact"):
with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
@ -586,7 +576,7 @@ def create_ui():
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
txt_prompt_img.change(
fn=modules.images.image_data,
@ -757,7 +747,9 @@ def create_ui():
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="img2img_column_batch"):
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
@ -774,7 +766,7 @@ def create_ui():
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
elif category == "checkboxes":
with FormRow(elem_id="img2img_checkboxes", variant="compact"):
with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces")
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
@ -904,7 +896,7 @@ def create_ui():
img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args)
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height])
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
img2img_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
@ -1212,7 +1204,7 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
@ -1491,11 +1483,33 @@ def create_ui():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
with gr.Row():
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
with gr.TabItem("Licenses"):
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
def unload_sd_weights():
modules.sd_models.unload_model_weights()
def reload_sd_weights():
modules.sd_models.reload_model_weights()
unload_sd_model.click(
fn=unload_sd_weights,
inputs=[],
outputs=[]
)
reload_sd_model.click(
fn=reload_sd_weights,
inputs=[],
outputs=[]
)
request_notifications.click(
fn=lambda: None,
@ -1541,22 +1555,6 @@ def create_ui():
(train_interface, "Train", "ti"),
]
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
if os.path.exists(os.path.join(data_path, "user.css")):
with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file:
css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
@ -1567,7 +1565,7 @@ def create_ui():
for _interface, label, _ifid in interfaces:
shared.tab_names.append(label)
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings", variant="compact"):
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
component = create_setting_component(k, is_quicksettings=True)
@ -1598,11 +1596,13 @@ def create_ui():
for i, k, item in quicksettings_list:
component = component_dict[k]
info = opts.data_labels[k]
component.change(
fn=lambda value, k=k: run_settings_single(value, key=k),
inputs=[component],
outputs=[component, text_settings],
show_progress=info.refresh is not None,
)
text_settings.change(
@ -1628,6 +1628,7 @@ def create_ui():
fn=get_settings_values,
inputs=[],
outputs=[component_dict[k] for k in component_keys],
queue=False,
)
def modelmerger(*args):
@ -1704,7 +1705,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:
@ -1750,25 +1751,60 @@ def create_ui():
return demo
def reload_javascript():
def webpath(fn):
if fn.startswith(script_path):
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
else:
web_path = os.path.abspath(fn)
return f'file={web_path}?{os.path.getmtime(fn)}'
def javascript_html():
script_js = os.path.join(script_path, "script.js")
head = f'<script type="text/javascript" src="file={os.path.abspath(script_js)}?{os.path.getmtime(script_js)}"></script>\n'
head = f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
inline = f"{localization.localization_js(shared.opts.localization)};"
if cmd_opts.theme is not None:
inline += f"set_theme('{cmd_opts.theme}');"
for script in modules.scripts.list_scripts("javascript", ".js"):
head += f'<script type="text/javascript" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
for script in modules.scripts.list_scripts("javascript", ".mjs"):
head += f'<script type="module" src="file={script.path}?{os.path.getmtime(script.path)}"></script>\n'
head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
head += f'<script type="text/javascript">{inline}</script>\n'
return head
def css_html():
head = ""
def stylesheet(fn):
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
for cssfile in modules.scripts.list_files_with_name("style.css"):
if not os.path.isfile(cssfile):
continue
head += stylesheet(cssfile)
if os.path.exists(os.path.join(data_path, "user.css")):
head += stylesheet(os.path.join(data_path, "user.css"))
return head
def reload_javascript():
js = javascript_html()
css = css_html()
def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
res.init_headers()
return res

View File

@ -125,12 +125,12 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None
with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}"):
open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}')
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}')
@ -145,11 +145,10 @@ Requested path was: {f}
)
if tabname != "extras":
with gr.Row():
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
@ -160,6 +159,7 @@ Requested path was: {f}
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[generation_info, html_info, html_info],
outputs=[html_info, html_info],
show_progress=False,
)
save.click(
@ -195,7 +195,7 @@ Requested path was: {f}
else:
html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
html_info = gr.HTML(elem_id=f'html_info_{tabname}')
html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []

View File

@ -1,58 +1,74 @@
import gradio as gr
class ToolButton(gr.Button, gr.components.FormComponent):
class FormComponent:
def get_expected_parent(self):
return gr.components.Form
gr.Dropdown.get_expected_parent = FormComponent.get_expected_parent
class ToolButton(FormComponent, gr.Button):
"""Small button with single emoji as text, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool", **kwargs)
def __init__(self, *args, **kwargs):
classes = kwargs.pop("elem_classes", [])
super().__init__(*args, elem_classes=["tool", *classes], **kwargs)
def get_block_name(self):
return "button"
class ToolButtonTop(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, with extra margin at top, fits inside gradio forms"""
def __init__(self, **kwargs):
super().__init__(variant="tool-top", **kwargs)
def get_block_name(self):
return "button"
class FormRow(gr.Row, gr.components.FormComponent):
class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
class FormGroup(gr.Group, gr.components.FormComponent):
class FormColumn(FormComponent, gr.Column):
"""Same as gr.Column but fits inside gradio forms"""
def get_block_name(self):
return "column"
class FormGroup(FormComponent, gr.Group):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "group"
class FormHTML(gr.HTML, gr.components.FormComponent):
class FormHTML(FormComponent, gr.HTML):
"""Same as gr.HTML but fits inside gradio forms"""
def get_block_name(self):
return "html"
class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
class FormColorPicker(FormComponent, gr.ColorPicker):
"""Same as gr.ColorPicker but fits inside gradio forms"""
def get_block_name(self):
return "colorpicker"
class DropdownMulti(gr.Dropdown):
class DropdownMulti(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but always multiselect"""
def __init__(self, **kwargs):
super().__init__(multiselect=True, **kwargs)
def get_block_name(self):
return "dropdown"
class DropdownEditable(FormComponent, gr.Dropdown):
"""Same as gr.Dropdown but allows editing value"""
def __init__(self, **kwargs):
super().__init__(allow_custom_value=True, **kwargs)
def get_block_name(self):
return "dropdown"

View File

@ -1,6 +1,5 @@
import json
import os.path
import shutil
import sys
import time
import traceback
@ -22,7 +21,7 @@ def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabled because of command line flags"
def apply_and_restart(disable_list, update_list):
def apply_and_restart(disable_list, update_list, disable_all):
check_access()
disabled = json.loads(disable_list)
@ -44,6 +43,7 @@ def apply_and_restart(disable_list, update_list):
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
shared.state.interrupt()
@ -64,6 +64,9 @@ def check_updates(id_task, disable_list):
try:
ext.check_updates()
except FileNotFoundError as e:
if 'FETCH_HEAD' not in str(e):
raise
except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@ -88,6 +91,8 @@ def extension_table():
"""
for ext in extensions.extensions:
ext.read_info_from_repo()
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update:
@ -95,9 +100,13 @@ def extension_table():
else:
ext_status = ext.status
style = ""
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
style = ' style="color: var(--primary-400)"'
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td>{remote}</td>
<td>{ext.version}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
@ -120,7 +129,7 @@ def normalize_git_url(url):
return url
def install_extension_from_url(dirname, url):
def install_extension_from_url(dirname, branch_name, url):
check_access()
assert url, 'No URL specified'
@ -141,22 +150,27 @@ def install_extension_from_url(dirname, url):
try:
shutil.rmtree(tmpdir, True)
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
if branch_name == '':
# if no branch is specified, use the default branch
with git.Repo.clone_from(url, tmpdir) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
else:
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
repo.remote().fetch()
for submodule in repo.submodules:
submodule.update()
try:
os.rename(tmpdir, target_dir)
except OSError as err:
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
# Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir)
else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
raise(err)
raise err
import launch
launch.run_extension_installer(target_dir)
@ -167,12 +181,12 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url, hide_tags, sort_column):
def install_extension_from_index(url, hide_tags, sort_column, filter_text):
ext_table, message = install_extension_from_url(None, url)
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
return code, ext_table, message
return code, ext_table, message, ''
def refresh_available_extensions(url, hide_tags, sort_column):
@ -186,11 +200,17 @@ def refresh_available_extensions(url, hide_tags, sort_column):
code, tags = refresh_available_extensions_from_data(hide_tags, sort_column)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
return url, code, gr.CheckboxGroup.update(choices=tags), '', ''
def refresh_available_extensions_for_tags(hide_tags, sort_column):
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column)
def refresh_available_extensions_for_tags(hide_tags, sort_column, filter_text):
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
return code, ''
def search_extensions(filter_text, hide_tags, sort_column):
code, _ = refresh_available_extensions_from_data(hide_tags, sort_column, filter_text)
return code, ''
@ -205,7 +225,7 @@ sort_ordering = [
]
def refresh_available_extensions_from_data(hide_tags, sort_column):
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@ -244,7 +264,12 @@ def refresh_available_extensions_from_data(hide_tags, sort_column):
hidden += 1
continue
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
if filter_text and filter_text.strip():
if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():
hidden += 1
continue
install_code = f"""<button onclick="install_extension_from_index(this, '{html.escape(url)}')" {"disabled=disabled" if existing else ""} class="lg secondary gradio-button custom-button">{"Install" if not existing else "Installed"}</button>"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
@ -281,16 +306,24 @@ def create_ui():
with gr.Row(elem_id="extensions_installed_top"):
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
info = gr.HTML()
html = ""
if shared.opts.disable_all_extensions != "none":
html = """
<span style="color: var(--primary-400);">
"Disable all extensions" was set, change it to "none" to load all extensions again
</span>
"""
info = gr.HTML(html)
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
fn=apply_and_restart,
_js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list],
inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],
outputs=[],
)
@ -312,42 +345,52 @@ def create_ui():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False)
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install, hide_tags, sort_column],
inputs=[extension_to_install, hide_tags, sort_column, search_extensions_text],
outputs=[available_extensions_table, extensions_table, install_result],
)
search_extensions_text.change(
fn=modules.ui.wrap_gradio_call(search_extensions, extra_outputs=[gr.update()]),
inputs=[search_extensions_text, hide_tags, sort_column],
outputs=[available_extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags, sort_column],
inputs=[hide_tags, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
sort_column.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags, sort_column],
inputs=[hide_tags, sort_column, search_extensions_text],
outputs=[available_extensions_table, install_result]
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
inputs=[install_dirname, install_branch, install_url],
outputs=[extensions_table, install_result],
)

View File

@ -2,8 +2,10 @@ import glob
import os.path
import urllib.parse
from pathlib import Path
from PIL import PngImagePlugin
from modules import shared
from modules.images import read_info_from_image
import gradio as gr
import json
import html
@ -22,21 +24,37 @@ def register_page(page):
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
if ext not in (".png", ".jpg", ".webp"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
def get_metadata(page: str = "", item: str = ""):
from starlette.responses import JSONResponse
page = next(iter([x for x in extra_pages if x.name == page]), None)
if page is None:
return JSONResponse({})
metadata = page.metadata.get(item)
if metadata is None:
return JSONResponse({})
return JSONResponse({"metadata": metadata})
def add_pages_to_demo(app):
def fetch_file(filename: str = ""):
from starlette.responses import FileResponse
if not any([Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs]):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()
if ext not in (".png", ".jpg", ".webp"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
class ExtraNetworksPage:
@ -45,6 +63,7 @@ class ExtraNetworksPage:
self.name = title.lower()
self.card_page = shared.html("extra-networks-card.html")
self.allow_negative_prompt = False
self.metadata = {}
def refresh(self):
pass
@ -66,6 +85,8 @@ class ExtraNetworksPage:
view = shared.opts.extra_networks_default_view
items_html = ''
self.metadata = {}
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
@ -86,12 +107,16 @@ class ExtraNetworksPage:
subdirs = {"": 1, **subdirs}
subdirs_html = "".join([f"""
<button class='gr-button gr-button-lg gr-button-secondary{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'>
{html.escape(subdir if subdir!="" else "all")}
</button>
""" for subdir in subdirs])
for item in self.list_items():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
items_html += self.create_html_for_item(item, tabname)
if items_html == '':
@ -124,14 +149,16 @@ class ExtraNetworksPage:
if onclick is None:
onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'
height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else ''
width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else ''
background_image = f"background-image: url(\"{html.escape(preview)}\");" if preview else ''
metadata_button = ""
metadata = item.get("metadata")
if metadata:
metadata_onclick = '"' + html.escape(f"""extraNetworksShowMetadata({json.dumps(metadata)}); return false;""") + '"'
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick={metadata_onclick}></div>"
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
args = {
"preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
"style": f"'{height}{width}{background_image}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
@ -215,6 +242,7 @@ def create_ui(container, button, tabname):
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages:
with gr.Tab(page.title):
page_elem = gr.HTML(page.create_html(ui.tabname))
ui.pages.append(page_elem)
@ -226,10 +254,10 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible):
is_visible = not is_visible
return is_visible, gr.update(visible=is_visible)
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container])
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
def refresh():
res = []
@ -264,6 +292,7 @@ def setup_ui(ui, gallery):
img_info = images[index if index >= 0 else 0]
image = image_from_url_text(img_info)
geninfo, items = read_info_from_image(image)
is_allowed = False
for extra_page in ui.stored_extra_pages:
@ -273,7 +302,12 @@ def setup_ui(ui, gallery):
assert is_allowed, f'writing to {filename} is not allowed'
image.save(filename)
if geninfo:
pnginfo_data = PngImagePlugin.PngInfo()
pnginfo_data.add_text('parameters', geninfo)
image.save(filename, pnginfo=pnginfo_data)
else:
image.save(filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]

View File

@ -13,7 +13,7 @@ def create_ui():
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")

View File

@ -1,10 +1,11 @@
astunparse
blendmodes
accelerate
basicsr
fonts
font-roboto
gfpgan
gradio==3.16.2
gradio==3.27
invisible-watermark
numpy
omegaconf
@ -30,3 +31,4 @@ GitPython
torchsde
safetensors
psutil
rich

View File

@ -1,15 +1,15 @@
blendmodes==2022
transformers==4.25.1
accelerate==0.12.0
accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.16.2
numpy==1.23.3
gradio==3.27
numpy==1.23.5
Pillow==9.4.0
realesrgan==0.3.0
torch
omegaconf==2.2.3
pytorch_lightning==1.7.6
pytorch_lightning==1.9.4
scikit-image==0.19.2
fonts
font-roboto
@ -25,6 +25,6 @@ lark==1.1.2
inflection==0.5.1
GitPython==3.1.30
torchsde==0.2.5
safetensors==0.2.7
safetensors==0.3.1
httpcore<=0.15
fastapi==0.94.0

View File

@ -1,7 +1,9 @@
function gradioApp() {
const elems = document.getElementsByTagName('gradio-app')
const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
return !!gradioShadowRoot ? gradioShadowRoot : document;
const elem = elems.length == 0 ? document : elems[0]
if (elem !== document) elem.getElementById = function(id){ return document.getElementById(id) }
return elem.shadowRoot ? elem.shadowRoot : elem
}
function get_uiCurrentTab() {

View File

@ -1,9 +1,40 @@
import modules.scripts as scripts
import gradio as gr
import ast
import copy
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
def convertExpr2Expression(expr):
expr.lineno = 0
expr.col_offset = 0
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
return result
def exec_with_return(code, module):
"""
like exec() but can return values
https://stackoverflow.com/a/52361938/5862977
"""
code_ast = ast.parse(code)
init_ast = copy.deepcopy(code_ast)
init_ast.body = code_ast.body[:-1]
last_ast = copy.deepcopy(code_ast)
last_ast.body = code_ast.body[-1:]
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
if type(last_ast.body[0]) == ast.Expr:
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
else:
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
class Script(scripts.Script):
def title(self):
@ -13,12 +44,23 @@ class Script(scripts.Script):
return cmd_opts.allow_code
def ui(self, is_img2img):
code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
example = """from modules.processing import process_images
return [code]
p.width = 768
p.height = 768
p.batch_size = 2
p.steps = 10
return process_images(p)
"""
def run(self, p, code):
code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
return [code, indent_level]
def run(self, p, code, indent_level):
assert cmd_opts.allow_code, '--allow-code option must be enabled'
display_result_data = [[], -1, ""]
@ -29,13 +71,20 @@ class Script(scripts.Script):
display_result_data[2] = i
from types import ModuleType
compiled = compile(code, '', 'exec')
module = ModuleType("testmodule")
module.__dict__.update(globals())
module.p = p
module.display = display
exec(compiled, module.__dict__)
indent = " " * indent_level
indented = code.replace('\n', '\n' + indent)
body = f"""def __webuitemp__():
{indent}{indented}
__webuitemp__()"""
result = exec_with_return(body, module)
if isinstance(result, Processed):
return result
return Processed(p, *display_result_data)

View File

@ -6,23 +6,21 @@ from tqdm import trange
import modules.scripts as scripts
import gradio as gr
from modules import processing, shared, sd_samplers, prompt_parser, sd_samplers_common
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
from modules import processing, shared, sd_samplers, sd_samplers_common
import torch
import k_diffusion as K
from PIL import Image
from torch import autocast
from einops import rearrange, repeat
def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = p.init_latent
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(shared.sd_model)
if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps
@ -37,7 +35,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in)
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
@ -69,7 +67,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
x = p.init_latent
s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(shared.sd_model)
if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps
@ -84,7 +87,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
@ -125,7 +128,7 @@ class Script(scripts.Script):
def show(self, is_img2img):
return is_img2img
def ui(self, is_img2img):
def ui(self, is_img2img):
info = gr.Markdown('''
* `CFG Scale` should be 2 or lower.
''')
@ -213,4 +216,3 @@ class Script(scripts.Script):
processed = processing.process_images(p)
return processed

View File

@ -1,14 +1,10 @@
import numpy as np
from tqdm import trange
import math
import modules.scripts as scripts
import gradio as gr
from modules import processing, shared, sd_samplers, images
import modules.scripts as scripts
from modules import deepbooru, images, processing, shared
from modules.processing import Processed
from modules.sd_samplers import samplers
from modules.shared import opts, cmd_opts, state
from modules import deepbooru
from modules.shared import opts, state
class Script(scripts.Script):
@ -20,39 +16,65 @@ class Script(scripts.Script):
def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor"))
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
append_interrogation = gr.Dropdown(label="Append interrogated prompt at each iteration", choices=["None", "CLIP", "DeepBooru"], value="None")
return [loops, denoising_strength_change_factor, append_interrogation]
return [loops, final_denoising_strength, denoising_curve, append_interrogation]
def run(self, p, loops, denoising_strength_change_factor, append_interrogation):
def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation):
processing.fix_seed(p)
batch_count = p.n_iter
p.extra_generation_params = {
"Denoising strength change factor": denoising_strength_change_factor,
"Final denoising strength": final_denoising_strength,
"Denoising curve": denoising_curve
}
p.batch_size = 1
p.n_iter = 1
output_images, info = None, None
info = None
initial_seed = None
initial_info = None
initial_denoising_strength = p.denoising_strength
grids = []
all_images = []
original_init_image = p.init_images
original_prompt = p.prompt
original_inpainting_fill = p.inpainting_fill
state.job_count = loops * batch_count
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
for n in range(batch_count):
history = []
def calculate_denoising_strength(loop):
strength = initial_denoising_strength
if loops == 1:
return strength
progress = loop / (loops - 1)
if denoising_curve == "Aggressive":
strength = math.sin((progress) * math.pi * 0.5)
elif denoising_curve == "Lazy":
strength = 1 - math.cos((progress) * math.pi * 0.5)
else:
strength = progress
change = (final_denoising_strength - initial_denoising_strength) * strength
return initial_denoising_strength + change
history = []
for n in range(batch_count):
# Reset to original init image at the start of each batch
p.init_images = original_init_image
# Reset to original denoising strength
p.denoising_strength = initial_denoising_strength
last_image = None
for i in range(loops):
p.n_iter = 1
p.batch_size = 1
@ -72,26 +94,46 @@ class Script(scripts.Script):
processed = processing.process_images(p)
# Generation cancelled.
if state.interrupted:
break
if initial_seed is None:
initial_seed = processed.seed
initial_info = processed.info
init_img = processed.images[0]
p.init_images = [init_img]
p.seed = processed.seed + 1
p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
history.append(processed.images[0])
p.denoising_strength = calculate_denoising_strength(i + 1)
if state.skipped:
break
last_image = processed.images[0]
p.init_images = [last_image]
p.inpainting_fill = 1 # Set "masked content" to "original" for next loop.
if batch_count == 1:
history.append(last_image)
all_images.append(last_image)
if batch_count > 1 and not state.skipped and not state.interrupted:
history.append(last_image)
all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill
if state.interrupted:
break
if len(history) > 1:
grid = images.image_grid(history, rows=1)
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
grids.append(grid)
all_images += history
if opts.return_grid:
all_images = grids + all_images
if opts.return_grid:
grids.append(grid)
all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info)

View File

@ -4,8 +4,8 @@ import numpy as np
from modules import scripts_postprocessing, shared
import gradio as gr
from modules.ui_components import FormRow
from modules.ui_components import FormRow, ToolButton
from modules.ui import switch_values_symbol
upscale_cache = {}
@ -17,23 +17,29 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
def ui(self):
selected_tab = gr.State(value=0)
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with gr.Column():
with FormRow():
with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by:
upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize")
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow():
with gr.Column(elem_id="upscaling_column_size", scale=4):
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
with FormRow():
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
with FormRow():
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])

View File

@ -247,7 +247,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"
processed: Processed = cell(x, y, z)
processed: Processed = cell(x, y, z, ix, iy, iz)
if processed_result is None:
# Use our first processed result object as a template container to hold our full results
@ -374,16 +374,19 @@ class Script(scripts.Script):
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
@ -401,54 +404,74 @@ class Script(scripts.Script):
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
xy_swap_args = [x_type, x_values, y_type, y_values]
xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
yz_swap_args = [y_type, y_values, z_type, z_values]
yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
xz_swap_args = [x_type, x_values, z_type, z_values]
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type):
axis = self.current_axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update()
return axis.choices() if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
def select_axis(x_type):
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
def select_axis(axis_type,axis_values_dropdown):
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
current_values = axis_values_dropdown
if has_choices:
choices = choices()
if isinstance(current_values,str):
current_values = current_values.split(",")
current_values = list(filter(lambda x: x in choices, current_values))
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params):
val_key = axis + " Values"
vals = params.get(val_key,"")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist)
self.infotext_fields = (
(x_type, "X Type"),
(x_values, "X Values"),
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
(y_type, "Y Type"),
(y_values, "Y Values"),
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
(z_type, "Z Type"),
(z_values, "Z Values"),
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
)
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
if not opts.return_grid:
p.batch_size = 1
def process_axis(opt, vals):
def process_axis(opt, vals, vals_dropdown):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.choices is not None:
valslist = vals_dropdown
else:
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.type == int:
valslist_ext = []
@ -506,15 +529,22 @@ class Script(scripts.Script):
return valslist
x_opt = self.current_axis_options[x_type]
xs = process_axis(x_opt, x_values)
if x_opt.choices is not None:
x_values = ",".join(x_values_dropdown)
xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type]
ys = process_axis(y_opt, y_values)
if y_opt.choices is not None:
y_values = ",".join(y_values_dropdown)
ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type]
zs = process_axis(z_opt, z_values)
if z_opt.choices is not None:
z_values = ",".join(z_values_dropdown)
zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
@ -558,8 +588,6 @@ class Script(scripts.Script):
print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})")
shared.total_tqdm.updateTotal(total_steps)
grid_infotext = [None]
state.xyz_plot_x = AxisInfo(x_opt, xs)
state.xyz_plot_y = AxisInfo(y_opt, ys)
state.xyz_plot_z = AxisInfo(z_opt, zs)
@ -588,7 +616,9 @@ class Script(scripts.Script):
else:
second_axes_processed = 'y'
def cell(x, y, z):
grid_infotext = [None] * (1 + len(zs))
def cell(x, y, z, ix, iy, iz):
if shared.state.interrupted:
return Processed(p, [], p.seed, "")
@ -600,7 +630,9 @@ class Script(scripts.Script):
res = process_images(pc)
if grid_infotext[0] is None:
# Sets subgrid infotexts
subgrid_index = 1 + iz
if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:
pc.extra_generation_params = copy(pc.extra_generation_params)
pc.extra_generation_params['Script'] = self.title()
@ -616,6 +648,12 @@ class Script(scripts.Script):
if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds:
pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys])
grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)
# Sets main grid infotext
if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:
pc.extra_generation_params = copy(pc.extra_generation_params)
if z_opt.label != 'Nothing':
pc.extra_generation_params["Z Type"] = z_opt.label
pc.extra_generation_params["Z Values"] = z_values
@ -650,6 +688,9 @@ class Script(scripts.Script):
z_count = len(zs)
# Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)
processed.infotexts[:1+z_count] = grid_infotext[:1+z_count]
if not include_lone_images:
# Don't need sub-images anymore, drop from list:
processed.images = processed.images[:z_count+1]

851
style.css

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,7 @@ fi
export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1

View File

@ -43,4 +43,7 @@
# Uncomment to enable accelerated launch
#export ACCELERATE="True"
# Uncomment to disable TCMalloc
#export NO_TCMALLOC="True"
###########################################

View File

@ -4,6 +4,7 @@ import time
import importlib
import signal
import re
import warnings
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@ -17,6 +18,11 @@ from modules import paths, timer, import_hook, errors
startup_timer = timer.Timer()
import torch
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
startup_timer.record("import torch")
import gradio
@ -64,11 +70,51 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def check_versions():
if shared.cmd_opts.skip_version_check:
return
expected_torch_version = "1.13.1"
expected_torch_version = "2.0.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
@ -81,7 +127,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
expected_xformers_version = "0.0.17"
if shared.xformers_available:
import xformers
@ -96,6 +142,8 @@ Use --skip-version-check commandline argument to disable this check.
def initialize():
fix_asyncio_event_loop_policy()
check_versions()
extensions.list_extensions()
@ -123,9 +171,6 @@ def initialize():
modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
@ -147,6 +192,7 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
@ -240,7 +286,7 @@ def webui():
shared.demo = modules.ui.create_ui()
startup_timer.record("create ui")
if cmd_opts.gradio_queue:
if not cmd_opts.no_gradio_queue:
shared.demo.queue(64)
gradio_auth_creds = []

View File

@ -23,7 +23,7 @@ fi
# Install directory without trailing slash
if [[ -z "${install_dir}" ]]
then
install_dir="/home/$(whoami)"
install_dir="${HOME}"
fi
# Name of the subdirectory (defaults to stable-diffusion-webui)
@ -113,13 +113,13 @@ case "$gpu_info" in
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
printf "\n%s\n" "${delimiter}"
;;
*)
*)
;;
esac
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
fi
fi
for preq in "${GIT}" "${python_cmd}"
do
@ -172,15 +172,30 @@ else
exit 1
fi
# Try using TCMalloc on Linux
prepare_tcmalloc() {
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
if [[ ! -z "${TCMALLOC}" ]]; then
echo "Using TCMalloc: ${TCMALLOC}"
export LD_PRELOAD="${TCMALLOC}"
else
printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
fi
fi
}
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
then
printf "\n%s\n" "${delimiter}"
printf "Accelerating launch.py..."
printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
else
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
printf "\n%s\n" "${delimiter}"
prepare_tcmalloc
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
fi