Merge branch 'dev' into master

This commit is contained in:
AUTOMATIC1111 2023-07-13 15:21:39 +03:00 committed by GitHub
commit b7c5b30f14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 704 additions and 389 deletions

View File

@ -18,7 +18,7 @@ jobs:
# not to have GHA download an (at the time of writing) 4 GB cache # not to have GHA download an (at the time of writing) 4 GB cache
# of PyTorch and other dependencies. # of PyTorch and other dependencies.
- name: Install Ruff - name: Install Ruff
run: pip install ruff==0.0.265 run: pip install ruff==0.0.272
- name: Run Ruff - name: Run Ruff
run: ruff . run: ruff .
lint-js: lint-js:

View File

@ -42,7 +42,7 @@ jobs:
--no-half --no-half
--disable-opt-split-attention --disable-opt-split-attention
--use-cpu all --use-cpu all
--add-stop-route --api-server-stop
2>&1 | tee output.txt & 2>&1 | tee output.txt &
- name: Run tests - name: Run tests
run: | run: |
@ -50,7 +50,7 @@ jobs:
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
- name: Kill test server - name: Kill test server
if: always() if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/_stop && sleep 10 run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
- name: Show coverage - name: Show coverage
run: | run: |
python -m coverage combine .coverage* python -m coverage combine .coverage*

View File

@ -135,8 +135,11 @@ Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-w
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
## Documentation ## Documentation
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
For the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki).
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.

View File

@ -12,7 +12,7 @@ import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack from modules import shared, sd_hijack, devices
cached_ldsr_model: torch.nn.Module = None cached_ldsr_model: torch.nn.Module = None
@ -112,8 +112,7 @@ class LDSR:
gc.collect() gc.collect()
if torch.cuda.is_available: devices.torch_gc()
torch.cuda.empty_cache()
im_og = image im_og = image
width_og, height_og = im_og.size width_og, height_og = im_og.size
@ -150,8 +149,7 @@ class LDSR:
del model del model
gc.collect() gc.collect()
if torch.cuda.is_available: devices.torch_gc()
torch.cuda.empty_cache()
return a return a

View File

@ -1,7 +1,6 @@
import os import os
from basicsr.utils.download_util import load_file_from_url from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors from modules import shared, script_callbacks, errors
@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
if local_safetensors_path is not None and os.path.exists(local_safetensors_path): if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
model = local_safetensors_path model = local_safetensors_path
else: else:
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True) model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True) yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
try:
return LDSR(model, yaml) return LDSR(model, yaml)
except Exception:
errors.report("Error importing LDSR", exc_info=True)
return None
def do_upscale(self, img, path): def do_upscale(self, img, path):
try:
ldsr = self.load_model(path) ldsr = self.load_model(path)
if ldsr is None: except Exception:
print("NO LDSR!") errors.report(f"Failed loading LDSR model {path}", exc_info=True)
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)

View File

@ -443,7 +443,7 @@ def list_available_loras():
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower): for filename in candidates:
if os.path.isdir(filename): if os.path.isdir(filename):
continue continue

View File

@ -1,4 +1,3 @@
import os.path
import sys import sys
import PIL.Image import PIL.Image
@ -6,12 +5,11 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts from modules.shared import opts
@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers = [] scalers = []
add_model2 = True add_model2 = True
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
@ -87,11 +85,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def do_upscale(self, img: PIL.Image.Image, selected_file): def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache() devices.torch_gc()
try:
model = self.load_model(selected_file) model = self.load_model(selected_file)
if model is None: except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img return img
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
@ -111,7 +110,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output del torch_img, torch_output
torch.cuda.empty_cache() devices.torch_gc()
output = np_output.transpose((1, 2, 0)) # CHW to HWC output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB output = output[:, :, ::-1] # BGR to RGB
@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def load_model(self, path: str): def load_model(self, path: str):
device = devices.get_device_for('scunet') device = devices.get_device_for('scunet')
if "http" in path: if path.startswith("http"):
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True) # TODO: this doesn't use `path` at all?
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True) model.load_state_dict(torch.load(filename), strict=True)
model.eval() model.eval()
for _, v in model.named_parameters(): for _, v in model.named_parameters():

View File

@ -1,34 +1,35 @@
import os import sys
import platform
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state from modules.shared import opts, state
from swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir = devices.get_device_for('swinir') device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
self.name = "SwinIR" self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ self.model_url = SWINIR_MODEL_URL
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
self.user_path = dirname self.user_path = dirname
super().__init__() super().__init__()
scalers = [] scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"]) model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files: for model in model_files:
if "http" in model: if model.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(model) name = modelloader.friendly_name(model)
@ -37,27 +38,39 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img, model_file):
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
current_config = (model_file, opts.SWIN_tile)
if use_compile and self._cached_model_config == current_config:
model = self._cached_model
else:
self._cached_model = None
try:
model = self.load_model(model_file) model = self.load_model(model_file)
if model is None: except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img return img
model = model.to(device_swinir, dtype=devices.dtype) model = model.to(device_swinir, dtype=devices.dtype)
if use_compile:
model = torch.compile(model)
self._cached_model = model
self._cached_model_config = current_config
img = upscale(img, model) img = upscale(img, model)
try: devices.torch_gc()
torch.cuda.empty_cache()
except Exception:
pass
return img return img
def load_model(self, path, scale=4): def load_model(self, path, scale=4):
if "http" in path: if path.startswith("http"):
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth") filename = modelloader.load_file_from_url(
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True) url=path,
model_dir=self.model_download_path,
file_name=f"{self.model_name.replace(' ', '_')}.pth",
)
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename):
return None
if filename.endswith(".v2.pth"): if filename.endswith(".v2.pth"):
model = net2( model = Swin2SR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
@ -72,7 +85,7 @@ class UpscalerSwinIR(Upscaler):
) )
params = None params = None
else: else:
model = net( model = SwinIR(
upscale=scale, upscale=scale,
in_chans=3, in_chans=3,
img_size=64, img_size=64,
@ -172,6 +185,8 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)

View File

@ -200,7 +200,8 @@ onUiLoaded(async() => {
canvas_hotkey_move: "KeyF", canvas_hotkey_move: "KeyF",
canvas_hotkey_overlap: "KeyO", canvas_hotkey_overlap: "KeyO",
canvas_disabled_functions: [], canvas_disabled_functions: [],
canvas_show_tooltip: true canvas_show_tooltip: true,
canvas_blur_prompt: false
}; };
const functionMap = { const functionMap = {
@ -608,6 +609,19 @@ onUiLoaded(async() => {
// Handle keydown events // Handle keydown events
function handleKeyDown(event) { function handleKeyDown(event) {
// Disable key locks to make pasting from the buffer work correctly
if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === "F5") {
return;
}
// before activating shortcut, ensure user is not actively typing in an input field
if (!hotkeysConfig.canvas_blur_prompt) {
if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {
return;
}
}
const hotkeyActions = { const hotkeyActions = {
[hotkeysConfig.canvas_hotkey_reset]: resetZoom, [hotkeysConfig.canvas_hotkey_reset]: resetZoom,
[hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,
@ -686,6 +700,20 @@ onUiLoaded(async() => {
// Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.
function handleMoveKeyDown(e) { function handleMoveKeyDown(e) {
// Disable key locks to make pasting from the buffer work correctly
if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === "F5") {
return;
}
// before activating shortcut, ensure user is not actively typing in an input field
if (!hotkeysConfig.canvas_blur_prompt) {
if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {
return;
}
}
if (e.code === hotkeysConfig.canvas_hotkey_move) { if (e.code === hotkeysConfig.canvas_hotkey_move) {
if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {
e.preventDefault(); e.preventDefault();

View File

@ -9,5 +9,6 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
})) }))

View File

@ -100,11 +100,12 @@ function keyupEditAttention(event) {
if (String(weight).length == 1) weight += ".0"; if (String(weight).length == 1) weight += ".0";
if (closeCharacter == ')' && weight == 1) { if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5); var endParenPos = text.substring(selectionEnd).indexOf(')');
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);
selectionStart--; selectionStart--;
selectionEnd--; selectionEnd--;
} else { } else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
} }
target.focus(); target.focus();

41
javascript/edit-order.js Normal file
View File

@ -0,0 +1,41 @@
/* alt+left/right moves text in prompt */
function keyupEditOrder(event) {
if (!opts.keyedit_move) return;
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
if (!event.altKey) return;
let isLeft = event.key == "ArrowLeft";
let isRight = event.key == "ArrowRight";
if (!isLeft && !isRight) return;
event.preventDefault();
let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
let text = target.value;
let items = text.split(",");
let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;
let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;
let range = indexEnd - indexStart + 1;
if (isLeft && indexStart > 0) {
items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));
target.value = items.join();
target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);
target.selectionEnd = items.slice(0, indexEnd).join().length;
} else if (isRight && indexEnd < items.length - 1) {
items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));
target.value = items.join();
target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;
target.selectionEnd = items.slice(0, indexEnd + 2).join().length;
}
event.preventDefault();
updateInput(target);
}
addEventListener('keydown', (event) => {
keyupEditOrder(event);
});

View File

@ -72,3 +72,21 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type)
} }
return [confirmed, config_state_name, config_restore_type]; return [confirmed, config_state_name, config_restore_type];
} }
function toggle_all_extensions(event) {
gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {
checkbox_el.checked = event.target.checked;
});
}
function toggle_extension() {
let all_extensions_toggled = true;
for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {
if (!checkbox_el.checked) {
all_extensions_toggled = false;
break;
}
}
gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;
}

View File

@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@ -22,7 +22,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_vae import vae_dict from modules.sd_vae import vae_dict
from modules.sd_models_config import find_checkpoint_config_near_filename from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
@ -30,13 +30,7 @@ from modules import devices
from typing import Dict, List, Any from typing import Dict, List, Any
import piexif import piexif
import piexif.helper import piexif.helper
from contextlib import closing
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
def script_name_to_index(name, scripts): def script_name_to_index(name, scripts):
@ -84,6 +78,8 @@ def encode_pil_to_base64(image):
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
if image.mode == "RGBA":
image = image.convert("RGB")
parameters = image.info.get('parameters', None) parameters = image.info.get('parameters', None)
exif_bytes = piexif.dump({ exif_bytes = piexif.dump({
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
@ -209,6 +205,11 @@ class Api:
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
if shared.cmd_opts.api_server_stop:
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
self.default_script_arg_txt2img = [] self.default_script_arg_txt2img = []
self.default_script_arg_img2img = [] self.default_script_arg_img2img = []
@ -324,12 +325,12 @@ class Api:
args.pop('save_images', None) args.pop('save_images', None)
with self.queue_lock: with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
p.scripts = script_runner p.scripts = script_runner
p.outpath_grids = opts.outdir_txt2img_grids p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin() shared.state.begin(job="scripts_txt2img")
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
@ -380,13 +381,13 @@ class Api:
args.pop('save_images', None) args.pop('save_images', None)
with self.queue_lock: with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
p.init_images = [decode_base64_to_image(x) for x in init_images] p.init_images = [decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner p.scripts = script_runner
p.outpath_grids = opts.outdir_img2img_grids p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin() shared.state.begin(job="scripts_img2img")
if selectable_scripts is not None: if selectable_scripts is not None:
p.script_args = script_args p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
@ -517,6 +518,10 @@ class Api:
return options return options
def set_config(self, req: Dict[str, Any]): def set_config(self, req: Dict[str, Any]):
checkpoint_name = req.get("sd_model_checkpoint", None)
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
raise RuntimeError(f"model {checkpoint_name!r} not found")
for k, v in req.items(): for k, v in req.items():
shared.opts.set(k, v) shared.opts.set(k, v)
@ -598,44 +603,42 @@ class Api:
def create_embedding(self, args: dict): def create_embedding(self, args: dict):
try: try:
shared.state.begin() shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}") return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}") return models.TrainResponse(info=f"create embedding error: {e}")
finally:
shared.state.end()
def create_hypernetwork(self, args: dict): def create_hypernetwork(self, args: dict):
try: try:
shared.state.begin() shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}") return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e: except AssertionError as e:
shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}") return models.TrainResponse(info=f"create hypernetwork error: {e}")
finally:
shared.state.end()
def preprocess(self, args: dict): def preprocess(self, args: dict):
try: try:
shared.state.begin() shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end() shared.state.end()
return models.PreprocessResponse(info='preprocess complete') return models.PreprocessResponse(info='preprocess complete')
except KeyError as e: except KeyError as e:
shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}") return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e: except Exception as e:
shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: {e}") return models.PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e: finally:
shared.state.end() shared.state.end()
return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict): def train_embedding(self, args: dict):
try: try:
shared.state.begin() shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations apply_optimizations = shared.opts.training_xattention_optimizations
error = None error = None
filename = '' filename = ''
@ -648,15 +651,15 @@ class Api:
finally: finally:
if not apply_optimizations: if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg: except Exception as msg:
shared.state.end()
return models.TrainResponse(info=f"train embedding error: {msg}") return models.TrainResponse(info=f"train embedding error: {msg}")
finally:
shared.state.end()
def train_hypernetwork(self, args: dict): def train_hypernetwork(self, args: dict):
try: try:
shared.state.begin() shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = [] shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations apply_optimizations = shared.opts.training_xattention_optimizations
error = None error = None
@ -674,9 +677,10 @@ class Api:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
shared.state.end() shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}") return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError: except Exception as exc:
return models.TrainResponse(info=f"train embedding error: {exc}")
finally:
shared.state.end() shared.state.end()
return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self): def get_memory(self):
try: try:
@ -716,3 +720,16 @@ class Api:
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
def kill_webui(self):
restart.stop_program()
def restart_webui(self):
if restart.is_restartable():
restart.restart_program()
return Response(status_code=501)
def stop_webui(request):
shared.state.server_command = "stop"
return Response("Stopping.")

View File

@ -274,10 +274,6 @@ class PromptStyleItem(BaseModel):
prompt: Optional[str] = Field(title="Prompt") prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt") negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
class EmbeddingItem(BaseModel): class EmbeddingItem(BaseModel):
step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")

View File

@ -1,3 +1,4 @@
from functools import wraps
import html import html
import threading import threading
import time import time
@ -18,6 +19,7 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
@wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
# if the first argument is a string that says "task(...)", it is treated as a job id # if the first argument is a string that says "task(...)", it is treated as a job id
@ -28,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None id_task = None
with queue_lock: with queue_lock:
shared.state.begin() shared.state.begin(job=id_task)
progress.start_task(id_task) progress.start_task(id_task)
try: try:
@ -45,6 +47,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
@wraps(func)
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:

View File

@ -107,4 +107,5 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')

View File

@ -15,7 +15,6 @@ model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None codeformer = None
@ -100,7 +99,7 @@ def setup_model(dirname):
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output del output
torch.cuda.empty_cache() devices.torch_gc()
except Exception: except Exception:
errors.report('Failed inference for CodeFormer', exc_info=True) errors.report('Failed inference for CodeFormer', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
@ -123,9 +122,6 @@ def setup_model(dirname):
return restored_img return restored_img
global have_codeformer
have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer(dirname) codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)

View File

@ -15,13 +15,6 @@ def has_mps() -> bool:
else: else:
return mac_specific.has_mps return mac_specific.has_mps
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]:
return args[x + 1]
return None
def get_cuda_device_string(): def get_cuda_device_string():
from modules import shared from modules import shared
@ -56,11 +49,15 @@ def get_device_for(task):
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()): with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32(): def enable_tf32():
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -1,15 +1,13 @@
import os import sys
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict): def mod2normal(state_dict):
@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data) scalers.append(scaler_data)
for file in model_paths: for file in model_paths:
if "http" in file: if file.startswith("http"):
name = self.model_name name = self.model_name
else: else:
name = modelloader.friendly_name(file) name = modelloader.friendly_name(file)
@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data) self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model): def do_upscale(self, img, selected_model):
try:
model = self.load_model(selected_model) model = self.load_model(selected_model)
if model is None: except Exception as e:
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img return img
model.to(devices.device_esrgan) model.to(devices.device_esrgan)
img = esrgan_upscale(model, img) img = esrgan_upscale(model, img)
return img return img
def load_model(self, path: str): def load_model(self, path: str):
if "http" in path: if path.startswith("http"):
filename = load_file_from_url( # TODO: this doesn't use `path` at all?
filename = modelloader.load_file_from_url(
url=self.model_url, url=self.model_url,
model_dir=self.model_download_path, model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth", file_name=f"{self.model_name}.pth",
progress=True,
) )
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None:
print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

View File

@ -103,6 +103,9 @@ def activate(p, extra_network_data):
except Exception as e: except Exception as e:
errors.display(e, f"activating extra network {extra_network_name}") errors.display(e, f"activating extra network {extra_network_name}")
if p.scripts is not None:
p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)
def deactivate(p, extra_network_data): def deactivate(p, extra_network_data):
"""call deactivate for extra networks in extra_network_data in specified order, then call """call deactivate for extra networks in extra_network_data in specified order, then call

View File

@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
shared.state.begin() shared.state.begin(job="model-merge")
shared.state.job = 'model-merge'
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message

View File

@ -174,31 +174,6 @@ def send_image_and_dimensions(x):
return img, w, h return img, w, h
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
If the infotext has no hash, then a hypernet with the same name will be selected instead.
"""
hypernet_name = hypernet_name.lower()
if hypernet_hash is not None:
# Try to match the hash in the name
for hypernet_key in shared.hypernetworks.keys():
result = re_hypernet_hash.search(hypernet_key)
if result is not None and result[1] == hypernet_hash:
return hypernet_key
else:
# Fall back to a hypernet with the same name
for hypernet_key in shared.hypernetworks.keys():
if hypernet_key.lower().startswith(hypernet_name):
return hypernet_key
return None
def restore_old_hires_fix_params(res): def restore_old_hires_fix_params(res):
"""for infotexts that specify old First pass size parameter, convert it into """for infotexts that specify old First pass size parameter, convert it into
width, height, and hr scale""" width, height, and hr scale"""
@ -332,10 +307,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res return res
settings_map = {}
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ), ('Clip skip', 'CLIP_stop_at_last_layers', ),
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),

View File

@ -25,7 +25,7 @@ def gfpgann():
return None return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]: if len(models) == 1 and models[0].startswith("http"):
model_file = models[0] model_file = models[0]
elif len(models) != 0: elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) latest_file = max(models, key=os.path.getctime)

View File

@ -3,6 +3,7 @@ import glob
import html import html
import os import os
import inspect import inspect
from contextlib import closing
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
import torch import torch
@ -353,17 +354,6 @@ def load_hypernetworks(names, multipliers=None):
shared.loaded_hypernetworks.append(hypernetwork) shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
@ -446,18 +436,6 @@ def statistics(data):
return total_information, recent_information return total_information, recent_information
def report_statistics(loss_info:dict):
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
for key in keys:
try:
print("Loss statistics for file " + key)
info, recent = statistics(list(loss_info[key]))
print(info)
print(recent)
except Exception as e:
print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
@ -734,6 +712,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
preview_text = p.prompt preview_text = p.prompt
with closing(p):
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None image = processed.images[0] if len(processed.images) > 0 else None
@ -770,7 +749,6 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False pbar.leave = False
pbar.close() pbar.close()
hypernetwork.eval() hypernetwork.eval()
#report_statistics(loss_dict)
sd_hijack_checkpoint.remove() sd_hijack_checkpoint.remove()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime import datetime
import pytz import pytz
@ -10,7 +12,7 @@ import re
import numpy as np import numpy as np
import piexif import piexif
import piexif.helper import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
import string import string
import json import json
import hashlib import hashlib
@ -139,6 +141,11 @@ class GridAnnotation:
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0): def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')
color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')
color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')
def wrap(drawing, text, font, line_length): def wrap(drawing, text, font, line_length):
lines = [''] lines = ['']
for word in text.split(): for word in text.split():
@ -168,9 +175,6 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
fnt = get_font(fontsize) fnt = get_font(fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
cols = im.width // width cols = im.width // width
@ -179,7 +183,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}' assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}' assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
calc_img = Image.new("RGB", (1, 1), "white") calc_img = Image.new("RGB", (1, 1), color_background)
calc_d = ImageDraw.Draw(calc_img) calc_d = ImageDraw.Draw(calc_img)
for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)): for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
@ -200,7 +204,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), "white") result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)
for row in range(rows): for row in range(rows):
for col in range(cols): for col in range(cols):
@ -302,10 +306,12 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
if ratio < src_ratio: if ratio < src_ratio:
fill_height = height // 2 - src_h // 2 fill_height = height // 2 - src_h // 2
if fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio: elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2 fill_width = width // 2 - src_w // 2
if fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
@ -372,8 +378,8 @@ class FilenameGenerator:
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..] 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
'user': lambda self: self.p.user,
'vae_filename': lambda self: self.get_vae_filename(), 'vae_filename': lambda self: self.get_vae_filename(),
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
@ -497,13 +503,23 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None): def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):
"""
Saves image to filename, including geninfo as text information for generation info.
For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.
For JPG images, there's no dictionary and geninfo just replaces the EXIF description.
"""
if extension is None: if extension is None:
extension = os.path.splitext(filename)[1] extension = os.path.splitext(filename)[1]
image_format = Image.registered_extensions()[extension] image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png': if extension.lower() == '.png':
existing_pnginfo = existing_pnginfo or {}
if opts.enable_pnginfo:
existing_pnginfo[pnginfo_section_name] = geninfo
if opts.enable_pnginfo: if opts.enable_pnginfo:
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
for k, v in (existing_pnginfo or {}).items(): for k, v in (existing_pnginfo or {}).items():
@ -622,7 +638,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
""" """
temp_file_path = f"{filename_without_extension}.tmp" temp_file_path = f"{filename_without_extension}.tmp"
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo) save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
os.replace(temp_file_path, filename_without_extension + extension) os.replace(temp_file_path, filename_without_extension + extension)
@ -639,12 +655,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
oversize = image.width > opts.target_side_length or image.height > opts.target_side_length oversize = image.width > opts.target_side_length or image.height > opts.target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024): if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):
ratio = image.width / image.height ratio = image.width / image.height
resize_to = None
if oversize and ratio > 1: if oversize and ratio > 1:
image = image.resize((round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)), LANCZOS) resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)
elif oversize: elif oversize:
image = image.resize((round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)), LANCZOS) resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)
if resize_to is not None:
try:
# Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16
image = image.resize(resize_to, LANCZOS)
except Exception:
image = image.resize(resize_to)
try: try:
_atomically_save_image(image, fullfn_without_extension, ".jpg") _atomically_save_image(image, fullfn_without_extension, ".jpg")
except Exception as e: except Exception as e:
@ -662,8 +684,15 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn return fullfn, txt_fullfn
def read_info_from_image(image): IGNORED_INFO_KEYS = {
items = image.info or {} 'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
'icc_profile', 'chromaticity', 'photoshop',
}
def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
items = (image.info or {}).copy()
geninfo = items.pop('parameters', None) geninfo = items.pop('parameters', None)
@ -679,9 +708,7 @@ def read_info_from_image(image):
items['exif comment'] = exif_comment items['exif comment'] = exif_comment
geninfo = exif_comment geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', for field in IGNORED_INFO_KEYS:
'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',
'icc_profile', 'chromaticity']:
items.pop(field, None) items.pop(field, None)
if items.get("Software", None) == "NovelAI": if items.get("Software", None) == "NovelAI":

View File

@ -1,23 +1,26 @@
import os import os
from contextlib import closing
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
import gradio as gr
from modules import sd_samplers from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
from modules.images import save_image
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import modules.scripts import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
processing.fix_seed(p) processing.fix_seed(p)
images = shared.listfiles(input_dir) images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
is_inpaint_batch = False is_inpaint_batch = False
if inpaint_mask_dir: if inpaint_mask_dir:
@ -36,6 +39,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
state.job_count = len(images) * p.n_iter state.job_count = len(images) * p.n_iter
# extract "default" params to use in case getting png info fails
prompt = p.prompt
negative_prompt = p.negative_prompt
seed = p.seed
cfg_scale = p.cfg_scale
sampler_name = p.sampler_name
steps = p.steps
for i, image in enumerate(images): for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(images)}"
if state.skipped: if state.skipped:
@ -79,25 +90,45 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
mask_image = Image.open(mask_image_path) mask_image = Image.open(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image
if use_png_info:
try:
info_img = img
if png_info_dir:
info_img_path = os.path.join(png_info_dir, os.path.basename(image))
info_img = Image.open(info_img_path)
geninfo, _ = imgutil.read_info_from_image(info_img)
parsed_parameters = parse_generation_parameters(geninfo)
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
except Exception:
parsed_parameters = {}
p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
p.seed = int(parsed_parameters.get("Seed", seed))
p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
p.steps = int(parsed_parameters.get("Steps", steps))
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None: if proc is None:
proc = process_images(p) proc = process_images(p)
for n, processed_image in enumerate(proc.images): for n, processed_image in enumerate(proc.images):
filename = image_path.name filename = image_path.stem
infotext = proc.infotext(p, n)
relpath = os.path.dirname(os.path.relpath(image, input_dir))
if n > 0: if n > 0:
left, right = os.path.splitext(filename) filename += f"-{n}"
filename = f"{left}-{n}{right}"
if not save_normally: if not save_normally:
os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA': if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB") processed_image = processed_image.convert("RGB")
processed_image.save(os.path.join(output_dir, filename)) save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
@ -180,16 +211,19 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
p.scripts = modules.scripts.scripts_img2img p.scripts = modules.scripts.scripts_img2img
p.script_args = args p.script_args = args
p.user = request.username
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if mask: if mask:
p.extra_generation_params["Mask blur"] = mask_blur p.extra_generation_params["Mask blur"] = mask_blur
with closing(p):
if is_batch: if is_batch:
assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled"
process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by) process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)
processed = Processed(p, [], p.seed, "") processed = Processed(p, [], p.seed, "")
else: else:
@ -197,8 +231,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()

View File

@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image): def interrogate(self, pil_image):
res = "" res = ""
shared.state.begin() shared.state.begin(job="interrogate")
shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()

View File

@ -142,15 +142,15 @@ def git_clone(url, dir, name, commithash=None):
if commithash is None: if commithash is None:
return return
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip()
if current_hash == commithash: if current_hash == commithash:
return return
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True)
if commithash is not None: if commithash is not None:
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")

View File

@ -1,12 +1,19 @@
import logging
import torch import torch
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
log = logging.getLogger(__name__)
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool: def check_for_mps() -> bool:
if version.parse(torch.__version__) <= version.parse("2.0.1"):
if not getattr(torch, 'has_mps', False): if not getattr(torch, 'has_mps', False):
return False return False
try: try:
@ -14,9 +21,25 @@ def check_for_mps() -> bool:
return True return True
except Exception: except Exception:
return False return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps() has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
from modules.shared import state
if state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection")
return
from torch.mps import empty_cache
empty_cache()
except Exception:
log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs): def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps': if input.device.type == 'mps':

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import shutil import shutil
import importlib import importlib
@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_file_from_url(
url: str,
*,
model_dir: str,
progress: bool = True,
file_name: str | None = None,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress)
return cached_file
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0: if model_url is not None and len(output) == 0:
if download_name is not None: if download_name is not None:
from basicsr.utils.download_util import load_file_from_url output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
dl = load_file_from_url(model_url, places[0], True, download_name)
output.append(dl)
else: else:
output.append(model_url) output.append(model_url)
@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str): def friendly_name(file: str):
if "http" in file: if file.startswith("http"):
file = urlparse(file).path file = urlparse(file).path
file = os.path.basename(file) file = os.path.basename(file)

View File

@ -38,17 +38,3 @@ for d, must_exist, what, options in path_dirs:
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
class Prioritize:
def __init__(self, name):
self.name = name
self.path = None
def __enter__(self):
self.path = sys.path.copy()
sys.path = [paths[self.name]] + sys.path
def __exit__(self, exc_type, exc_val, exc_tb):
sys.path = self.path
self.path = None

View File

@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc() devices.torch_gc()
shared.state.begin() shared.state.begin(job="extras")
shared.state.job = 'extras'
image_data = [] image_data = []
image_names = [] image_names = []
@ -54,7 +53,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
for image, name in zip(image_data, image_names): for image, name in zip(image_data, image_names):
shared.state.textinfo = name shared.state.textinfo = name
existing_pnginfo = image.info or {} parameters, existing_pnginfo = images.read_info_from_image(image)
if parameters:
existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))

View File

@ -184,6 +184,8 @@ class StableDiffusionProcessing:
self.uc = None self.uc = None
self.c = None self.c = None
self.user = None
@property @property
def sd_model(self): def sd_model(self):
return shared.sd_model return shared.sd_model
@ -549,7 +551,7 @@ def program_version():
return res return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
@ -573,7 +575,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
@ -585,13 +587,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params, **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None, "Version": program_version() if opts.add_version_to_infotext else None,
"User": p.user if opts.add_user_name_to_info else None,
} }
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed: def process_images(p: StableDiffusionProcessing) -> Processed:
@ -602,7 +606,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try: try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None) p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights() sd_models.reload_model_weights()
@ -663,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else: else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
@ -824,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
text = infotext() text = infotext(use_main_prompt=True)
infotexts.insert(0, text) infotexts.insert(0, text)
if opts.enable_pnginfo: if opts.enable_pnginfo:
grid.info["parameters"] = text grid.info["parameters"] = text
@ -832,7 +836,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1 index_of_first_image = 1
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
if not p.disable_extra_networks and p.extra_network_data: if not p.disable_extra_networks and p.extra_network_data:
extra_networks.deactivate(p, p.extra_network_data) extra_networks.deactivate(p, p.extra_network_data)
@ -1074,6 +1078,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
if self.scripts is not None:
self.scripts.before_hr(self)
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

View File

@ -2,7 +2,6 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if not self.enable: if not self.enable:
return img return img
try:
info = self.load_model(path) info = self.load_model(path)
if not os.path.exists(info.local_data_path): except Exception:
print(f"Unable to load RealESRGAN model: {info.name}") errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img return img
upsampler = RealESRGANer( upsampler = RealESRGANer(
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return image return image
def load_model(self, path): def load_model(self, path):
try: for scaler in self.scalers:
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) if scaler.data_path == path:
if scaler.local_data_path.startswith("http"):
if info is None: scaler.local_data_path = modelloader.load_file_from_url(
print(f"Unable to find model info: {path}") scaler.data_path,
return None model_dir=self.model_download_path,
)
if info.local_data_path.startswith("http"): if not os.path.exists(scaler.local_data_path):
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
return scaler
return info raise ValueError(f"Unable to find model info: {path}")
except Exception:
errors.report("Error making Real-ESRGAN models list", exc_info=True)
return None
def load_models(self, _): def load_models(self, _):
return get_realesrgan_models(self) return get_realesrgan_models(self)

View File

@ -1,6 +1,7 @@
import os import os
import re import re
import sys import sys
import inspect
from collections import namedtuple from collections import namedtuple
import gradio as gr import gradio as gr
@ -116,6 +117,21 @@ class Script:
pass pass
def after_extra_networks_activate(self, p, *args, **kwargs):
"""
Calledafter extra networks activation, before conds calculation
allow modification of the network after extra networks activation been applied
won't be call if p.disable_extra_networks
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
- extra_network_data - list of ExtraNetworkParams for current stage
"""
pass
def process_batch(self, p, *args, **kwargs): def process_batch(self, p, *args, **kwargs):
""" """
Same as process(), but called for every batch. Same as process(), but called for every batch.
@ -186,6 +202,11 @@ class Script:
return f'script_{tabname}{title}_{item_id}' return f'script_{tabname}{title}_{item_id}'
def before_hr(self, p, *args):
"""
This function is called before hires fix start.
"""
pass
current_basedir = paths.script_path current_basedir = paths.script_path
@ -249,7 +270,7 @@ def load_scripts():
def register_scripts_from_module(module): def register_scripts_from_module(module):
for script_class in module.__dict__.values(): for script_class in module.__dict__.values():
if type(script_class) != type: if not inspect.isclass(script_class):
continue continue
if issubclass(script_class, Script): if issubclass(script_class, Script):
@ -483,6 +504,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)
def after_extra_networks_activate(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.after_extra_networks_activate(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)
def process_batch(self, p, **kwargs): def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
@ -548,6 +577,15 @@ class ScriptRunner:
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
def before_hr(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_hr(p, *script_args)
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
scripts_txt2img: ScriptRunner = None scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None scripts_img2img: ScriptRunner = None
scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None

View File

@ -23,7 +23,8 @@ model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
checkpoints_list = {} checkpoints_list = {}
checkpoint_alisases = {} checkpoint_aliases = {}
checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
@ -66,7 +67,7 @@ class CheckpointInfo:
def register(self): def register(self):
checkpoints_list[self.title] = self checkpoints_list[self.title] = self
for id in self.ids: for id in self.ids:
checkpoint_alisases[id] = self checkpoint_aliases[id] = self
def calculate_shorthash(self): def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
@ -112,7 +113,7 @@ def checkpoint_tiles():
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
checkpoint_alisases.clear() checkpoint_aliases.clear()
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
@ -136,7 +137,7 @@ def list_models():
def get_closet_checkpoint_match(search_string): def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_alisases.get(search_string, None) checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
@ -166,7 +167,7 @@ def select_checkpoint():
"""Raises `FileNotFoundError` if no checkpoints are found.""" """Raises `FileNotFoundError` if no checkpoints are found."""
model_checkpoint = shared.opts.sd_model_checkpoint model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
@ -247,7 +248,12 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
_, extension = os.path.splitext(checkpoint_file) _, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location or devices.get_optimal_device_name() device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
if not shared.opts.disable_mmap_load_safetensors:
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@ -585,7 +591,6 @@ def unload_model_weights(sd_model=None, info=None):
sd_model = None sd_model = None
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
torch.cuda.empty_cache()
print(f"Unloaded weights {timer.summary()}.") print(f"Unloaded weights {timer.summary()}.")

View File

@ -1,9 +1,11 @@
import datetime import datetime
import json import json
import os import os
import re
import sys import sys
import threading import threading
import time import time
import logging
import gradio as gr import gradio as gr
import torch import torch
@ -18,6 +20,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional from typing import Optional
log = logging.getLogger(__name__)
demo = None demo = None
parser = cmd_args.parser parser = cmd_args.parser
@ -144,12 +148,15 @@ class State:
def request_restart(self) -> None: def request_restart(self) -> None:
self.interrupt() self.interrupt()
self.server_command = "restart" self.server_command = "restart"
log.info("Received restart request")
def skip(self): def skip(self):
self.skipped = True self.skipped = True
log.info("Received skip request")
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
log.info("Received interrupt request")
def nextjob(self): def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@ -173,7 +180,7 @@ class State:
return obj return obj
def begin(self): def begin(self, job: str = "(unknown)"):
self.sampling_step = 0 self.sampling_step = 0
self.job_count = -1 self.job_count = -1
self.processing_has_refined_job_count = False self.processing_has_refined_job_count = False
@ -187,10 +194,13 @@ class State:
self.interrupted = False self.interrupted = False
self.textinfo = None self.textinfo = None
self.time_start = time.time() self.time_start = time.time()
self.job = job
devices.torch_gc() devices.torch_gc()
log.info("Starting job %s", job)
def end(self): def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = "" self.job = ""
self.job_count = 0 self.job_count = 0
@ -311,6 +321,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"font": OptionInfo("", "Font for image grids that have text"),
"grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}),
"grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}),
"grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
@ -376,6 +390,7 @@ options_templates.update(options_section(('system', "System"), {
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
@ -470,7 +485,6 @@ options_templates.update(options_section(('ui', "User interface"), {
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"),
@ -481,6 +495,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
@ -493,6 +508,7 @@ options_templates.update(options_section(('ui', "User interface"), {
options_templates.update(options_section(('infotext', "Infotext"), { options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"),
"infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'> "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'>
@ -817,8 +833,12 @@ mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start() mem_mon.start()
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
def listfiles(dirname): def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=str.lower) if not x.startswith(".")] filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)] return [file for file in filenames if os.path.isfile(file)]
@ -843,8 +863,11 @@ def walk_files(path, allowed_extensions=None):
if allowed_extensions is not None: if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions) allowed_extensions = set(allowed_extensions)
for root, _, files in os.walk(path, followlinks=True): items = list(os.walk(path, followlinks=True))
for filename in files: items = sorted(items, key=lambda x: natural_sort_key(x[0]))
for root, _, files in items:
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None: if allowed_extensions is not None:
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
if ext not in allowed_extensions: if ext not in allowed_extensions:

View File

@ -2,11 +2,51 @@ import datetime
import json import json
import os import os
saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} saved_params_shared = {
saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} "batch_size",
saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} "clip_grad_mode",
"clip_grad_value",
"create_image_every",
"data_root",
"gradient_step",
"initial_step",
"latent_sampling_method",
"learn_rate",
"log_directory",
"model_hash",
"model_name",
"num_of_dataset_images",
"steps",
"template_file",
"training_height",
"training_width",
}
saved_params_ti = {
"embedding_name",
"num_vectors_per_token",
"save_embedding_every",
"save_image_with_stored_embedding",
}
saved_params_hypernet = {
"activation_func",
"add_layer_norm",
"hypernetwork_name",
"layer_structure",
"save_hypernetwork_every",
"use_dropout",
"weight_init",
}
saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} saved_params_previews = {
"preview_cfg_scale",
"preview_height",
"preview_negative_prompt",
"preview_prompt",
"preview_sampler_index",
"preview_seed",
"preview_steps",
"preview_width",
}
def save_settings_to_file(log_directory, all_params): def save_settings_to_file(log_directory, all_params):

View File

@ -7,7 +7,7 @@ from modules import paths, shared, images, deepbooru
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()

View File

@ -1,5 +1,6 @@
import os import os
from collections import namedtuple from collections import namedtuple
from contextlib import closing
import torch import torch
import tqdm import tqdm
@ -584,6 +585,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
preview_text = p.prompt preview_text = p.prompt
with closing(p):
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None image = processed.images[0] if len(processed.images) > 0 else None

View File

@ -1,13 +1,15 @@
from contextlib import closing
import modules.scripts import modules.scripts
from modules import sd_samplers, processing from modules import sd_samplers, processing
from modules.generation_parameters_copypaste import create_override_settings_dict from modules.generation_parameters_copypaste import create_override_settings_dict
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
@ -48,16 +50,17 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
p.scripts = modules.scripts.scripts_txt2img p.scripts = modules.scripts.scripts_txt2img
p.script_args = args p.script_args = args
p.user = request.username
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
with closing(p):
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
processed = processing.process_images(p) processed = processing.process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()

View File

@ -155,7 +155,7 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
img = Image.open(image) img = Image.open(image)
filename = os.path.basename(image) filename = os.path.basename(image)
left, _ = os.path.splitext(filename) left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a')) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f"{left}.txt"), 'a', encoding='utf-8'))
return [gr.update(), None] return [gr.update(), None]
@ -733,6 +733,10 @@ def create_ui():
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
with gr.Accordion("PNG info", open=False):
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch] img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
@ -773,7 +777,7 @@ def create_ui():
selected_scale_tab = gr.State(value=0) selected_scale_tab = gr.State(value=0)
with gr.Tabs(): with gr.Tabs():
with gr.Tab(label="Resize to") as tab_scale_to: with gr.Tab(label="Resize to", elem_id="img2img_tab_resize_to") as tab_scale_to:
with FormRow(): with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4): with gr.Column(elem_id="img2img_column_size", scale=4):
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
@ -782,7 +786,7 @@ def create_ui():
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn") detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
with gr.Tab(label="Resize by") as tab_scale_by: with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale") scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
with FormRow(): with FormRow():
@ -934,6 +938,9 @@ def create_ui():
img2img_batch_output_dir, img2img_batch_output_dir,
img2img_batch_inpaint_mask_dir, img2img_batch_inpaint_mask_dir,
override_settings, override_settings,
img2img_batch_use_png_info,
img2img_batch_png_info_props,
img2img_batch_png_info_dir,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,

View File

@ -138,7 +138,10 @@ def extension_table():
<table id="extensions"> <table id="extensions">
<thead> <thead>
<tr> <tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th> <th>
<input class="gr-check-radio gr-checkbox all_extensions_toggle" type="checkbox" {'checked="checked"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange="toggle_all_extensions(event)" />
<abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr>
</th>
<th>URL</th> <th>URL</th>
<th>Branch</th> <th>Branch</th>
<th>Version</th> <th>Version</th>
@ -170,7 +173,7 @@ def extension_table():
code += f""" code += f"""
<tr> <tr>
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label{style}><input class="gr-check-radio gr-checkbox extension_toggle" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''} onchange="toggle_extension(event)" />{html.escape(ext.name)}</label></td>
<td>{remote}</td> <td>{remote}</td>
<td>{ext.branch}</td> <td>{ext.branch}</td>
<td>{version_link}</td> <td>{version_link}</td>
@ -421,9 +424,19 @@ sort_ordering = [
(False, lambda x: x.get('name', 'z')), (False, lambda x: x.get('name', 'z')),
(True, lambda x: x.get('name', 'z')), (True, lambda x: x.get('name', 'z')),
(False, lambda x: 'z'), (False, lambda x: 'z'),
(True, lambda x: x.get('commit_time', '')),
(True, lambda x: x.get('created_at', '')),
(True, lambda x: x.get('stars', 0)),
] ]
def get_date(info: dict, key):
try:
return datetime.strptime(info.get(key), "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d")
except (ValueError, TypeError):
return ''
def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""): def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=""):
extlist = available_extensions["extensions"] extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
@ -448,7 +461,10 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):
name = ext.get("name", "noname") name = ext.get("name", "noname")
stars = int(ext.get("stars", 0))
added = ext.get('added', 'unknown') added = ext.get('added', 'unknown')
update_time = get_date(ext, 'commit_time')
create_time = get_date(ext, 'created_at')
url = ext.get("url", None) url = ext.get("url", None)
description = ext.get("description", "") description = ext.get("description", "")
extension_tags = ext.get("tags", []) extension_tags = ext.get("tags", [])
@ -475,7 +491,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
code += f""" code += f"""
<tr> <tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td> <td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td> <td>{html.escape(description)}<p class="info">
<span class="date_added">Update: {html.escape(update_time)} Added: {html.escape(added)} Created: {html.escape(create_time)}</span><span class="star_count">stars: <b>{stars}</b></a></p></td>
<td>{install_code}</td> <td>{install_code}</td>
</tr> </tr>
@ -559,7 +576,7 @@ def create_ui():
with gr.Row(): with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
with gr.Row(): with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False) search_extensions_text = gr.Text(label="Search").style(container=False)
@ -568,9 +585,9 @@ def create_ui():
available_extensions_table = gr.HTML() available_extensions_table = gr.HTML()
refresh_available_extensions_button.click( refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags, sort_column], inputs=[available_extensions_index, hide_tags, sort_column],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result, search_extensions_text], outputs=[available_extensions_index, available_extensions_table, hide_tags, search_extensions_text, install_result],
) )
install_extension_button.click( install_extension_button.click(

View File

@ -30,8 +30,8 @@ def fetch_file(filename: str = ""):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower() ext = os.path.splitext(filename)[1].lower()
if ext not in (".png", ".jpg", ".jpeg", ".webp"): if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
raise ValueError(f"File cannot be fetched: {filename}. Only png and jpg and webp.") raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
# would profit from returning 304 # would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
@ -90,8 +90,8 @@ class ExtraNetworksPage:
subdirs = {} subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]: for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for root, dirs, _ in os.walk(parentdir, followlinks=True): for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):
for dirname in dirs: for dirname in sorted(dirs, key=shared.natural_sort_key):
x = os.path.join(root, dirname) x = os.path.join(root, dirname)
if not os.path.isdir(x): if not os.path.isdir(x):

View File

@ -260,8 +260,15 @@ class UiSettings:
component = self.component_dict[k] component = self.component_dict[k]
info = opts.data_labels[k] info = opts.data_labels[k]
change_handler = component.release if hasattr(component, 'release') else component.change if isinstance(component, gr.Textbox):
change_handler( methods = [component.submit, component.blur]
elif hasattr(component, 'release'):
methods = [component.release]
else:
methods = [component.change]
for method in methods:
method(
fn=lambda value, k=k: self.run_settings_single(value, key=k), fn=lambda value, k=k: self.run_settings_single(value, key=k),
inputs=[component], inputs=[component],
outputs=[component, self.text_settings], outputs=[component, self.text_settings],

View File

@ -704,11 +704,24 @@ table.popup-table .link{
margin: 0; margin: 0;
} }
#available_extensions .date_added{ #available_extensions .info{
opacity: 0.85; margin: 0.5em 0;
display: flex;
margin-top: auto;
opacity: 0.80;
font-size: 90%; font-size: 90%;
} }
#available_extensions .date_added{
margin-right: auto;
display: inline-block;
}
#available_extensions .star_count{
margin-left: auto;
display: inline-block;
}
/* replace original footer with ours */ /* replace original footer with ours */
footer { footer {

View File

@ -11,13 +11,24 @@ import json
from threading import Thread from threading import Thread
from typing import Iterable from typing import Iterable
from fastapi import FastAPI, Response from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from packaging import version from packaging import version
import logging import logging
# We can't use cmd_opts for this because it will not have been initialized at this point.
log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
if log_level:
log_level = getattr(logging, log_level.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors, devices # noqa: F401 from modules import paths, timer, import_hook, errors, devices # noqa: F401
@ -32,7 +43,7 @@ warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvisi
startup_timer.record("import torch") startup_timer.record("import torch")
import gradio import gradio # noqa: F401
startup_timer.record("import gradio") startup_timer.record("import gradio")
import ldm.modules.encoders.modules # noqa: F401 import ldm.modules.encoders.modules # noqa: F401
@ -359,12 +370,11 @@ def api_only():
modules.script_callbacks.app_started_callback(None, app) modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.") print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
port=cmd_opts.port if cmd_opts.port else 7861,
def stop_route(request): root_path = f"/{cmd_opts.subpath}"
shared.state.server_command = "stop" )
return Response("Stopping.")
def webui(): def webui():
@ -403,9 +413,8 @@ def webui():
"docs_url": "/docs", "docs_url": "/docs",
"redoc_url": "/redoc", "redoc_url": "/redoc",
}, },
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
) )
if cmd_opts.add_stop_route:
app.add_route("/_stop", stop_route, methods=["POST"])
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False
@ -436,11 +445,6 @@ def webui():
timer.startup_record = startup_timer.dump() timer.startup_record = startup_timer.dump()
print(f"Startup time: {startup_timer.summary()}.") print(f"Startup time: {startup_timer.summary()}.")
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
try: try:
while True: while True:
server_command = shared.state.wait_for_server_command(timeout=5) server_command = shared.state.wait_for_server_command(timeout=5)

View File

@ -4,26 +4,28 @@
# change the variables in webui-user.sh instead # # change the variables in webui-user.sh instead #
################################################# #################################################
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
# If run from macOS, load defaults from webui-macos-env.sh # If run from macOS, load defaults from webui-macos-env.sh
if [[ "$OSTYPE" == "darwin"* ]]; then if [[ "$OSTYPE" == "darwin"* ]]; then
if [[ -f webui-macos-env.sh ]] if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
then then
source ./webui-macos-env.sh source "$SCRIPT_DIR"/webui-macos-env.sh
fi fi
fi fi
# Read variables from webui-user.sh # Read variables from webui-user.sh
# shellcheck source=/dev/null # shellcheck source=/dev/null
if [[ -f webui-user.sh ]] if [[ -f "$SCRIPT_DIR"/webui-user.sh ]]
then then
source ./webui-user.sh source "$SCRIPT_DIR"/webui-user.sh
fi fi
# Set defaults # Set defaults
# Install directory without trailing slash # Install directory without trailing slash
if [[ -z "${install_dir}" ]] if [[ -z "${install_dir}" ]]
then then
install_dir="$(pwd)" install_dir="$SCRIPT_DIR"
fi fi
# Name of the subdirectory (defaults to stable-diffusion-webui) # Name of the subdirectory (defaults to stable-diffusion-webui)
@ -131,6 +133,10 @@ case "$gpu_info" in
;; ;;
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
;; ;;
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
export TORCH_COMMAND="pip install --pre torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 --index-url https://download.pytorch.org/whl/nightly/rocm5.5"
# Navi 3 needs at least 5.5 which is only on the nightly chain
;;
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"