diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 35802a53c..9c2ff3138 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -45,6 +45,8 @@ body: attributes: label: Commit where the problem happens description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI) + validations: + required: true - type: dropdown id: platforms attributes: diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 000000000..f58c94a9b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: WebUI Community Support + url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions + about: Please ask and answer questions here. diff --git a/.gitignore b/.gitignore index f9c3357cf..8fa05852a 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode +/extensions diff --git a/README.md b/README.md index 859a91b6f..1a0e4f6a2 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - One click install and run script (but you still must install python and git) - Outpainting - Inpainting +- Color Sketch - Prompt Matrix - Stable Diffusion Upscale - Attention, specify parts of text that the model should pay more attention to @@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - have as many embeddings as you want and use any names you like for them - use multiple embeddings with different numbers of vectors per token - works with half precision floating point numbers + - train embeddings on 8GB (also reports of 6GB working) - Extras tab with: - GFPGAN, neural network that fixes faces - CodeFormer, face restoration tool as an alternative to GFPGAN @@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Interrupt processing at any time - 4GB video card support (also reports of 2GB working) - Correct seeds for batches -- Prompt length validation - - get length of prompt in tokens as you type - - get a warning after generation if some text was truncated +- Live prompt token length validation - Generation parameters - parameters you used to generate images are saved with that image - in PNG chunks for PNG, in EXIF for JPEG - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI - can be disabled in settings + - drag and drop an image/text-parameters to promptbox +- Read Generation Parameters Button, loads parameters in promptbox to UI - Settings page - Running arbitrary python code from UI (must run with --allow-code to enable) - Mouseover hints for most UI elements @@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - CLIP interrogator, a button that tries to guess prompt from an image - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway - Batch Processing, process a group of files using img2img -- Img2img Alternative +- Img2img Alternative, reverse Euler method of cross attention control - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions - Reloading checkpoints on the fly -- Checkpoint Merger, a tab that allows you to merge two checkpoints into one +- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once - separate prompts using uppercase `AND` @@ -70,14 +72,35 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- History tab: view, direct and delete images conveniently within the UI +- Generate forever option +- Training tab + - hypernetworks and embeddings options + - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) +- Clip skip +- Use Hypernetworks +- Use VAEs +- Estimated completion time in progress bar +- API +- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. +- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) + +## Where are Aesthetic Gradients?!?! +Aesthetic Gradients are now an extension. You can install it using git: + +```commandline +git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients +``` + +After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart +the UI. The interface for Aesthetic Gradients should appear exactly the same as it was. ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. -Alternatively, use Google Colab: +Alternatively, use online services (like Google Colab): -- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl) -- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh). +- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services) ### Automatic Installation on Windows 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" diff --git a/extensions/put extension here.txt b/extensions/put extension here.txt new file mode 100644 index 000000000..e69de29bb diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 96f1c00d0..66f26a22a 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -3,12 +3,12 @@ let currentWidth = null; let currentHeight = null; let arFrameTimeout = setTimeout(function(){},0); -function dimensionChange(e,dimname){ +function dimensionChange(e, is_width, is_height){ - if(dimname == 'Width'){ + if(is_width){ currentWidth = e.target.value*1.0 } - if(dimname == 'Height'){ + if(is_height){ currentHeight = e.target.value*1.0 } @@ -18,22 +18,13 @@ function dimensionChange(e,dimname){ return; } - var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200') - if(img2imgMode){ - img2imgMode=img2imgMode.innerText - }else{ - return; - } - - var redrawImage = gradioApp().querySelector('div[data-testid=image] img'); - var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img') - var targetElement = null; - if(img2imgMode=='img2img' && redrawImage){ - targetElement = redrawImage; - }else if(img2imgMode=='Inpaint' && inpaintImage){ - targetElement = inpaintImage; + var tabIndex = get_tab_index('mode_img2img') + if(tabIndex == 0){ + targetElement = gradioApp().querySelector('div[data-testid=image] img'); + } else if(tabIndex == 1){ + targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); } if(targetElement){ @@ -98,22 +89,20 @@ onUiUpdate(function(){ var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) if(inImg2img){ let inputs = gradioApp().querySelectorAll('input'); - inputs.forEach(function(e){ - let parentLabel = e.parentElement.querySelector('label') - if(parentLabel && parentLabel.innerText){ - if(!e.classList.contains('scrollwatch')){ - if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){ - e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} ) - e.classList.add('scrollwatch') - } - if(parentLabel.innerText == 'Width'){ - currentWidth = e.value*1.0 - } - if(parentLabel.innerText == 'Height'){ - currentHeight = e.value*1.0 - } - } - } + inputs.forEach(function(e){ + var is_width = e.parentElement.id == "img2img_width" + var is_height = e.parentElement.id == "img2img_height" + + if((is_width || is_height) && !e.classList.contains('scrollwatch')){ + e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} ) + e.classList.add('scrollwatch') + } + if(is_width){ + currentWidth = e.value*1.0 + } + if(is_height){ + currentHeight = e.value*1.0 + } }) } }); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 070cf255f..3ed1cb3c6 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; const imgWrap = target.closest('[data-testid="image"]'); - if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) { + if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { return; } e.stopPropagation(); diff --git a/javascript/images_history.js b/javascript/images_history.js index f7d052c3d..c9aa76f83 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -17,14 +17,6 @@ var images_history_click_image = function(){ images_history_set_image_info(this); } -var images_history_click_tab = function(){ - var tabs_box = gradioApp().getElementById("images_history_tab"); - if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click(); - tabs_box.classList.add(this.getAttribute("tabname")) - } -} - function images_history_disabled_del(){ gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ btn.setAttribute('disabled','disabled'); @@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){ var parent = item.parentElement; tagname = tagname.toUpperCase() while(parent.tagName != tagname){ - console.log(parent.tagName, tagname) parent = parent.parentElement; } return parent; @@ -88,15 +79,15 @@ function images_history_set_image_info(button){ } -function images_history_get_current_img(tabname, image_path, files){ +function images_history_get_current_img(tabname, img_index, files){ return [ - gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), - image_path, + tabname, + gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), files ]; } -function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){ +function images_history_delete(del_num, tabname, image_index){ image_index = parseInt(image_index); var tab = gradioApp().getElementById(tabname + '_images_history'); var set_btn = tab.querySelector(".images_history_set_index"); @@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i } }); var img_num = buttons.length / 2; + del_num = Math.min(img_num - image_index, del_num) if (img_num <= del_num){ setTimeout(function(tabname){ gradioApp().getElementById(tabname + '_images_history_renew_page').click(); @@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i } else { var next_img for (var i = 0; i < del_num; i++){ - if (image_index + i < image_index + img_num){ - buttons[image_index + i].style.display = 'none'; - buttons[image_index + img_num + 1].style.display = 'none'; - next_img = image_index + i + 1 - } + buttons[image_index + i].style.display = 'none'; + buttons[image_index + i + img_num].style.display = 'none'; + next_img = image_index + i + 1 } var bnt; if (next_img >= img_num){ - btn = buttons[image_index - del_num]; + btn = buttons[image_index - 1]; } else { btn = buttons[next_img]; } setTimeout(function(btn){btn.click()}, 30, btn); } images_history_disabled_del(); - return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index]; + } -function images_history_turnpage(img_path, page_index, image_index, tabname){ +function images_history_turnpage(tabname){ + gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); buttons.forEach(function(elem) { elem.style.display = 'block'; - }) - return [img_path, page_index, image_index, tabname]; + }) } function images_history_enable_del_buttons(){ @@ -147,60 +137,64 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page') - if (load_txt2img_button){ + var tabnames = gradioApp().getElementById("images_history_tabnames_list") + if (tabnames){ + images_history_tab_list = tabnames.querySelector("textarea").value.split(",") for (var i in images_history_tab_list ){ - tab = images_history_tab_list[i]; + var tab = images_history_tab_list[i]; gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); - gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); - + gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); + gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px"); } - var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); - tabs_box.setAttribute("id", "images_history_tab"); - var tab_btns = tabs_box.querySelectorAll("button"); - for (var i in images_history_tab_list){ - var tabname = images_history_tab_list[i] - tab_btns[i].setAttribute("tabname", tabname); - // this refreshes history upon tab switch - // until the history is known to work well, which is not the case now, we do not do this at startup - //tab_btns[i].addEventListener('click', images_history_click_tab); - } - tabs_box.classList.add(images_history_tab_list[0]); - - // same as above, at page load - //load_txt2img_button.click(); + //preload + if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){ + var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); + tabs_box.setAttribute("id", "images_history_tab"); + var tab_btns = tabs_box.querySelectorAll("button"); + for (var i in images_history_tab_list){ + var tabname = images_history_tab_list[i] + tab_btns[i].setAttribute("tabname", tabname); + tab_btns[i].addEventListener('click', function(){ + var tabs_box = gradioApp().getElementById("images_history_tab"); + if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { + gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); + tabs_box.classList.add(this.getAttribute("tabname")) + } + }); + } + tab_btns[0].click() + } } else { setTimeout(images_history_init, 500); } } -var images_history_tab_list = ["txt2img", "img2img", "extras"]; +var images_history_tab_list = ""; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - for (var i in images_history_tab_list ){ - let tabname = images_history_tab_list[i] - var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); - buttons.forEach(function(bnt){ - bnt.addEventListener('click', images_history_click_image, true); - }); + if (images_history_tab_list != ""){ + for (var i in images_history_tab_list ){ + let tabname = images_history_tab_list[i] + var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); + buttons.forEach(function(bnt){ + bnt.addEventListener('click', images_history_click_image, true); + }); - // same as load_txt2img_button.click() above - /* - var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); - if (cls_btn){ - cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); - }, false); - }*/ + var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); + if (cls_btn){ + cls_btn.addEventListener('click', function(){ + gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + }, false); + } - } + } + } }); - mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); - + mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); }); diff --git a/modules/devices.py b/modules/devices.py index eb4225834..dc1f3cddd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + from modules import shared + + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") @@ -34,7 +45,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/extras.py b/modules/extras.py index b853fa5b2..22c5a1c12 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ if input_dir == '': return outputs, "Please select an input directory.", '' - image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] + image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)] for img in image_list: - image = Image.open(img) + try: + image = Image.open(img) + except Exception: + continue imageArr.append(image) imageNameArr.append(img) else: @@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] + + if opts.use_original_name_batch and image_name != None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' - images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, - forced_filename=image_name if opts.use_original_name_batch else None) + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if opts.enable_pnginfo: image.info = existing_pnginfo diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0f0414499..f73647dab 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -4,13 +4,22 @@ import gradio as gr from modules.shared import script_path from modules import shared -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") type_of_gr_update = type(gr.update()) +def quote(text): + if ',' not in str(text): + return text + + text = str(text) + text = text.replace('\\', '\\\\') + text = text.replace('"', '\\"') + return f'"{text}"' + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): else: try: valtype = type(output.value) - val = valtype(v) + + if valtype == bool and v == "False": + val = False + else: + val = valtype(v) + res.append(gr.update(value=val)) except Exception: res.append(gr.update()) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index e493f3669..b7a040389 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -41,12 +41,12 @@ class HypernetworkModule(torch.nn.Module): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) # Add an activation func - if activation_func == "linear": + if activation_func == "linear" or activation_func is None: pass elif activation_func in self.activation_dict: linears.append(self.activation_dict[activation_func]()) else: - raise NotImplementedError( + raise RuntimeError( "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" ) @@ -65,7 +65,7 @@ class HypernetworkModule(torch.nn.Module): self.load_state_dict(state_dict) else: for layer in self.linear: - if isinstance(layer, torch.nn.Linear): + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer.weight.data.normal_(mean=0.0, std=0.01) layer.bias.data.zero_() @@ -93,7 +93,7 @@ class HypernetworkModule(torch.nn.Module): def trainables(self): layer_structure = [] for layer in self.linear: - if isinstance(layer, torch.nn.Linear): + if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm: layer_structure += [layer.weight, layer.bias] return layer_structure @@ -272,6 +272,9 @@ def stack_conds(conds): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + # images allows training previews to have infotext. Importing it at the top causes a circular import problem. + from modules import images + assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -314,6 +317,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" + forced_filename = "" ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -353,7 +357,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log pbar.set_description(f"loss: {mean_loss:.7f}") if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { @@ -362,7 +368,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' + last_saved_image = os.path.join(images_dir, forced_filename) optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -398,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if image is not None: shared.state.current_image = image - image.save(last_saved_image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -408,7 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log Loss: {mean_loss:.7f}
Step: {hypernetwork.step}
Last prompt: {html.escape(entries[0].cond_text)}
-Last saved embedding: {html.escape(last_saved_file)}
+Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" @@ -417,6 +424,9 @@ Last saved image: {html.escape(last_saved_image)}
hypernetwork.sd_checkpoint = checkpoint.hash hypernetwork.sd_checkpoint_name = checkpoint.model_name + # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). + hypernetwork.name = hypernetwork_name + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') hypernetwork.save(filename) return hypernetwork, filename diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 5f6f17b6d..2b472d870 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,9 +9,13 @@ from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): + # Remove illegal characters from name. + name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" if type(layer_structure) == str: layer_structure = [float(x.strip()) for x in layer_structure.split(",")] diff --git a/modules/images_history.py b/modules/images_history.py index 46b23e569..bc5cf11f0 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -1,183 +1,424 @@ import os import shutil -import sys +import time +import hashlib +import gradio +system_bak_path = "webui_log_and_bak" +custom_tab_name = "custom fold" +faverate_tab_name = "favorites" +tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] +def is_valid_date(date): + try: + time.strptime(date, "%Y%m%d") + return True + except: + return False -def traverse_all_files(output_dir, image_list, curr_dir=None): - curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir) +def reduplicative_file_move(src, dst): + def same_name_file(basename, path): + name, ext = os.path.splitext(basename) + f_list = os.listdir(path) + max_num = 0 + for f in f_list: + if len(f) <= len(basename): + continue + f_ext = f[-len(ext):] if len(ext) > 0 else "" + if f[:len(name)] == name and f_ext == ext: + if f[len(name)] == "(" and f[-len(ext)-1] == ")": + number = f[len(name)+1:-len(ext)-1] + if number.isdigit(): + if int(number) > max_num: + max_num = int(number) + return f"{name}({max_num + 1}){ext}" + name = os.path.basename(src) + save_name = os.path.join(dst, name) + if not os.path.exists(save_name): + shutil.move(src, dst) + else: + name = same_name_file(name, dst) + shutil.move(src, os.path.join(dst, name)) + +def traverse_all_files(curr_path, image_list, all_type=False): try: f_list = os.listdir(curr_path) except: - if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": - image_list.append(curr_dir) + if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"): + image_list.append(curr_path) return image_list for file in f_list: - file = file if curr_dir is None else os.path.join(curr_dir, file) - file_path = os.path.join(curr_path, file) - if file[-4:] == ".txt": + file = os.path.join(curr_path, file) + if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): pass - elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: + elif os.path.isfile(file) and file[-10:].rfind(".") > 0: image_list.append(file) else: - image_list = traverse_all_files(output_dir, image_list, file) + image_list = traverse_all_files(file, image_list) return image_list +def auto_sorting(dir_name): + bak_path = os.path.join(dir_name, system_bak_path) + if not os.path.exists(bak_path): + os.mkdir(bak_path) + log_file = None + files_list = [] + f_list = os.listdir(dir_name) + for file in f_list: + if file == system_bak_path: + continue + file_path = os.path.join(dir_name, file) + if not is_valid_date(file): + if file[-10:].rfind(".") > 0: + files_list.append(file_path) + else: + files_list = traverse_all_files(file_path, files_list, all_type=True) -def get_recent_images(dir_name, page_index, step, image_index, tabname): - page_index = int(page_index) - image_list = [] - if not os.path.exists(dir_name): - pass - elif os.path.isdir(dir_name): - image_list = traverse_all_files(dir_name, image_list) - image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) + for file in files_list: + date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file))) + file_path = os.path.dirname(file) + hash_path = hashlib.md5(file_path.encode()).hexdigest() + path = os.path.join(dir_name, date_str, hash_path) + if not os.path.exists(path): + os.makedirs(path) + if log_file is None: + log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") + log_file.write(f"{hash_path},{file_path}\n") + reduplicative_file_move(file, path) + + date_list = [] + f_list = os.listdir(dir_name) + for f in f_list: + if is_valid_date(f): + date_list.append(f) + elif f == system_bak_path: + continue + else: + try: + reduplicative_file_move(os.path.join(dir_name, f), bak_path) + except: + pass + + today = time.strftime("%Y%m%d",time.localtime(time.time())) + if today not in date_list: + date_list.append(today) + return sorted(date_list, reverse=True) + +def archive_images(dir_name, date_to): + filenames = [] + batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) + if batch_size <= 0: + batch_size = opts.images_history_num_per_page * 6 + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + date_to_bak = date_to + if False: #opts.images_history_reconstruct_directory: + date_list = auto_sorting(dir_name) + for date in date_list: + if date <= date_to: + path = os.path.join(dir_name, date) + if date == today and not os.path.exists(path): + continue + filenames = traverse_all_files(path, filenames) + if len(filenames) > batch_size: + break + filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) else: - print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr) - num = 48 if tabname != "extras" else 12 - max_page_index = len(image_list) // num + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num - image_list = image_list[idx_frm:idx_frm + num] - image_index = int(image_index) - if image_index < 0 or image_index > len(image_list) - 1: - current_file = None - hidden = None - else: - current_file = image_list[int(image_index)] - hidden = os.path.join(dir_name, current_file) - return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, "" + filenames = traverse_all_files(dir_name, filenames) + total_num = len(filenames) + tmparray = [(os.path.getmtime(file), file) for file in filenames ] + date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 + filenames = [] + date_list = {date_to:None} + date = time.strftime("%Y%m%d",time.localtime(time.time())) + for t, f in tmparray: + date = time.strftime("%Y%m%d",time.localtime(t)) + date_list[date] = None + if t <= date_stamp: + filenames.append((t, f ,date)) + date_list = sorted(list(date_list.keys()), reverse=True) + sort_array = sorted(filenames, key=lambda x:-x[0]) + if len(sort_array) > batch_size: + date = sort_array[batch_size][2] + filenames = [x[1] for x in sort_array] + else: + date = date_to if len(sort_array) == 0 else sort_array[-1][2] + filenames = [x[1] for x in sort_array] + filenames = [x[1] for x in sort_array if x[2]>= date] + num = len(filenames) + last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) + date = date[:4] + "/" + date[4:6] + "/" + date[6:8] + date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] + load_info = "
" + load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info += "
" + _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) + return ( + date_to, + load_info, + filenames, + 1, + image_list, + "", + "", + visible_num, + last_date_from, + gradio.update(visible=total_num > num) + ) - -def first_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, 1, 0, image_index, tabname) - - -def end_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, -1, 0, image_index, tabname) - - -def prev_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, -1, image_index, tabname) - - -def next_page_click(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 1, image_index, tabname) - - -def page_index_change(dir_name, page_index, image_index, tabname): - return get_recent_images(dir_name, page_index, 0, image_index, tabname) - - -def show_image_info(num, image_path, filenames): - # print(f"select image {num}") - file = filenames[int(num)] - return file, num, os.path.join(image_path, file) - - -def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): +def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": return filenames, delete_num else: delete_num = int(delete_num) + visible_num = int(visible_num) + image_index = int(image_index) index = list(filenames).index(name) i = 0 new_file_list = [] for name in filenames: if i >= index and i < index + delete_num: - path = os.path.join(dir_name, name) - if os.path.exists(path): - print(f"Delete file {path}") - os.remove(path) - txt_file = os.path.splitext(path)[0] + ".txt" + if os.path.exists(name): + if visible_num == image_index: + new_file_list.append(name) + i += 1 + continue + print(f"Delete file {name}") + os.remove(name) + visible_num -= 1 + txt_file = os.path.splitext(name)[0] + ".txt" if os.path.exists(txt_file): os.remove(txt_file) else: - print(f"Not exists file {path}") + print(f"Not exists file {name}") else: new_file_list.append(name) i += 1 - return new_file_list, 1 + return new_file_list, 1, visible_num +def save_image(file_name): + if file_name is not None and os.path.exists(file_name): + shutil.copy(file_name, opts.outdir_save) + +def get_recent_images(page_index, step, filenames): + page_index = int(page_index) + num_of_imgs_per_page = int(opts.images_history_num_per_page) + max_page_index = len(filenames) // num_of_imgs_per_page + 1 + page_index = max_page_index if page_index == -1 else page_index + step + page_index = 1 if page_index < 1 else page_index + page_index = max_page_index if page_index > max_page_index else page_index + idx_frm = (page_index - 1) * num_of_imgs_per_page + image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] + length = len(filenames) + visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page + visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num + return page_index, image_list, "", "", visible_num + +def loac_batch_click(date_to): + if date_to is None: + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + else: + return None, [] +def forward_click(last_date_from, date_to_recorder): + if len(date_to_recorder) == 0: + return None, [] + if last_date_from == date_to_recorder[-1]: + date_to_recorder = date_to_recorder[:-1] + if len(date_to_recorder) == 0: + return None, [] + return date_to_recorder[-1], date_to_recorder[:-1] + +def backward_click(last_date_from, date_to_recorder): + if last_date_from is None or last_date_from == "": + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: + date_to_recorder.append(last_date_from) + return last_date_from, date_to_recorder + + +def first_page_click(page_index, filenames): + return get_recent_images(1, 0, filenames) + +def end_page_click(page_index, filenames): + return get_recent_images(-1, 0, filenames) + +def prev_page_click(page_index, filenames): + return get_recent_images(page_index, -1, filenames) + +def next_page_click(page_index, filenames): + return get_recent_images(page_index, 1, filenames) + +def page_index_change(page_index, filenames): + return get_recent_images(page_index, 0, filenames) + +def show_image_info(tabname_box, num, page_index, filenames): + file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] + tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" + return file, tm, num, file + +def enable_page_buttons(): + return gradio.update(visible=True) + +def change_dir(img_dir, date_to): + warning = None + try: + if os.path.exists(img_dir): + try: + f = os.listdir(img_dir) + except: + warning = f"'{img_dir} is not a directory" + else: + warning = "The directory is not exist" + except: + warning = "The format of the directory is incorrect" + if warning is None: + today = time.strftime("%Y%m%d",time.localtime(time.time())) + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) + else: + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - if opts.outdir_samples != "": - dir_name = opts.outdir_samples - elif tabname == "txt2img": + custom_dir = False + if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples + elif tabname == faverate_tab_name: + dir_name = opts.outdir_save else: - return - with gr.Row(): - renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - with gr.Row(elem_id=tabname + "_images_history"): - with gr.Row(): - with gr.Column(scale=2): - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - with gr.Column(): - with gr.Row(): - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(label="File Name", interactive=False) - with gr.Row(): + custom_dir = True + dir_name = None + + if not custom_dir: + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + + with gr.Column() as page_panel: + with gr.Row(): + with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: + load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) + with gr.Column(scale=4): + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(visible=False, scale=1) as batch_panel: + with gr.Row(): + forward = gr.Button('Prev batch') + backward = gr.Button('Next batch') + with gr.Column(scale=3): + load_info = gr.HTML(visible=not custom_dir) + with gr.Row(visible=False) as warning: + warning_box = gr.Textbox("Message", interactive=False) + + with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: + with gr.Column(scale=2): + with gr.Row(visible=True) as turn_page_buttons: + #date_to = gr.Dropdown(label="Date to") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") + + with gr.Column(): + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) + gr.HTML("
") + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + img_file_time= gr.HTML() + with gr.Row(): + if tabname != faverate_tab_name: + save_btn = gr.Button('Collect') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + + # hiden items + with gr.Row(visible=False): + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + batch_date_to = gr.Textbox(label="Date to") + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() - img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) - tabname_box = gr.Textbox(tabname, visible=False) - image_index = gr.Textbox(value=-1, visible=False) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) - filenames = gr.State() - hidden = gr.Image(type="pil", visible=False) - info1 = gr.Textbox(visible=False) - info2 = gr.Textbox(visible=False) + img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) - # turn pages - gallery_inputs = [img_path, page_index, image_index, tabname_box] - gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name] + #change batch + change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] + + batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) + batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) - # page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) + load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) + + + #delete + delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) + delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) + if tabname != faverate_tab_name: + save_btn.click(save_image, inputs=[img_file_name], outputs=None) + + #turn page + gallery_inputs = [page_index, filenames] + gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] + first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) + page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) + renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) + + first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) + img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - - # pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + -def create_history_tabs(gr, opts, run_pnginfo, switch_dict): +def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): + global opts; + opts = sys_opts + loads_files_num = int(opts.images_history_num_per_page) + num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) + if cmp_ops.browse_all_images: + tabs_list.append(custom_tab_name) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - with gr.Tab("txt2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_txt2img: - show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) - with gr.Tab("img2img history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) - with gr.Tab("extras history"): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: - show_images_history(gr, opts, "extras", run_pnginfo, switch_dict) + for tab in tabs_list: + with gr.Tab(tab): + with gr.Blocks(analytics_enabled=False) : + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) + gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) + gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) + return images_history diff --git a/modules/img2img.py b/modules/img2img.py index 241267745..8d9f7cf98 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/interrogate.py b/modules/interrogate.py index 64b91eb46..65b05d34c 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -28,9 +28,11 @@ class InterrogateModels: clip_preprocess = None categories = None dtype = None + running_on_cpu = None def __init__(self, content_dir): self.categories = [] + self.running_on_cpu = devices.device_interrogate == torch.device("cpu") if os.path.exists(content_dir): for filename in os.listdir(content_dir): @@ -53,7 +55,11 @@ class InterrogateModels: def load_clip_model(self): import clip - model, preprocess = clip.load(clip_model_name) + if self.running_on_cpu: + model, preprocess = clip.load(clip_model_name, device="cpu") + else: + model, preprocess = clip.load(clip_model_name) + model.eval() model = model.to(devices.device_interrogate) @@ -62,14 +68,14 @@ class InterrogateModels: def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() - if not shared.cmd_opts.no_half: + if not shared.cmd_opts.no_half and not self.running_on_cpu: self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.to(devices.device_interrogate) diff --git a/modules/lowvram.py b/modules/lowvram.py index 7eba1349c..f327c3dfc 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,9 +1,8 @@ import torch -from modules.devices import get_optimal_device +from modules import devices module_in_gpu = None cpu = torch.device("cpu") -device = gpu = get_optimal_device() def send_everything_to_cpu(): @@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram): if module_in_gpu is not None: module_in_gpu.to(cpu) - module.to(gpu) + module.to(devices.device) module_in_gpu = module # see below for register_forward_pre_hook; @@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram): # send the model to GPU. Then put modules back. the modules will be in CPU. stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None - sd_model.to(device) + sd_model.to(devices.device) sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored # register hooks for those the first two models @@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram): # so that only one of them is in GPU at a time stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None - sd_model.model.to(device) + sd_model.model.to(devices.device) diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored # install hooks for bits of third model diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c3..372489f7c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -104,6 +104,12 @@ class StableDiffusionProcessing(): self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + self.scripts = None + self.script_args = None + self.all_prompts = None + self.all_seeds = None + self.all_subseeds = None + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -304,7 +310,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -318,7 +324,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" @@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - all_prompts = p.prompt + p.all_prompts = p.prompt else: - all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(seed) == list: - all_seeds = seed + p.all_seeds = seed else: - all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] if type(subseed) == list: - all_subseeds = subseed + p.all_subseeds = subseed else: - all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + if p.scripts is not None: + p.scripts.run_alwayson_scripts(p) + infotexts = [] output_images = [] with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): - p.init(all_prompts, all_seeds, all_subseeds) + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) if state.job_count == -1: state.job_count = p.n_iter @@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break - prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] - seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] - subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if (len(prompts) == 0): break @@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): @@ -540,17 +549,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + def create_dummy_mask(self, x, width=None, height=None): + if self.sampler.conditioning_key in {'hybrid', 'concat'}: + height = height or self.height + width = width or self.width + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + else: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) + + return image_conditioning def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] @@ -587,7 +616,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) return samples @@ -613,6 +642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.inpainting_mask_invert = inpainting_mask_invert self.mask = None self.nmask = None + self.image_conditioning = None def init(self, all_prompts, all_seeds, all_subseeds): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) @@ -714,10 +744,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + if self.sampler.conditioning_key in {'hybrid', 'concat'}: + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + else: + self.image_conditioning = torch.zeros( + self.init_latent.shape[0], 5, 1, 1, + dtype=self.init_latent.dtype, + device=self.init_latent.device + ) + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py new file mode 100644 index 000000000..866b7acdf --- /dev/null +++ b/modules/script_callbacks.py @@ -0,0 +1,42 @@ + +callbacks_model_loaded = [] +callbacks_ui_tabs = [] + + +def clear_callbacks(): + callbacks_model_loaded.clear() + callbacks_ui_tabs.clear() + + +def model_loaded_callback(sd_model): + for callback in callbacks_model_loaded: + callback(sd_model) + + +def ui_tabs_callback(): + res = [] + + for callback in callbacks_ui_tabs: + res += callback() or [] + + return res + + +def on_model_loaded(callback): + """register a function to be called when the stable diffusion model is created; the model is + passed as an argument""" + callbacks_model_loaded.append(callback) + + +def on_ui_tabs(callback): + """register a function to be called when the UI is creating new tabs. + The function must either return a None, which means no new tabs to be added, or a list, where + each element is a tuple: + (gradio_component, title, elem_id) + + gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks) + title is tab text displayed to user in the UI + elem_id is HTML id for the tab + """ + callbacks_ui_tabs.append(callback) + diff --git a/modules/scripts.py b/modules/scripts.py index 1039fa9cd..9323af3e3 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,86 +1,175 @@ import os import sys import traceback +from collections import namedtuple import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared +from modules import shared, paths, script_callbacks + +AlwaysVisible = object() + class Script: filename = None args_from = None args_to = None + alwayson = False + + infotext_fields = None + """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when + parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example + """ - # The title of the script. This is what will be displayed in the dropdown menu. def title(self): + """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" + raise NotImplementedError() - # How the script is displayed in the UI. See https://gradio.app/docs/#components - # for the different UI components you can use and how to create them. - # Most UI components can return a value, such as a boolean for a checkbox. - # The returned values are passed to the run method as parameters. def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned componenbts will be passed to run() and process() functions. + """ + pass - # Determines when the script should be shown in the dropdown menu via the - # returned value. As an example: - # is_img2img is True if the current tab is img2img, and False if it is txt2img. - # Thus, return is_img2img to only show the script on the img2img tab. def show(self, is_img2img): + """ + is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + + This function should return: + - False if the script should not be shown in UI at all + - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - script.AlwaysVisible if the script should be shown in UI at all times + """ + return True - # This is where the additional processing is implemented. The parameters include - # self, the model object "p" (a StableDiffusionProcessing class, see - # processing.py), and the parameters returned by the ui method. - # Custom functions can be defined here, and additional libraries can be imported - # to be used in processing. The return value should be a Processed object, which is - # what is returned by the process_images method. - def run(self, *args): + def run(self, p, *args): + """ + This function is called if the script has been selected in the script dropdown. + It must do all processing and return the Processed object with results, same as + one returned by processing.process_images. + + Usually the processing is done by calling the processing.process_images function. + + args contains all values returned by components from ui() + """ + raise NotImplementedError() - # The description method is currently unused. - # To add a description that appears when hovering over the title, amend the "titles" - # dict in script.js to include the script title (returned by title) as a key, and - # your description as the value. + def process(self, p, *args): + """ + This function is called before processing begins for AlwaysVisible scripts. + scripts. You can modify the processing object (p) here, inject hooks, etc. + """ + + pass + def describe(self): + """unused""" return "" +current_basedir = paths.script_path + + +def basedir(): + """returns the base directory for the current script. For scripts in the main scripts directory, + this is the main directory (where webui.py resides), and for scripts in extensions directory + (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic) + """ + return current_basedir + + scripts_data = [] +ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) -def load_scripts(basedir): - if not os.path.exists(basedir): - return +def list_scripts(scriptdirname, extension): + scripts_list = [] - for filename in sorted(os.listdir(basedir)): - path = os.path.join(basedir, filename) + basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(basedir): + for filename in sorted(os.listdir(basedir)): + scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - if os.path.splitext(path)[1].lower() != '.py': + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + for dirname in sorted(os.listdir(extdir)): + dirpath = os.path.join(extdir, dirname) + scriptdirpath = os.path.join(dirpath, scriptdirname) + + if not os.path.isdir(scriptdirpath): + continue + + for filename in sorted(os.listdir(scriptdirpath)): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) + + scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + + return scripts_list + + +def list_files_with_name(filename): + res = [] + + dirs = [paths.script_path] + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + + for dirpath in dirs: + if not os.path.isdir(dirpath): continue - if not os.path.isfile(path): - continue + path = os.path.join(dirpath, filename) + if os.path.isfile(filename): + res.append(path) + return res + + +def load_scripts(): + global current_basedir + scripts_data.clear() + script_callbacks.clear_callbacks() + + scripts_list = list_scripts("scripts", ".py") + + syspath = sys.path + + for scriptfile in sorted(scripts_list): try: - with open(path, "r", encoding="utf8") as file: + if scriptfile.basedir != paths.script_path: + sys.path = [scriptfile.basedir] + sys.path + current_basedir = scriptfile.basedir + + with open(scriptfile.path, "r", encoding="utf8") as file: text = file.read() from types import ModuleType - compiled = compile(text, path, 'exec') - module = ModuleType(filename) + compiled = compile(text, scriptfile.path, 'exec') + module = ModuleType(scriptfile.filename) exec(compiled, module.__dict__) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append((script_class, path)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) except Exception: - print(f"Error loading script: {filename}", file=sys.stderr) + print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + finally: + sys.path = syspath + current_basedir = paths.script_path + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: @@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.selectable_scripts = [] + self.alwayson_scripts = [] self.titles = [] + self.infotext_fields = [] def setup_ui(self, is_img2img): - for script_class, path in scripts_data: + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path - if not script.show(is_img2img): - continue + visibility = script.show(is_img2img) - self.scripts.append(script) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True - inputs = [dropdown] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] - for script in self.scripts: + inputs = [None] + inputs_alwayson = [True] + + def create_script_ui(script, inputs, inputs_alwayson): script.args_from = len(inputs) script.args_to = len(inputs) controls = wrap_call(script.ui, script.filename, "ui", is_img2img) if controls is None: - continue + return for control in controls: control.custom_script_source = os.path.basename(script.filename) - control.visible = False + if not script.alwayson: + control.visible = False + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields inputs += controls + inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) + for script in self.alwayson_scripts: + with gr.Group(): + create_script_ui(script, inputs, inputs_alwayson) + + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True + inputs[0] = dropdown + + for script in self.selectable_scripts: + create_script_ui(script, inputs, inputs_alwayson) + def select_script(script_index): - if 0 < script_index <= len(self.scripts): - script = self.scripts[script_index-1] + if 0 < script_index <= len(self.selectable_scripts): + script = self.selectable_scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] def init_field(title): if title == 'None': return script_index = self.titles.index(title) - script = self.scripts[script_index] + script = self.selectable_scripts[script_index] for i in range(script.args_from, script.args_to): inputs[i].visible = True @@ -164,7 +277,7 @@ class ScriptRunner: if script_index == 0: return None - script = self.scripts[script_index-1] + script = self.selectable_scripts[script_index-1] if script is None: return None @@ -176,7 +289,16 @@ class ScriptRunner: return processed - def reload_sources(self): + def run_alwayson_scripts(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process(p, *script_args) + except Exception: + print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: args_from = script.args_from @@ -186,9 +308,12 @@ class ScriptRunner: from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + module = cache.get(filename, None) + if module is None: + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + cache[filename] = module for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -197,19 +322,22 @@ class ScriptRunner: self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to + scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + cache = {} + scripts_txt2img.reload_sources(cache) + scripts_img2img.reload_sources(cache) -def reload_scripts(basedir): +def reload_scripts(): global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 984b35c40..0f10828ed 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + def apply_optimizations(): undo_optimizations() @@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens = remade_tokens[:last_comma] length = len(remade_tokens) - + rem = int(math.ceil(length / 75)) * 75 - length remade_tokens += [id_end] * rem + reloc_tokens multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) @@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): use_old = opts.use_old_emphasis_implementation if use_old: @@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - + if use_old: self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) - + z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: @@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if fix[0] == i: fixes.append(fix[1]) self.hijack.fixes.append(fixes) - + tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): @@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - + remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 - + return z - - + def process_tokens(self, remade_batch_tokens, batch_multipliers): if not opts.use_old_emphasis_implementation: remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - + tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) vecs.append(tensor) diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py new file mode 100644 index 000000000..fd92a3354 --- /dev/null +++ b/modules/sd_hijack_inpainting.py @@ -0,0 +1,331 @@ +import torch + +from einops import repeat +from omegaconf import ListConfig + +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddim import DDIMSampler, noise_like + +# ================================================================================================= +# Monkey patch DDIMSampler methods from RunwayML repo directly. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py +# ================================================================================================= +@torch.no_grad() +def sample_ddim(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + +@torch.no_grad() +def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + +# ================================================================================================= +# Monkey patch PLMSSampler methods. +# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes. +# Adapted from: +# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py +# ================================================================================================= +@torch.no_grad() +def sample_plms(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + +# ================================================================================================= +# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py +# ================================================================================================= + +@torch.no_grad() +def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + + +def should_hijack_inpainting(checkpoint_info): + return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + + +def do_inpainting_hijack(): + ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + + ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e872..f9b3063d3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,8 +7,9 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) @@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) + + if should_hijack_inpainting(checkpoint_info): + # Hardcoded config for now... + sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + sd_config.model.params.use_ema = False + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.unet_config.params.in_channels = 9 + + # Create a "fake" config with a different name so that we know to unload it when switching models. + checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + + do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -222,6 +238,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b3..f58a29b9f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def number_of_needed_noises(self, p): return 0 @@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException + # Have to unwrap the inpainting conditioning here to perform pre-processing + image_conditioning = None + if isinstance(cond, dict): + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) @@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + # Wrap the image conditioning back up since the DDIM code can accept the dict directly. + # Note that they need to be lists because it just concatenates them later. + if image_conditioning is not None: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) if self.mask is not None: @@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) self.initialize(p) @@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) self.init_latent = x + self.last_latent = x self.step = 0 + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + + samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): self.initialize(p) self.init_latent = None + self.last_latent = x self.step = 0 steps = steps or p.steps + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + # existing code fails with certain step counts, like 9 try: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module): self.init_latent = None self.step = 0 - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise InterruptedException @@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) @@ -306,6 +334,8 @@ class KDiffusionSampler: self.config = None self.last_latent = None + self.conditioning_key = sd_model.model.conditioning_key + def callback_state(self, d): step = d['i'] latent = d["denoised"] @@ -361,7 +391,7 @@ class KDiffusionSampler: return extra_params_kwargs - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) if p.sampler_noise_scheduler_override: @@ -388,12 +418,18 @@ class KDiffusionSampler: extra_params_kwargs['sigmas'] = sigma_sched self.model_wrap_cfg.init_latent = x + self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): steps = steps or p.steps if p.sampler_noise_scheduler_override: @@ -414,7 +450,13 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples diff --git a/modules/shared.py b/modules/shared.py index faede8214..5d83971ea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,6 +3,7 @@ import datetime import json import os import sys +from collections import OrderedDict import gradio as gr import tqdm @@ -78,6 +79,8 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) +parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) cmd_opts = parser.parse_args() restricted_opts = [ @@ -249,7 +252,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), @@ -315,6 +318,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section(('images-history', "Images Browser"), { + #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), + "images_history_preload": OptionInfo(False, "Preload images at startup"), + "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), + "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), + "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), + +})) class Options: data = None @@ -387,6 +398,8 @@ sd_upscalers = [] sd_model = None +clip_model = None + progress_print_out = sys.stdout diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 23bb4b6a6..5b1c5002f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): self.dataset.append(entry) - assert len(self.dataset) > 1, "No images have been found in the dataset." + assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset) * repeats // batch_size self.initial_indexes = np.arange(len(self.dataset)) @@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): self.shuffle() def shuffle(self): - self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] def create_text(self, filename_text): text = random.choice(self.lines) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index 898ce3b3b..ea653806f 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -5,6 +5,7 @@ import zlib from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from fonts.ttf import Roboto import torch +from modules.shared import opts class EmbeddingEncoder(json.JSONEncoder): @@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t from math import cos image = srcimage.copy() - + fontsize = 32 if textfont is None: try: textfont = ImageFont.truetype(opts.font or Roboto, fontsize) @@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) draw = ImageDraw.Draw(image) - fontsize = 32 + font = ImageFont.truetype(textfont, fontsize) padding = 10 diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 886cf0c3b..33eaddb63 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -1,5 +1,6 @@ import os from PIL import Image, ImageOps +import math import platform import sys import tqdm @@ -11,7 +12,7 @@ if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru -def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): try: if process_caption: shared.interrogator.load() @@ -21,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ db_opts[deepbooru.OPT_INCLUDE_RANKS] = False deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) - preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio) finally: @@ -33,11 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ -def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2): width = process_width height = process_height src = os.path.abspath(process_src) dst = os.path.abspath(process_dst) + split_threshold = max(0.0, min(1.0, split_threshold)) + overlap_ratio = max(0.0, min(0.9, overlap_ratio)) assert src != dst, 'same directory specified as source and destination' @@ -48,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - def save_pic_with_caption(image, index): + def save_pic_with_caption(image, index, existing_caption=None): caption = "" if process_caption: @@ -66,17 +69,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro basename = f"{index:05}-{subindex[0]}-{filename_part}" image.save(os.path.join(dst, f"{basename}.png")) + if preprocess_txt_action == 'prepend' and existing_caption: + caption = existing_caption + ' ' + caption + elif preprocess_txt_action == 'append' and existing_caption: + caption = caption + ' ' + existing_caption + elif preprocess_txt_action == 'copy' and existing_caption: + caption = existing_caption + + caption = caption.strip() + if len(caption) > 0: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) subindex[0] += 1 - def save_pic(image, index): - save_pic_with_caption(image, index) + def save_pic(image, index, existing_caption=None): + save_pic_with_caption(image, index, existing_caption=existing_caption) if process_flip: - save_pic_with_caption(ImageOps.mirror(image), index) + save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption) + + def split_pic(image, inverse_xy): + if inverse_xy: + from_w, from_h = image.height, image.width + to_w, to_h = height, width + else: + from_w, from_h = image.width, image.height + to_w, to_h = width, height + h = from_h * to_w // from_w + if inverse_xy: + image = image.resize((h, to_w)) + else: + image = image.resize((to_w, h)) + + split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) + y_step = (h - to_h) / (split_count - 1) + for i in range(split_count): + y = int(y_step * i) + if inverse_xy: + splitted = image.crop((y, 0, y + to_h, to_w)) + else: + splitted = image.crop((0, y, to_w, y + to_h)) + yield splitted for index, imagefile in enumerate(tqdm.tqdm(files)): subindex = [0] @@ -86,31 +121,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro except Exception: continue + existing_caption = None + existing_caption_filename = os.path.splitext(filename)[0] + '.txt' + if os.path.exists(existing_caption_filename): + with open(existing_caption_filename, 'r', encoding="utf8") as file: + existing_caption = file.read() + if shared.state.interrupted: break - ratio = img.height / img.width - is_tall = ratio > 1.35 - is_wide = ratio < 1 / 1.35 + if img.height > img.width: + ratio = (img.width * height) / (img.height * width) + inverse_xy = False + else: + ratio = (img.height * width) / (img.width * height) + inverse_xy = True - if process_split and is_tall: - img = img.resize((width, height * img.height // img.width)) - - top = img.crop((0, 0, width, height)) - save_pic(top, index) - - bot = img.crop((0, img.height - height, width, img.height)) - save_pic(bot, index) - elif process_split and is_wide: - img = img.resize((width * img.width // img.height, height)) - - left = img.crop((0, 0, width, height)) - save_pic(left, index) - - right = img.crop((img.width - width, 0, img.width, height)) - save_pic(right, index) + if process_split and ratio < 1.0 and ratio <= split_threshold: + for splitted in split_pic(img, inverse_xy): + save_pic(splitted, index, existing_caption=existing_caption) else: img = images.resize_image(1, img, width, height) - save_pic(img, index) + save_pic(img, index, existing_caption=existing_caption) shared.state.nextjob() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562b..529ed3e22 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -153,7 +153,7 @@ class EmbeddingDatabase: return None, None -def create_embedding(name, num_vectors_per_token, init_text='*'): +def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model embedding_layer = cond_model.wrapped.transformer.text_model.embeddings @@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" + if not overwrite_old: + assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 @@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc loss.backward() optimizer.step() + epoch_num = embedding.step // len(ds) epoch_step = embedding.step - (epoch_num * len(ds)) + 1 diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index 36881e7ad..e712284d1 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess from modules import sd_hijack, shared -def create_embedding(name, initialization_text, nvpt): - filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) +def create_embedding(name, initialization_text, nvpt, overwrite_old): + filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() diff --git a/modules/txt2img.py b/modules/txt2img.py index 2381347f9..c9d5a0906 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,5 +1,6 @@ import modules.scripts -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ + StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts import modules.shared as shared import modules.processing as processing @@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) @@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed.images = [] return processed.images, generation_info_js, plaintext_to_html(processed.info) - diff --git a/modules/ui.py b/modules/ui.py index d4b32c058..cd1185527 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -22,9 +22,14 @@ import piexif import torch from PIL import Image, PngImagePlugin -from modules import localization, sd_hijack, sd_models +import gradio as gr +import gradio.utils +import gradio.routes + +from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -from modules.shared import cmd_opts, opts, restricted_opts + +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags @@ -43,6 +48,11 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img +import modules.textual_inversion.ui +import modules.hypernetworks.ui + +import modules.images_history as img_his + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -593,27 +603,29 @@ def apply_setting(key, value): return value +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn = refresh, - inputs = [], - outputs = [refresh_component] - ) - return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -711,6 +723,7 @@ def create_ui(wrap_gradio_gpu_call): firstphase_width, firstphase_height, ] + custom_inputs, + outputs=[ txt2img_gallery, generation_info, @@ -787,6 +800,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), + *modules.scripts.scripts_txt2img.infotext_fields ] txt2img_preview_params = [ @@ -854,8 +868,8 @@ def create_ui(wrap_gradio_gpu_call): sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) @@ -1052,6 +1066,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), + *modules.scripts.scripts_img2img.infotext_fields ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1174,12 +1189,12 @@ def create_ui(wrap_gradio_gpu_call): ) #images history images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields + "fn": modules.generation_parameters_copypaste.connect_paste, + "t2i": txt2img_paste_fields, + "i2i": img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1212,6 +1227,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name = gr.Textbox(label="Name") initialization_text = gr.Textbox(label="Initialization text", value="*") nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") with gr.Row(): with gr.Column(scale=3): @@ -1227,6 +1243,8 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) with gr.Row(): with gr.Column(scale=3): @@ -1240,13 +1258,18 @@ def create_ui(wrap_gradio_gpu_call): process_dst = gr.Textbox(label='Destination directory') process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images into two') + process_split = gr.Checkbox(label='Split oversized images') process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1254,15 +1277,24 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): run_preprocess = gr.Button(value="Preprocess", variant='primary') + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with gr.Row(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + with gr.Row(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") + batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -1296,6 +1328,7 @@ def create_ui(wrap_gradio_gpu_call): new_embedding_name, initialization_text, nvpt, + overwrite_old_embedding, ], outputs=[ train_embedding_name, @@ -1309,6 +1342,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + overwrite_old_hypernetwork, new_hypernetwork_layer_structure, new_hypernetwork_activation_func, new_hypernetwork_add_layer_norm, @@ -1329,10 +1363,13 @@ def create_ui(wrap_gradio_gpu_call): process_dst, process_width, process_height, + preprocess_txt_action, process_flip, process_split, process_caption, - process_caption_deepbooru + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, ], outputs=[ ti_output, @@ -1345,7 +1382,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_embedding_name, - learn_rate, + embedding_learn_rate, batch_size, dataset_directory, log_directory, @@ -1370,7 +1407,7 @@ def create_ui(wrap_gradio_gpu_call): _js="start_training_textual_inversion", inputs=[ train_hypernetwork_name, - learn_rate, + hypernetwork_learn_rate, batch_size, dataset_directory, log_directory, @@ -1491,10 +1528,10 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + oldval = opts.data.get(key, None) if cmd_opts.hide_ui_dir_config and key in restricted_opts: return gr.update(value=oldval), opts.dumpjson() - oldval = opts.data.get(key, None) opts.data[key] = value if oldval != value: @@ -1600,19 +1637,24 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), - (settings_interface, "Settings", "settings"), ] - with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: - css = file.read() + interfaces += script_callbacks.ui_tabs_callback() + + interfaces += [(settings_interface, "Settings", "settings")] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" if os.path.exists(os.path.join(script_path, "user.css")): with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - usercss = file.read() - css += usercss + css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: css += css_hide_progressbar @@ -1835,9 +1877,9 @@ def load_javascript(raw_response): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' - jsdir = os.path.join(script_path, "javascript") - for filename in sorted(os.listdir(jsdir)): - with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: + scripts_list = modules.scripts.list_scripts("javascript", ".js") + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" if cmd_opts.theme is not None: @@ -1855,6 +1897,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, - gradio.routes.templates.TemplateResponse) +reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript() diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index a6468e09a..2afd4aa51 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -172,54 +172,54 @@ class Script(scripts.Script): if down > 0: down = target_h - init_img.height - up - init_image = p.init_images[0] - - state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0) - - def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): + def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False): is_horiz = is_left or is_right is_vert = is_top or is_bottom pixels_horiz = expand_pixels if is_horiz else 0 pixels_vert = expand_pixels if is_vert else 0 - res_w = init.width + pixels_horiz - res_h = init.height + pixels_vert - process_res_w = math.ceil(res_w / 64) * 64 - process_res_h = math.ceil(res_h / 64) * 64 + images_to_process = [] + output_images = [] + for n in range(count): + res_w = init[n].width + pixels_horiz + res_h = init[n].height + pixels_vert + process_res_w = math.ceil(res_w / 64) * 64 + process_res_h = math.ceil(res_h / 64) * 64 - img = Image.new("RGB", (process_res_w, process_res_h)) - img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) - mask = Image.new("RGB", (process_res_w, process_res_h), "white") - draw = ImageDraw.Draw(mask) - draw.rectangle(( - expand_pixels + mask_blur if is_left else 0, - expand_pixels + mask_blur if is_top else 0, - mask.width - expand_pixels - mask_blur if is_right else res_w, - mask.height - expand_pixels - mask_blur if is_bottom else res_h, - ), fill="black") + img = Image.new("RGB", (process_res_w, process_res_h)) + img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) + mask = Image.new("RGB", (process_res_w, process_res_h), "white") + draw = ImageDraw.Draw(mask) + draw.rectangle(( + expand_pixels + mask_blur if is_left else 0, + expand_pixels + mask_blur if is_top else 0, + mask.width - expand_pixels - mask_blur if is_right else res_w, + mask.height - expand_pixels - mask_blur if is_bottom else res_h, + ), fill="black") - np_image = (np.asarray(img) / 255.0).astype(np.float64) - np_mask = (np.asarray(mask) / 255.0).astype(np.float64) - noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) - out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") + np_image = (np.asarray(img) / 255.0).astype(np.float64) + np_mask = (np.asarray(mask) / 255.0).astype(np.float64) + noised = get_matched_noise(np_image, np_mask, noise_q, color_variation) + output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")) - target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width - target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height + target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width + target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height + p.width = target_width if is_horiz else img.width + p.height = target_height if is_vert else img.height - crop_region = ( - 0 if is_left else out.width - target_width, - 0 if is_top else out.height - target_height, - target_width if is_left else out.width, - target_height if is_top else out.height, - ) + crop_region = ( + 0 if is_left else output_images[n].width - target_width, + 0 if is_top else output_images[n].height - target_height, + target_width if is_left else output_images[n].width, + target_height if is_top else output_images[n].height, + ) + mask = mask.crop(crop_region) + p.image_mask = mask - image_to_process = out.crop(crop_region) - mask = mask.crop(crop_region) + image_to_process = output_images[n].crop(crop_region) + images_to_process.append(image_to_process) - p.width = target_width if is_horiz else img.width - p.height = target_height if is_vert else img.height - p.init_images = [image_to_process] - p.image_mask = mask + p.init_images = images_to_process latent_mask = Image.new("RGB", (p.width, p.height), "white") draw = ImageDraw.Draw(latent_mask) @@ -232,31 +232,52 @@ class Script(scripts.Script): p.latent_mask = latent_mask proc = process_images(p) - proc_img = proc.images[0] if initial_seed_and_info[0] is None: initial_seed_and_info[0] = proc.seed initial_seed_and_info[1] = proc.info - out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height)) - out = out.crop((0, 0, res_w, res_h)) - return out + for n in range(count): + output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height)) + output_images[n] = output_images[n].crop((0, 0, res_w, res_h)) - img = init_image + return output_images - if left > 0: - img = expand(img, left, is_left=True) - if right > 0: - img = expand(img, right, is_right=True) - if up > 0: - img = expand(img, up, is_top=True) - if down > 0: - img = expand(img, down, is_bottom=True) + batch_count = p.n_iter + batch_size = p.batch_size + p.n_iter = 1 + state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)) + all_processed_images = [] - res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1]) + for i in range(batch_count): + imgs = [init_img] * batch_size + state.job = f"Batch {i + 1} out of {batch_count}" + + if left > 0: + imgs = expand(imgs, batch_size, left, is_left=True) + if right > 0: + imgs = expand(imgs, batch_size, right, is_right=True) + if up > 0: + imgs = expand(imgs, batch_size, up, is_top=True) + if down > 0: + imgs = expand(imgs, batch_size, down, is_bottom=True) + + all_processed_images += imgs + + all_images = all_processed_images + + combined_grid_image = images.image_grid(all_processed_images) + unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple + if opts.return_grid and not unwanted_grid_because_of_img_count: + all_images = [combined_grid_image] + all_processed_images + + res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1]) if opts.samples_save: - images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p) + for img in all_processed_images: + images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p) + + if opts.grid_save and not unwanted_grid_because_of_img_count: + images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p) return res - diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a1..eff0c942a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) + p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): diff --git a/webui.py b/webui.py index 177bef744..b1deca1b0 100644 --- a/webui.py +++ b/webui.py @@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -79,9 +80,9 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + modules.scripts.load_scripts() - shared.sd_model = modules.sd_models.load_model() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) @@ -118,7 +119,8 @@ def api_only(): api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -def webui(launch_api=False): +def webui(): + launch_api = cmd_opts.api initialize() while 1: @@ -144,7 +146,7 @@ def webui(launch_api=False): sd_samplers.set_samplers() print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) print('Refreshing Model List') @@ -158,4 +160,4 @@ if __name__ == "__main__": if cmd_opts.nowebui: api_only() else: - webui(cmd_opts.api) + webui()