mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-21 21:55:01 +08:00
Merge branch 'master' of upstream
This commit is contained in:
commit
dcb45dfecf
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -45,6 +45,8 @@ body:
|
|||||||
attributes:
|
attributes:
|
||||||
label: Commit where the problem happens
|
label: Commit where the problem happens
|
||||||
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: platforms
|
id: platforms
|
||||||
attributes:
|
attributes:
|
||||||
|
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: WebUI Community Support
|
||||||
|
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
|
||||||
|
about: Please ask and answer questions here.
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -27,4 +27,5 @@ __pycache__
|
|||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
/textual_inversion
|
/textual_inversion
|
||||||
.vscode
|
.vscode
|
||||||
|
/extensions
|
||||||
|
39
README.md
39
README.md
@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- One click install and run script (but you still must install python and git)
|
- One click install and run script (but you still must install python and git)
|
||||||
- Outpainting
|
- Outpainting
|
||||||
- Inpainting
|
- Inpainting
|
||||||
|
- Color Sketch
|
||||||
- Prompt Matrix
|
- Prompt Matrix
|
||||||
- Stable Diffusion Upscale
|
- Stable Diffusion Upscale
|
||||||
- Attention, specify parts of text that the model should pay more attention to
|
- Attention, specify parts of text that the model should pay more attention to
|
||||||
@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- have as many embeddings as you want and use any names you like for them
|
- have as many embeddings as you want and use any names you like for them
|
||||||
- use multiple embeddings with different numbers of vectors per token
|
- use multiple embeddings with different numbers of vectors per token
|
||||||
- works with half precision floating point numbers
|
- works with half precision floating point numbers
|
||||||
|
- train embeddings on 8GB (also reports of 6GB working)
|
||||||
- Extras tab with:
|
- Extras tab with:
|
||||||
- GFPGAN, neural network that fixes faces
|
- GFPGAN, neural network that fixes faces
|
||||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||||
@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- Interrupt processing at any time
|
- Interrupt processing at any time
|
||||||
- 4GB video card support (also reports of 2GB working)
|
- 4GB video card support (also reports of 2GB working)
|
||||||
- Correct seeds for batches
|
- Correct seeds for batches
|
||||||
- Prompt length validation
|
- Live prompt token length validation
|
||||||
- get length of prompt in tokens as you type
|
|
||||||
- get a warning after generation if some text was truncated
|
|
||||||
- Generation parameters
|
- Generation parameters
|
||||||
- parameters you used to generate images are saved with that image
|
- parameters you used to generate images are saved with that image
|
||||||
- in PNG chunks for PNG, in EXIF for JPEG
|
- in PNG chunks for PNG, in EXIF for JPEG
|
||||||
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||||
- can be disabled in settings
|
- can be disabled in settings
|
||||||
|
- drag and drop an image/text-parameters to promptbox
|
||||||
|
- Read Generation Parameters Button, loads parameters in promptbox to UI
|
||||||
- Settings page
|
- Settings page
|
||||||
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||||
- Mouseover hints for most UI elements
|
- Mouseover hints for most UI elements
|
||||||
@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- CLIP interrogator, a button that tries to guess prompt from an image
|
- CLIP interrogator, a button that tries to guess prompt from an image
|
||||||
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
||||||
- Batch Processing, process a group of files using img2img
|
- Batch Processing, process a group of files using img2img
|
||||||
- Img2img Alternative
|
- Img2img Alternative, reverse Euler method of cross attention control
|
||||||
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
||||||
- Reloading checkpoints on the fly
|
- Reloading checkpoints on the fly
|
||||||
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
|
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
|
||||||
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
||||||
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
|
||||||
- separate prompts using uppercase `AND`
|
- separate prompts using uppercase `AND`
|
||||||
@ -70,14 +72,35 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
|
||||||
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
|
||||||
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
|
||||||
|
- History tab: view, direct and delete images conveniently within the UI
|
||||||
|
- Generate forever option
|
||||||
|
- Training tab
|
||||||
|
- hypernetworks and embeddings options
|
||||||
|
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
|
||||||
|
- Clip skip
|
||||||
|
- Use Hypernetworks
|
||||||
|
- Use VAEs
|
||||||
|
- Estimated completion time in progress bar
|
||||||
|
- API
|
||||||
|
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||||
|
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||||
|
|
||||||
|
## Where are Aesthetic Gradients?!?!
|
||||||
|
Aesthetic Gradients are now an extension. You can install it using git:
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
|
||||||
|
```
|
||||||
|
|
||||||
|
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
|
||||||
|
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
|
||||||
Alternatively, use Google Colab:
|
Alternatively, use online services (like Google Colab):
|
||||||
|
|
||||||
- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl)
|
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||||
- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
|
|
||||||
|
|
||||||
### Automatic Installation on Windows
|
### Automatic Installation on Windows
|
||||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||||
|
0
extensions/put extension here.txt
Normal file
0
extensions/put extension here.txt
Normal file
@ -3,12 +3,12 @@ let currentWidth = null;
|
|||||||
let currentHeight = null;
|
let currentHeight = null;
|
||||||
let arFrameTimeout = setTimeout(function(){},0);
|
let arFrameTimeout = setTimeout(function(){},0);
|
||||||
|
|
||||||
function dimensionChange(e,dimname){
|
function dimensionChange(e, is_width, is_height){
|
||||||
|
|
||||||
if(dimname == 'Width'){
|
if(is_width){
|
||||||
currentWidth = e.target.value*1.0
|
currentWidth = e.target.value*1.0
|
||||||
}
|
}
|
||||||
if(dimname == 'Height'){
|
if(is_height){
|
||||||
currentHeight = e.target.value*1.0
|
currentHeight = e.target.value*1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
|
|
||||||
if(img2imgMode){
|
|
||||||
img2imgMode=img2imgMode.innerText
|
|
||||||
}else{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
|
|
||||||
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
|
|
||||||
|
|
||||||
var targetElement = null;
|
var targetElement = null;
|
||||||
|
|
||||||
if(img2imgMode=='img2img' && redrawImage){
|
var tabIndex = get_tab_index('mode_img2img')
|
||||||
targetElement = redrawImage;
|
if(tabIndex == 0){
|
||||||
}else if(img2imgMode=='Inpaint' && inpaintImage){
|
targetElement = gradioApp().querySelector('div[data-testid=image] img');
|
||||||
targetElement = inpaintImage;
|
} else if(tabIndex == 1){
|
||||||
|
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
|
||||||
}
|
}
|
||||||
|
|
||||||
if(targetElement){
|
if(targetElement){
|
||||||
@ -98,22 +89,20 @@ onUiUpdate(function(){
|
|||||||
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
|
||||||
if(inImg2img){
|
if(inImg2img){
|
||||||
let inputs = gradioApp().querySelectorAll('input');
|
let inputs = gradioApp().querySelectorAll('input');
|
||||||
inputs.forEach(function(e){
|
inputs.forEach(function(e){
|
||||||
let parentLabel = e.parentElement.querySelector('label')
|
var is_width = e.parentElement.id == "img2img_width"
|
||||||
if(parentLabel && parentLabel.innerText){
|
var is_height = e.parentElement.id == "img2img_height"
|
||||||
if(!e.classList.contains('scrollwatch')){
|
|
||||||
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
|
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
|
||||||
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
|
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
|
||||||
e.classList.add('scrollwatch')
|
e.classList.add('scrollwatch')
|
||||||
}
|
}
|
||||||
if(parentLabel.innerText == 'Width'){
|
if(is_width){
|
||||||
currentWidth = e.value*1.0
|
currentWidth = e.value*1.0
|
||||||
}
|
}
|
||||||
if(parentLabel.innerText == 'Height'){
|
if(is_height){
|
||||||
currentHeight = e.value*1.0
|
currentHeight = e.value*1.0
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
2
javascript/dragdrop.js
vendored
2
javascript/dragdrop.js
vendored
@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
|
|||||||
window.document.addEventListener('dragover', e => {
|
window.document.addEventListener('dragover', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const imgWrap = target.closest('[data-testid="image"]');
|
const imgWrap = target.closest('[data-testid="image"]');
|
||||||
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
|
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.stopPropagation();
|
e.stopPropagation();
|
||||||
|
@ -17,14 +17,6 @@ var images_history_click_image = function(){
|
|||||||
images_history_set_image_info(this);
|
images_history_set_image_info(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
var images_history_click_tab = function(){
|
|
||||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
|
||||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
|
||||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
|
||||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function images_history_disabled_del(){
|
function images_history_disabled_del(){
|
||||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||||
btn.setAttribute('disabled','disabled');
|
btn.setAttribute('disabled','disabled');
|
||||||
@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){
|
|||||||
var parent = item.parentElement;
|
var parent = item.parentElement;
|
||||||
tagname = tagname.toUpperCase()
|
tagname = tagname.toUpperCase()
|
||||||
while(parent.tagName != tagname){
|
while(parent.tagName != tagname){
|
||||||
console.log(parent.tagName, tagname)
|
|
||||||
parent = parent.parentElement;
|
parent = parent.parentElement;
|
||||||
}
|
}
|
||||||
return parent;
|
return parent;
|
||||||
@ -88,15 +79,15 @@ function images_history_set_image_info(button){
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function images_history_get_current_img(tabname, image_path, files){
|
function images_history_get_current_img(tabname, img_index, files){
|
||||||
return [
|
return [
|
||||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
tabname,
|
||||||
image_path,
|
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
||||||
files
|
files
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
function images_history_delete(del_num, tabname, image_index){
|
||||||
image_index = parseInt(image_index);
|
image_index = parseInt(image_index);
|
||||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
var tab = gradioApp().getElementById(tabname + '_images_history');
|
||||||
var set_btn = tab.querySelector(".images_history_set_index");
|
var set_btn = tab.querySelector(".images_history_set_index");
|
||||||
@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
var img_num = buttons.length / 2;
|
var img_num = buttons.length / 2;
|
||||||
|
del_num = Math.min(img_num - image_index, del_num)
|
||||||
if (img_num <= del_num){
|
if (img_num <= del_num){
|
||||||
setTimeout(function(tabname){
|
setTimeout(function(tabname){
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||||
@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
|
|||||||
} else {
|
} else {
|
||||||
var next_img
|
var next_img
|
||||||
for (var i = 0; i < del_num; i++){
|
for (var i = 0; i < del_num; i++){
|
||||||
if (image_index + i < image_index + img_num){
|
buttons[image_index + i].style.display = 'none';
|
||||||
buttons[image_index + i].style.display = 'none';
|
buttons[image_index + i + img_num].style.display = 'none';
|
||||||
buttons[image_index + img_num + 1].style.display = 'none';
|
next_img = image_index + i + 1
|
||||||
next_img = image_index + i + 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var bnt;
|
var bnt;
|
||||||
if (next_img >= img_num){
|
if (next_img >= img_num){
|
||||||
btn = buttons[image_index - del_num];
|
btn = buttons[image_index - 1];
|
||||||
} else {
|
} else {
|
||||||
btn = buttons[next_img];
|
btn = buttons[next_img];
|
||||||
}
|
}
|
||||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
setTimeout(function(btn){btn.click()}, 30, btn);
|
||||||
}
|
}
|
||||||
images_history_disabled_del();
|
images_history_disabled_del();
|
||||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
function images_history_turnpage(tabname){
|
||||||
|
gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled');
|
||||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
||||||
buttons.forEach(function(elem) {
|
buttons.forEach(function(elem) {
|
||||||
elem.style.display = 'block';
|
elem.style.display = 'block';
|
||||||
})
|
})
|
||||||
return [img_path, page_index, image_index, tabname];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function images_history_enable_del_buttons(){
|
function images_history_enable_del_buttons(){
|
||||||
@ -147,60 +137,64 @@ function images_history_enable_del_buttons(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function images_history_init(){
|
function images_history_init(){
|
||||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
var tabnames = gradioApp().getElementById("images_history_tabnames_list")
|
||||||
if (load_txt2img_button){
|
if (tabnames){
|
||||||
|
images_history_tab_list = tabnames.querySelector("textarea").value.split(",")
|
||||||
for (var i in images_history_tab_list ){
|
for (var i in images_history_tab_list ){
|
||||||
tab = images_history_tab_list[i];
|
var tab = images_history_tab_list[i];
|
||||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
||||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
||||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
||||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
||||||
|
gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px");
|
||||||
}
|
}
|
||||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
|
||||||
tabs_box.setAttribute("id", "images_history_tab");
|
|
||||||
var tab_btns = tabs_box.querySelectorAll("button");
|
|
||||||
for (var i in images_history_tab_list){
|
|
||||||
var tabname = images_history_tab_list[i]
|
|
||||||
tab_btns[i].setAttribute("tabname", tabname);
|
|
||||||
|
|
||||||
// this refreshes history upon tab switch
|
//preload
|
||||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){
|
||||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
||||||
}
|
tabs_box.setAttribute("id", "images_history_tab");
|
||||||
tabs_box.classList.add(images_history_tab_list[0]);
|
var tab_btns = tabs_box.querySelectorAll("button");
|
||||||
|
for (var i in images_history_tab_list){
|
||||||
// same as above, at page load
|
var tabname = images_history_tab_list[i]
|
||||||
//load_txt2img_button.click();
|
tab_btns[i].setAttribute("tabname", tabname);
|
||||||
|
tab_btns[i].addEventListener('click', function(){
|
||||||
|
var tabs_box = gradioApp().getElementById("images_history_tab");
|
||||||
|
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
||||||
|
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click();
|
||||||
|
tabs_box.classList.add(this.getAttribute("tabname"))
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
tab_btns[0].click()
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setTimeout(images_history_init, 500);
|
setTimeout(images_history_init, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
var images_history_tab_list = "";
|
||||||
setTimeout(images_history_init, 500);
|
setTimeout(images_history_init, 500);
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
document.addEventListener("DOMContentLoaded", function() {
|
||||||
var mutationObserver = new MutationObserver(function(m){
|
var mutationObserver = new MutationObserver(function(m){
|
||||||
for (var i in images_history_tab_list ){
|
if (images_history_tab_list != ""){
|
||||||
let tabname = images_history_tab_list[i]
|
for (var i in images_history_tab_list ){
|
||||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
let tabname = images_history_tab_list[i]
|
||||||
buttons.forEach(function(bnt){
|
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
||||||
bnt.addEventListener('click', images_history_click_image, true);
|
buttons.forEach(function(bnt){
|
||||||
});
|
bnt.addEventListener('click', images_history_click_image, true);
|
||||||
|
});
|
||||||
|
|
||||||
// same as load_txt2img_button.click() above
|
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||||
/*
|
if (cls_btn){
|
||||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
cls_btn.addEventListener('click', function(){
|
||||||
if (cls_btn){
|
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||||
cls_btn.addEventListener('click', function(){
|
}, false);
|
||||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
}
|
||||||
}, false);
|
|
||||||
}*/
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
mutationObserver.observe(gradioApp(), { childList:true, subtree:true });
|
||||||
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
|
import sys, os, shlex
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||||
@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
|
|||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
|
|
||||||
|
def extract_device_id(args, name):
|
||||||
|
for x in range(len(args)):
|
||||||
|
if name in args[x]: return args[x+1]
|
||||||
|
return None
|
||||||
|
|
||||||
def get_optimal_device():
|
def get_optimal_device():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return torch.device("cuda")
|
from modules import shared
|
||||||
|
|
||||||
|
device_id = shared.cmd_opts.device_id
|
||||||
|
|
||||||
|
if device_id is not None:
|
||||||
|
cuda_device = f"cuda:{device_id}"
|
||||||
|
return torch.device(cuda_device)
|
||||||
|
else:
|
||||||
|
return torch.device("cuda")
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
@ -34,7 +45,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
dtype_vae = torch.float16
|
dtype_vae = torch.float16
|
||||||
|
|
||||||
|
@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||||||
|
|
||||||
if input_dir == '':
|
if input_dir == '':
|
||||||
return outputs, "Please select an input directory.", ''
|
return outputs, "Please select an input directory.", ''
|
||||||
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
|
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
|
||||||
for img in image_list:
|
for img in image_list:
|
||||||
image = Image.open(img)
|
try:
|
||||||
|
image = Image.open(img)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
imageArr.append(image)
|
imageArr.append(image)
|
||||||
imageNameArr.append(img)
|
imageNameArr.append(img)
|
||||||
else:
|
else:
|
||||||
@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
|
|||||||
|
|
||||||
while len(cached_images) > 2:
|
while len(cached_images) > 2:
|
||||||
del cached_images[next(iter(cached_images.keys()))]
|
del cached_images[next(iter(cached_images.keys()))]
|
||||||
|
|
||||||
|
if opts.use_original_name_batch and image_name != None:
|
||||||
|
basename = os.path.splitext(os.path.basename(image_name))[0]
|
||||||
|
else:
|
||||||
|
basename = ''
|
||||||
|
|
||||||
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
|
||||||
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
|
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
|
||||||
forced_filename=image_name if opts.use_original_name_batch else None)
|
|
||||||
|
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info = existing_pnginfo
|
image.info = existing_pnginfo
|
||||||
|
@ -4,13 +4,22 @@ import gradio as gr
|
|||||||
from modules.shared import script_path
|
from modules.shared import script_path
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
re_param = re.compile(re_param_code)
|
re_param = re.compile(re_param_code)
|
||||||
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
|
||||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||||
type_of_gr_update = type(gr.update())
|
type_of_gr_update = type(gr.update())
|
||||||
|
|
||||||
|
|
||||||
|
def quote(text):
|
||||||
|
if ',' not in str(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
text = str(text)
|
||||||
|
text = text.replace('\\', '\\\\')
|
||||||
|
text = text.replace('"', '\\"')
|
||||||
|
return f'"{text}"'
|
||||||
|
|
||||||
def parse_generation_parameters(x: str):
|
def parse_generation_parameters(x: str):
|
||||||
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
"""parses generation parameters string, the one you see in text field under the picture in UI:
|
||||||
```
|
```
|
||||||
@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
valtype = type(output.value)
|
valtype = type(output.value)
|
||||||
val = valtype(v)
|
|
||||||
|
if valtype == bool and v == "False":
|
||||||
|
val = False
|
||||||
|
else:
|
||||||
|
val = valtype(v)
|
||||||
|
|
||||||
res.append(gr.update(value=val))
|
res.append(gr.update(value=val))
|
||||||
except Exception:
|
except Exception:
|
||||||
res.append(gr.update())
|
res.append(gr.update())
|
||||||
|
@ -41,12 +41,12 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||||
|
|
||||||
# Add an activation func
|
# Add an activation func
|
||||||
if activation_func == "linear":
|
if activation_func == "linear" or activation_func is None:
|
||||||
pass
|
pass
|
||||||
elif activation_func in self.activation_dict:
|
elif activation_func in self.activation_dict:
|
||||||
linears.append(self.activation_dict[activation_func]())
|
linears.append(self.activation_dict[activation_func]())
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise RuntimeError(
|
||||||
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
self.load_state_dict(state_dict)
|
self.load_state_dict(state_dict)
|
||||||
else:
|
else:
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if isinstance(layer, torch.nn.Linear):
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||||
layer.bias.data.zero_()
|
layer.bias.data.zero_()
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ class HypernetworkModule(torch.nn.Module):
|
|||||||
def trainables(self):
|
def trainables(self):
|
||||||
layer_structure = []
|
layer_structure = []
|
||||||
for layer in self.linear:
|
for layer in self.linear:
|
||||||
if isinstance(layer, torch.nn.Linear):
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||||
layer_structure += [layer.weight, layer.bias]
|
layer_structure += [layer.weight, layer.bias]
|
||||||
return layer_structure
|
return layer_structure
|
||||||
|
|
||||||
@ -272,6 +272,9 @@ def stack_conds(conds):
|
|||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
|
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||||
|
from modules import images
|
||||||
|
|
||||||
assert hypernetwork_name, 'hypernetwork not selected'
|
assert hypernetwork_name, 'hypernetwork not selected'
|
||||||
|
|
||||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||||
@ -314,6 +317,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
|
|
||||||
last_saved_file = "<none>"
|
last_saved_file = "<none>"
|
||||||
last_saved_image = "<none>"
|
last_saved_image = "<none>"
|
||||||
|
forced_filename = "<none>"
|
||||||
|
|
||||||
ititial_step = hypernetwork.step or 0
|
ititial_step = hypernetwork.step or 0
|
||||||
if ititial_step > steps:
|
if ititial_step > steps:
|
||||||
@ -353,7 +357,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
pbar.set_description(f"loss: {mean_loss:.7f}")
|
pbar.set_description(f"loss: {mean_loss:.7f}")
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
# Before saving, change name to match current checkpoint.
|
||||||
|
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||||
|
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||||
hypernetwork.save(last_saved_file)
|
hypernetwork.save(last_saved_file)
|
||||||
|
|
||||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||||
@ -362,7 +368,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
})
|
})
|
||||||
|
|
||||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||||
|
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
shared.sd_model.cond_stage_model.to(devices.device)
|
shared.sd_model.cond_stage_model.to(devices.device)
|
||||||
@ -398,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
shared.state.current_image = image
|
shared.state.current_image = image
|
||||||
image.save(last_saved_image)
|
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||||
last_saved_image += f", prompt: {preview_text}"
|
last_saved_image += f", prompt: {preview_text}"
|
||||||
|
|
||||||
shared.state.job_no = hypernetwork.step
|
shared.state.job_no = hypernetwork.step
|
||||||
@ -408,7 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
|||||||
Loss: {mean_loss:.7f}<br/>
|
Loss: {mean_loss:.7f}<br/>
|
||||||
Step: {hypernetwork.step}<br/>
|
Step: {hypernetwork.step}<br/>
|
||||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
|
||||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
</p>
|
</p>
|
||||||
"""
|
"""
|
||||||
@ -417,6 +424,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
|
|||||||
|
|
||||||
hypernetwork.sd_checkpoint = checkpoint.hash
|
hypernetwork.sd_checkpoint = checkpoint.hash
|
||||||
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
hypernetwork.sd_checkpoint_name = checkpoint.model_name
|
||||||
|
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
|
||||||
|
hypernetwork.name = hypernetwork_name
|
||||||
|
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
|
||||||
hypernetwork.save(filename)
|
hypernetwork.save(filename)
|
||||||
|
|
||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
@ -9,9 +9,13 @@ from modules import devices, sd_hijack, shared
|
|||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
|
||||||
|
# Remove illegal characters from name.
|
||||||
|
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
if type(layer_structure) == str:
|
if type(layer_structure) == str:
|
||||||
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
|
||||||
|
@ -1,183 +1,424 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import time
|
||||||
|
import hashlib
|
||||||
|
import gradio
|
||||||
|
system_bak_path = "webui_log_and_bak"
|
||||||
|
custom_tab_name = "custom fold"
|
||||||
|
faverate_tab_name = "favorites"
|
||||||
|
tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
|
||||||
|
def is_valid_date(date):
|
||||||
|
try:
|
||||||
|
time.strptime(date, "%Y%m%d")
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
def reduplicative_file_move(src, dst):
|
||||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
def same_name_file(basename, path):
|
||||||
|
name, ext = os.path.splitext(basename)
|
||||||
|
f_list = os.listdir(path)
|
||||||
|
max_num = 0
|
||||||
|
for f in f_list:
|
||||||
|
if len(f) <= len(basename):
|
||||||
|
continue
|
||||||
|
f_ext = f[-len(ext):] if len(ext) > 0 else ""
|
||||||
|
if f[:len(name)] == name and f_ext == ext:
|
||||||
|
if f[len(name)] == "(" and f[-len(ext)-1] == ")":
|
||||||
|
number = f[len(name)+1:-len(ext)-1]
|
||||||
|
if number.isdigit():
|
||||||
|
if int(number) > max_num:
|
||||||
|
max_num = int(number)
|
||||||
|
return f"{name}({max_num + 1}){ext}"
|
||||||
|
name = os.path.basename(src)
|
||||||
|
save_name = os.path.join(dst, name)
|
||||||
|
if not os.path.exists(save_name):
|
||||||
|
shutil.move(src, dst)
|
||||||
|
else:
|
||||||
|
name = same_name_file(name, dst)
|
||||||
|
shutil.move(src, os.path.join(dst, name))
|
||||||
|
|
||||||
|
def traverse_all_files(curr_path, image_list, all_type=False):
|
||||||
try:
|
try:
|
||||||
f_list = os.listdir(curr_path)
|
f_list = os.listdir(curr_path)
|
||||||
except:
|
except:
|
||||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
|
||||||
image_list.append(curr_dir)
|
image_list.append(curr_path)
|
||||||
return image_list
|
return image_list
|
||||||
for file in f_list:
|
for file in f_list:
|
||||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
file = os.path.join(curr_path, file)
|
||||||
file_path = os.path.join(curr_path, file)
|
if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
|
||||||
if file[-4:] == ".txt":
|
|
||||||
pass
|
pass
|
||||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
|
||||||
image_list.append(file)
|
image_list.append(file)
|
||||||
else:
|
else:
|
||||||
image_list = traverse_all_files(output_dir, image_list, file)
|
image_list = traverse_all_files(file, image_list)
|
||||||
return image_list
|
return image_list
|
||||||
|
|
||||||
|
def auto_sorting(dir_name):
|
||||||
|
bak_path = os.path.join(dir_name, system_bak_path)
|
||||||
|
if not os.path.exists(bak_path):
|
||||||
|
os.mkdir(bak_path)
|
||||||
|
log_file = None
|
||||||
|
files_list = []
|
||||||
|
f_list = os.listdir(dir_name)
|
||||||
|
for file in f_list:
|
||||||
|
if file == system_bak_path:
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(dir_name, file)
|
||||||
|
if not is_valid_date(file):
|
||||||
|
if file[-10:].rfind(".") > 0:
|
||||||
|
files_list.append(file_path)
|
||||||
|
else:
|
||||||
|
files_list = traverse_all_files(file_path, files_list, all_type=True)
|
||||||
|
|
||||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
for file in files_list:
|
||||||
page_index = int(page_index)
|
date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
|
||||||
image_list = []
|
file_path = os.path.dirname(file)
|
||||||
if not os.path.exists(dir_name):
|
hash_path = hashlib.md5(file_path.encode()).hexdigest()
|
||||||
pass
|
path = os.path.join(dir_name, date_str, hash_path)
|
||||||
elif os.path.isdir(dir_name):
|
if not os.path.exists(path):
|
||||||
image_list = traverse_all_files(dir_name, image_list)
|
os.makedirs(path)
|
||||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
if log_file is None:
|
||||||
|
log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
|
||||||
|
log_file.write(f"{hash_path},{file_path}\n")
|
||||||
|
reduplicative_file_move(file, path)
|
||||||
|
|
||||||
|
date_list = []
|
||||||
|
f_list = os.listdir(dir_name)
|
||||||
|
for f in f_list:
|
||||||
|
if is_valid_date(f):
|
||||||
|
date_list.append(f)
|
||||||
|
elif f == system_bak_path:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
reduplicative_file_move(os.path.join(dir_name, f), bak_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||||
|
if today not in date_list:
|
||||||
|
date_list.append(today)
|
||||||
|
return sorted(date_list, reverse=True)
|
||||||
|
|
||||||
|
def archive_images(dir_name, date_to):
|
||||||
|
filenames = []
|
||||||
|
batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
|
||||||
|
if batch_size <= 0:
|
||||||
|
batch_size = opts.images_history_num_per_page * 6
|
||||||
|
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||||
|
date_to = today if date_to is None or date_to == "" else date_to
|
||||||
|
date_to_bak = date_to
|
||||||
|
if False: #opts.images_history_reconstruct_directory:
|
||||||
|
date_list = auto_sorting(dir_name)
|
||||||
|
for date in date_list:
|
||||||
|
if date <= date_to:
|
||||||
|
path = os.path.join(dir_name, date)
|
||||||
|
if date == today and not os.path.exists(path):
|
||||||
|
continue
|
||||||
|
filenames = traverse_all_files(path, filenames)
|
||||||
|
if len(filenames) > batch_size:
|
||||||
|
break
|
||||||
|
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
|
||||||
else:
|
else:
|
||||||
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
|
filenames = traverse_all_files(dir_name, filenames)
|
||||||
num = 48 if tabname != "extras" else 12
|
total_num = len(filenames)
|
||||||
max_page_index = len(image_list) // num + 1
|
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
|
||||||
page_index = max_page_index if page_index == -1 else page_index + step
|
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
|
||||||
page_index = 1 if page_index < 1 else page_index
|
filenames = []
|
||||||
page_index = max_page_index if page_index > max_page_index else page_index
|
date_list = {date_to:None}
|
||||||
idx_frm = (page_index - 1) * num
|
date = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||||
image_list = image_list[idx_frm:idx_frm + num]
|
for t, f in tmparray:
|
||||||
image_index = int(image_index)
|
date = time.strftime("%Y%m%d",time.localtime(t))
|
||||||
if image_index < 0 or image_index > len(image_list) - 1:
|
date_list[date] = None
|
||||||
current_file = None
|
if t <= date_stamp:
|
||||||
hidden = None
|
filenames.append((t, f ,date))
|
||||||
else:
|
date_list = sorted(list(date_list.keys()), reverse=True)
|
||||||
current_file = image_list[int(image_index)]
|
sort_array = sorted(filenames, key=lambda x:-x[0])
|
||||||
hidden = os.path.join(dir_name, current_file)
|
if len(sort_array) > batch_size:
|
||||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
date = sort_array[batch_size][2]
|
||||||
|
filenames = [x[1] for x in sort_array]
|
||||||
|
else:
|
||||||
|
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
|
||||||
|
filenames = [x[1] for x in sort_array]
|
||||||
|
filenames = [x[1] for x in sort_array if x[2]>= date]
|
||||||
|
num = len(filenames)
|
||||||
|
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
|
||||||
|
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
|
||||||
|
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
|
||||||
|
load_info = "<div style='color:#999' align='center'>"
|
||||||
|
load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
|
||||||
|
load_info += "</div>"
|
||||||
|
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
|
||||||
|
return (
|
||||||
|
date_to,
|
||||||
|
load_info,
|
||||||
|
filenames,
|
||||||
|
1,
|
||||||
|
image_list,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
visible_num,
|
||||||
|
last_date_from,
|
||||||
|
gradio.update(visible=total_num > num)
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_image(delete_num, name, filenames, image_index, visible_num):
|
||||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
|
||||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
|
||||||
|
|
||||||
|
|
||||||
def show_image_info(num, image_path, filenames):
|
|
||||||
# print(f"select image {num}")
|
|
||||||
file = filenames[int(num)]
|
|
||||||
return file, num, os.path.join(image_path, file)
|
|
||||||
|
|
||||||
|
|
||||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
|
||||||
if name == "":
|
if name == "":
|
||||||
return filenames, delete_num
|
return filenames, delete_num
|
||||||
else:
|
else:
|
||||||
delete_num = int(delete_num)
|
delete_num = int(delete_num)
|
||||||
|
visible_num = int(visible_num)
|
||||||
|
image_index = int(image_index)
|
||||||
index = list(filenames).index(name)
|
index = list(filenames).index(name)
|
||||||
i = 0
|
i = 0
|
||||||
new_file_list = []
|
new_file_list = []
|
||||||
for name in filenames:
|
for name in filenames:
|
||||||
if i >= index and i < index + delete_num:
|
if i >= index and i < index + delete_num:
|
||||||
path = os.path.join(dir_name, name)
|
if os.path.exists(name):
|
||||||
if os.path.exists(path):
|
if visible_num == image_index:
|
||||||
print(f"Delete file {path}")
|
new_file_list.append(name)
|
||||||
os.remove(path)
|
i += 1
|
||||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
continue
|
||||||
|
print(f"Delete file {name}")
|
||||||
|
os.remove(name)
|
||||||
|
visible_num -= 1
|
||||||
|
txt_file = os.path.splitext(name)[0] + ".txt"
|
||||||
if os.path.exists(txt_file):
|
if os.path.exists(txt_file):
|
||||||
os.remove(txt_file)
|
os.remove(txt_file)
|
||||||
else:
|
else:
|
||||||
print(f"Not exists file {path}")
|
print(f"Not exists file {name}")
|
||||||
else:
|
else:
|
||||||
new_file_list.append(name)
|
new_file_list.append(name)
|
||||||
i += 1
|
i += 1
|
||||||
return new_file_list, 1
|
return new_file_list, 1, visible_num
|
||||||
|
|
||||||
|
def save_image(file_name):
|
||||||
|
if file_name is not None and os.path.exists(file_name):
|
||||||
|
shutil.copy(file_name, opts.outdir_save)
|
||||||
|
|
||||||
|
def get_recent_images(page_index, step, filenames):
|
||||||
|
page_index = int(page_index)
|
||||||
|
num_of_imgs_per_page = int(opts.images_history_num_per_page)
|
||||||
|
max_page_index = len(filenames) // num_of_imgs_per_page + 1
|
||||||
|
page_index = max_page_index if page_index == -1 else page_index + step
|
||||||
|
page_index = 1 if page_index < 1 else page_index
|
||||||
|
page_index = max_page_index if page_index > max_page_index else page_index
|
||||||
|
idx_frm = (page_index - 1) * num_of_imgs_per_page
|
||||||
|
image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
|
||||||
|
length = len(filenames)
|
||||||
|
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
|
||||||
|
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
|
||||||
|
return page_index, image_list, "", "", visible_num
|
||||||
|
|
||||||
|
def loac_batch_click(date_to):
|
||||||
|
if date_to is None:
|
||||||
|
return time.strftime("%Y%m%d",time.localtime(time.time())), []
|
||||||
|
else:
|
||||||
|
return None, []
|
||||||
|
def forward_click(last_date_from, date_to_recorder):
|
||||||
|
if len(date_to_recorder) == 0:
|
||||||
|
return None, []
|
||||||
|
if last_date_from == date_to_recorder[-1]:
|
||||||
|
date_to_recorder = date_to_recorder[:-1]
|
||||||
|
if len(date_to_recorder) == 0:
|
||||||
|
return None, []
|
||||||
|
return date_to_recorder[-1], date_to_recorder[:-1]
|
||||||
|
|
||||||
|
def backward_click(last_date_from, date_to_recorder):
|
||||||
|
if last_date_from is None or last_date_from == "":
|
||||||
|
return time.strftime("%Y%m%d",time.localtime(time.time())), []
|
||||||
|
if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
|
||||||
|
date_to_recorder.append(last_date_from)
|
||||||
|
return last_date_from, date_to_recorder
|
||||||
|
|
||||||
|
|
||||||
|
def first_page_click(page_index, filenames):
|
||||||
|
return get_recent_images(1, 0, filenames)
|
||||||
|
|
||||||
|
def end_page_click(page_index, filenames):
|
||||||
|
return get_recent_images(-1, 0, filenames)
|
||||||
|
|
||||||
|
def prev_page_click(page_index, filenames):
|
||||||
|
return get_recent_images(page_index, -1, filenames)
|
||||||
|
|
||||||
|
def next_page_click(page_index, filenames):
|
||||||
|
return get_recent_images(page_index, 1, filenames)
|
||||||
|
|
||||||
|
def page_index_change(page_index, filenames):
|
||||||
|
return get_recent_images(page_index, 0, filenames)
|
||||||
|
|
||||||
|
def show_image_info(tabname_box, num, page_index, filenames):
|
||||||
|
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
|
||||||
|
tm = "<div style='color:#999' align='right'>" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "</div>"
|
||||||
|
return file, tm, num, file
|
||||||
|
|
||||||
|
def enable_page_buttons():
|
||||||
|
return gradio.update(visible=True)
|
||||||
|
|
||||||
|
def change_dir(img_dir, date_to):
|
||||||
|
warning = None
|
||||||
|
try:
|
||||||
|
if os.path.exists(img_dir):
|
||||||
|
try:
|
||||||
|
f = os.listdir(img_dir)
|
||||||
|
except:
|
||||||
|
warning = f"'{img_dir} is not a directory"
|
||||||
|
else:
|
||||||
|
warning = "The directory is not exist"
|
||||||
|
except:
|
||||||
|
warning = "The format of the directory is incorrect"
|
||||||
|
if warning is None:
|
||||||
|
today = time.strftime("%Y%m%d",time.localtime(time.time()))
|
||||||
|
return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
|
||||||
|
else:
|
||||||
|
return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
|
||||||
|
|
||||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||||
if opts.outdir_samples != "":
|
custom_dir = False
|
||||||
dir_name = opts.outdir_samples
|
if tabname == "txt2img":
|
||||||
elif tabname == "txt2img":
|
|
||||||
dir_name = opts.outdir_txt2img_samples
|
dir_name = opts.outdir_txt2img_samples
|
||||||
elif tabname == "img2img":
|
elif tabname == "img2img":
|
||||||
dir_name = opts.outdir_img2img_samples
|
dir_name = opts.outdir_img2img_samples
|
||||||
elif tabname == "extras":
|
elif tabname == "extras":
|
||||||
dir_name = opts.outdir_extras_samples
|
dir_name = opts.outdir_extras_samples
|
||||||
|
elif tabname == faverate_tab_name:
|
||||||
|
dir_name = opts.outdir_save
|
||||||
else:
|
else:
|
||||||
return
|
custom_dir = True
|
||||||
with gr.Row():
|
dir_name = None
|
||||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
|
||||||
first_page = gr.Button('First Page')
|
if not custom_dir:
|
||||||
prev_page = gr.Button('Prev Page')
|
d = dir_name.split("/")
|
||||||
page_index = gr.Number(value=1, label="Page Index")
|
dir_name = d[0]
|
||||||
next_page = gr.Button('Next Page')
|
for p in d[1:]:
|
||||||
end_page = gr.Button('End Page')
|
dir_name = os.path.join(dir_name, p)
|
||||||
with gr.Row(elem_id=tabname + "_images_history"):
|
if not os.path.exists(dir_name):
|
||||||
with gr.Row():
|
os.makedirs(dir_name)
|
||||||
with gr.Column(scale=2):
|
|
||||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
with gr.Column() as page_panel:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
|
||||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
|
||||||
with gr.Column():
|
with gr.Column(scale=4):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
|
||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
with gr.Row():
|
||||||
with gr.Row():
|
with gr.Column(visible=False, scale=1) as batch_panel:
|
||||||
with gr.Column():
|
with gr.Row():
|
||||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
forward = gr.Button('Prev batch')
|
||||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
backward = gr.Button('Next batch')
|
||||||
with gr.Row():
|
with gr.Column(scale=3):
|
||||||
|
load_info = gr.HTML(visible=not custom_dir)
|
||||||
|
with gr.Row(visible=False) as warning:
|
||||||
|
warning_box = gr.Textbox("Message", interactive=False)
|
||||||
|
|
||||||
|
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
with gr.Row(visible=True) as turn_page_buttons:
|
||||||
|
#date_to = gr.Dropdown(label="Date to")
|
||||||
|
first_page = gr.Button('First Page')
|
||||||
|
prev_page = gr.Button('Prev Page')
|
||||||
|
page_index = gr.Number(value=1, label="Page Index")
|
||||||
|
next_page = gr.Button('Next Page')
|
||||||
|
end_page = gr.Button('End Page')
|
||||||
|
|
||||||
|
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
|
||||||
|
with gr.Row():
|
||||||
|
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||||
|
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
|
||||||
|
gr.HTML("<hr>")
|
||||||
|
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
|
||||||
|
img_file_time= gr.HTML()
|
||||||
|
with gr.Row():
|
||||||
|
if tabname != faverate_tab_name:
|
||||||
|
save_btn = gr.Button('Collect')
|
||||||
|
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||||
|
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
|
||||||
|
|
||||||
# hiden items
|
# hiden items
|
||||||
|
with gr.Row(visible=False):
|
||||||
|
renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
|
||||||
|
batch_date_to = gr.Textbox(label="Date to")
|
||||||
|
visible_img_num = gr.Number()
|
||||||
|
date_to_recorder = gr.State([])
|
||||||
|
last_date_from = gr.Textbox()
|
||||||
|
tabname_box = gr.Textbox(tabname)
|
||||||
|
image_index = gr.Textbox(value=-1)
|
||||||
|
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
|
||||||
|
filenames = gr.State()
|
||||||
|
all_images_list = gr.State()
|
||||||
|
hidden = gr.Image(type="pil")
|
||||||
|
info1 = gr.Textbox()
|
||||||
|
info2 = gr.Textbox()
|
||||||
|
|
||||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
|
||||||
tabname_box = gr.Textbox(tabname, visible=False)
|
|
||||||
image_index = gr.Textbox(value=-1, visible=False)
|
|
||||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
|
||||||
filenames = gr.State()
|
|
||||||
hidden = gr.Image(type="pil", visible=False)
|
|
||||||
info1 = gr.Textbox(visible=False)
|
|
||||||
info2 = gr.Textbox(visible=False)
|
|
||||||
|
|
||||||
# turn pages
|
#change batch
|
||||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
|
||||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
|
||||||
|
batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
|
||||||
|
batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
|
||||||
|
batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
|
||||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
|
||||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
|
||||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
|
||||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
|
||||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
#delete
|
||||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
|
||||||
|
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
|
||||||
|
if tabname != faverate_tab_name:
|
||||||
|
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
|
||||||
|
|
||||||
|
#turn page
|
||||||
|
gallery_inputs = [page_index, filenames]
|
||||||
|
gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
|
||||||
|
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
|
||||||
|
|
||||||
|
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
|
||||||
|
|
||||||
# other funcitons
|
# other funcitons
|
||||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
|
||||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
|
||||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||||
|
|
||||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
|
||||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
|
||||||
|
global opts;
|
||||||
|
opts = sys_opts
|
||||||
|
loads_files_num = int(opts.images_history_num_per_page)
|
||||||
|
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
|
||||||
|
if cmp_ops.browse_all_images:
|
||||||
|
tabs_list.append(custom_tab_name)
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
with gr.Tab("txt2img history"):
|
for tab in tabs_list:
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
with gr.Tab(tab):
|
||||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
with gr.Blocks(analytics_enabled=False) :
|
||||||
with gr.Tab("img2img history"):
|
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
|
||||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
|
||||||
with gr.Tab("extras history"):
|
|
||||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
|
||||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
|
||||||
return images_history
|
return images_history
|
||||||
|
@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
|
p.script_args = args
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
@ -28,9 +28,11 @@ class InterrogateModels:
|
|||||||
clip_preprocess = None
|
clip_preprocess = None
|
||||||
categories = None
|
categories = None
|
||||||
dtype = None
|
dtype = None
|
||||||
|
running_on_cpu = None
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.categories = []
|
self.categories = []
|
||||||
|
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||||
|
|
||||||
if os.path.exists(content_dir):
|
if os.path.exists(content_dir):
|
||||||
for filename in os.listdir(content_dir):
|
for filename in os.listdir(content_dir):
|
||||||
@ -53,7 +55,11 @@ class InterrogateModels:
|
|||||||
def load_clip_model(self):
|
def load_clip_model(self):
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
model, preprocess = clip.load(clip_model_name)
|
if self.running_on_cpu:
|
||||||
|
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||||
|
else:
|
||||||
|
model, preprocess = clip.load(clip_model_name)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(devices.device_interrogate)
|
model = model.to(devices.device_interrogate)
|
||||||
|
|
||||||
@ -62,14 +68,14 @@ class InterrogateModels:
|
|||||||
def load(self):
|
def load(self):
|
||||||
if self.blip_model is None:
|
if self.blip_model is None:
|
||||||
self.blip_model = self.load_blip_model()
|
self.blip_model = self.load_blip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.blip_model = self.blip_model.half()
|
self.blip_model = self.blip_model.half()
|
||||||
|
|
||||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||||
|
|
||||||
if self.clip_model is None:
|
if self.clip_model is None:
|
||||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||||
self.clip_model = self.clip_model.half()
|
self.clip_model = self.clip_model.half()
|
||||||
|
|
||||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from modules.devices import get_optimal_device
|
from modules import devices
|
||||||
|
|
||||||
module_in_gpu = None
|
module_in_gpu = None
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
device = gpu = get_optimal_device()
|
|
||||||
|
|
||||||
|
|
||||||
def send_everything_to_cpu():
|
def send_everything_to_cpu():
|
||||||
@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
if module_in_gpu is not None:
|
if module_in_gpu is not None:
|
||||||
module_in_gpu.to(cpu)
|
module_in_gpu.to(cpu)
|
||||||
|
|
||||||
module.to(gpu)
|
module.to(devices.device)
|
||||||
module_in_gpu = module
|
module_in_gpu = module
|
||||||
|
|
||||||
# see below for register_forward_pre_hook;
|
# see below for register_forward_pre_hook;
|
||||||
@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
|
||||||
sd_model.to(device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
|
||||||
|
|
||||||
# register hooks for those the first two models
|
# register hooks for those the first two models
|
||||||
@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
# so that only one of them is in GPU at a time
|
# so that only one of them is in GPU at a time
|
||||||
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
|
||||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
|
||||||
sd_model.model.to(device)
|
sd_model.model.to(devices.device)
|
||||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||||
|
|
||||||
# install hooks for bits of third model
|
# install hooks for bits of third model
|
||||||
|
@ -12,7 +12,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -104,6 +104,12 @@ class StableDiffusionProcessing():
|
|||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
|
self.scripts = None
|
||||||
|
self.script_args = None
|
||||||
|
self.all_prompts = None
|
||||||
|
self.all_seeds = None
|
||||||
|
self.all_subseeds = None
|
||||||
|
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
@ -304,7 +310,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
||||||
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
@ -318,7 +324,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
||||||
|
|
||||||
@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
shared.prompt_styles.apply_styles(p)
|
shared.prompt_styles.apply_styles(p)
|
||||||
|
|
||||||
if type(p.prompt) == list:
|
if type(p.prompt) == list:
|
||||||
all_prompts = p.prompt
|
p.all_prompts = p.prompt
|
||||||
else:
|
else:
|
||||||
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(seed) == list:
|
||||||
all_seeds = seed
|
p.all_seeds = seed
|
||||||
else:
|
else:
|
||||||
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
if type(subseed) == list:
|
if type(subseed) == list:
|
||||||
all_subseeds = subseed
|
p.all_subseeds = subseed
|
||||||
else:
|
else:
|
||||||
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
if p.scripts is not None:
|
||||||
|
p.scripts.run_alwayson_scripts(p)
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
|
|
||||||
with torch.no_grad(), p.sd_model.ema_scope():
|
with torch.no_grad(), p.sd_model.ema_scope():
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
p.init(all_prompts, all_seeds, all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
if (len(prompts) == 0):
|
if (len(prompts) == 0):
|
||||||
break
|
break
|
||||||
@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
@ -540,17 +549,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||||
|
|
||||||
|
def create_dummy_mask(self, x, width=None, height=None):
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
height = height or self.height
|
||||||
|
width = width or self.width
|
||||||
|
|
||||||
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
||||||
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
||||||
|
|
||||||
|
# Add the fake full 1s mask to the first dimension.
|
||||||
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
||||||
|
image_conditioning = image_conditioning.to(x.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Dummy zero conditioning if we're not using inpainting model.
|
||||||
|
# Still takes up a bit of memory, but no encoder call.
|
||||||
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
||||||
|
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
return image_conditioning
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||||
|
|
||||||
@ -587,7 +616,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -613,6 +642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.inpainting_mask_invert = inpainting_mask_invert
|
self.inpainting_mask_invert = inpainting_mask_invert
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
self.image_conditioning = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||||
@ -714,10 +744,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
elif self.inpainting_fill == 3:
|
elif self.inpainting_fill == 3:
|
||||||
self.init_latent = self.init_latent * self.mask
|
self.init_latent = self.init_latent * self.mask
|
||||||
|
|
||||||
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
|
if self.image_mask is not None:
|
||||||
|
conditioning_mask = np.array(self.image_mask.convert("L"))
|
||||||
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
||||||
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
||||||
|
|
||||||
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
||||||
|
conditioning_mask = torch.round(conditioning_mask)
|
||||||
|
else:
|
||||||
|
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
||||||
|
|
||||||
|
# Create another latent image, this time with a masked version of the original input.
|
||||||
|
conditioning_mask = conditioning_mask.to(image.device)
|
||||||
|
conditioning_image = image * (1.0 - conditioning_mask)
|
||||||
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
||||||
|
|
||||||
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
||||||
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
||||||
|
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
||||||
|
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
||||||
|
else:
|
||||||
|
self.image_conditioning = torch.zeros(
|
||||||
|
self.init_latent.shape[0], 5, 1, 1,
|
||||||
|
dtype=self.init_latent.dtype,
|
||||||
|
device=self.init_latent.device
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
42
modules/script_callbacks.py
Normal file
42
modules/script_callbacks.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
|
||||||
|
callbacks_model_loaded = []
|
||||||
|
callbacks_ui_tabs = []
|
||||||
|
|
||||||
|
|
||||||
|
def clear_callbacks():
|
||||||
|
callbacks_model_loaded.clear()
|
||||||
|
callbacks_ui_tabs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def model_loaded_callback(sd_model):
|
||||||
|
for callback in callbacks_model_loaded:
|
||||||
|
callback(sd_model)
|
||||||
|
|
||||||
|
|
||||||
|
def ui_tabs_callback():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for callback in callbacks_ui_tabs:
|
||||||
|
res += callback() or []
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def on_model_loaded(callback):
|
||||||
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
|
passed as an argument"""
|
||||||
|
callbacks_model_loaded.append(callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_tabs(callback):
|
||||||
|
"""register a function to be called when the UI is creating new tabs.
|
||||||
|
The function must either return a None, which means no new tabs to be added, or a list, where
|
||||||
|
each element is a tuple:
|
||||||
|
(gradio_component, title, elem_id)
|
||||||
|
|
||||||
|
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
|
||||||
|
title is tab text displayed to user in the UI
|
||||||
|
elem_id is HTML id for the tab
|
||||||
|
"""
|
||||||
|
callbacks_ui_tabs.append(callback)
|
||||||
|
|
@ -1,86 +1,175 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import modules.ui as ui
|
import modules.ui as ui
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.processing import StableDiffusionProcessing
|
from modules.processing import StableDiffusionProcessing
|
||||||
from modules import shared
|
from modules import shared, paths, script_callbacks
|
||||||
|
|
||||||
|
AlwaysVisible = object()
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
filename = None
|
filename = None
|
||||||
args_from = None
|
args_from = None
|
||||||
args_to = None
|
args_to = None
|
||||||
|
alwayson = False
|
||||||
|
|
||||||
|
infotext_fields = None
|
||||||
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
|
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
|
||||||
|
"""
|
||||||
|
|
||||||
# The title of the script. This is what will be displayed in the dropdown menu.
|
|
||||||
def title(self):
|
def title(self):
|
||||||
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# How the script is displayed in the UI. See https://gradio.app/docs/#components
|
|
||||||
# for the different UI components you can use and how to create them.
|
|
||||||
# Most UI components can return a value, such as a boolean for a checkbox.
|
|
||||||
# The returned values are passed to the run method as parameters.
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
|
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||||
|
The return value should be an array of all components that are used in processing.
|
||||||
|
Values of those returned componenbts will be passed to run() and process() functions.
|
||||||
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Determines when the script should be shown in the dropdown menu via the
|
|
||||||
# returned value. As an example:
|
|
||||||
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
|
|
||||||
# Thus, return is_img2img to only show the script on the img2img tab.
|
|
||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
|
"""
|
||||||
|
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
|
||||||
|
|
||||||
|
This function should return:
|
||||||
|
- False if the script should not be shown in UI at all
|
||||||
|
- True if the script should be shown in UI if it's scelected in the scripts drowpdown
|
||||||
|
- script.AlwaysVisible if the script should be shown in UI at all times
|
||||||
|
"""
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# This is where the additional processing is implemented. The parameters include
|
def run(self, p, *args):
|
||||||
# self, the model object "p" (a StableDiffusionProcessing class, see
|
"""
|
||||||
# processing.py), and the parameters returned by the ui method.
|
This function is called if the script has been selected in the script dropdown.
|
||||||
# Custom functions can be defined here, and additional libraries can be imported
|
It must do all processing and return the Processed object with results, same as
|
||||||
# to be used in processing. The return value should be a Processed object, which is
|
one returned by processing.process_images.
|
||||||
# what is returned by the process_images method.
|
|
||||||
def run(self, *args):
|
Usually the processing is done by calling the processing.process_images function.
|
||||||
|
|
||||||
|
args contains all values returned by components from ui()
|
||||||
|
"""
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# The description method is currently unused.
|
def process(self, p, *args):
|
||||||
# To add a description that appears when hovering over the title, amend the "titles"
|
"""
|
||||||
# dict in script.js to include the script title (returned by title) as a key, and
|
This function is called before processing begins for AlwaysVisible scripts.
|
||||||
# your description as the value.
|
scripts. You can modify the processing object (p) here, inject hooks, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
|
"""unused"""
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
|
||||||
|
def basedir():
|
||||||
|
"""returns the base directory for the current script. For scripts in the main scripts directory,
|
||||||
|
this is the main directory (where webui.py resides), and for scripts in extensions directory
|
||||||
|
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
|
||||||
|
"""
|
||||||
|
return current_basedir
|
||||||
|
|
||||||
|
|
||||||
scripts_data = []
|
scripts_data = []
|
||||||
|
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
|
||||||
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
|
||||||
|
|
||||||
|
|
||||||
def load_scripts(basedir):
|
def list_scripts(scriptdirname, extension):
|
||||||
if not os.path.exists(basedir):
|
scripts_list = []
|
||||||
return
|
|
||||||
|
|
||||||
for filename in sorted(os.listdir(basedir)):
|
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
path = os.path.join(basedir, filename)
|
if os.path.exists(basedir):
|
||||||
|
for filename in sorted(os.listdir(basedir)):
|
||||||
|
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||||
|
|
||||||
if os.path.splitext(path)[1].lower() != '.py':
|
extdir = os.path.join(paths.script_path, "extensions")
|
||||||
|
if os.path.exists(extdir):
|
||||||
|
for dirname in sorted(os.listdir(extdir)):
|
||||||
|
dirpath = os.path.join(extdir, dirname)
|
||||||
|
scriptdirpath = os.path.join(dirpath, scriptdirname)
|
||||||
|
|
||||||
|
if not os.path.isdir(scriptdirpath):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for filename in sorted(os.listdir(scriptdirpath)):
|
||||||
|
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
|
||||||
|
|
||||||
|
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
|
return scripts_list
|
||||||
|
|
||||||
|
|
||||||
|
def list_files_with_name(filename):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
dirs = [paths.script_path]
|
||||||
|
|
||||||
|
extdir = os.path.join(paths.script_path, "extensions")
|
||||||
|
if os.path.exists(extdir):
|
||||||
|
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
|
||||||
|
|
||||||
|
for dirpath in dirs:
|
||||||
|
if not os.path.isdir(dirpath):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not os.path.isfile(path):
|
path = os.path.join(dirpath, filename)
|
||||||
continue
|
if os.path.isfile(filename):
|
||||||
|
res.append(path)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def load_scripts():
|
||||||
|
global current_basedir
|
||||||
|
scripts_data.clear()
|
||||||
|
script_callbacks.clear_callbacks()
|
||||||
|
|
||||||
|
scripts_list = list_scripts("scripts", ".py")
|
||||||
|
|
||||||
|
syspath = sys.path
|
||||||
|
|
||||||
|
for scriptfile in sorted(scripts_list):
|
||||||
try:
|
try:
|
||||||
with open(path, "r", encoding="utf8") as file:
|
if scriptfile.basedir != paths.script_path:
|
||||||
|
sys.path = [scriptfile.basedir] + sys.path
|
||||||
|
current_basedir = scriptfile.basedir
|
||||||
|
|
||||||
|
with open(scriptfile.path, "r", encoding="utf8") as file:
|
||||||
text = file.read()
|
text = file.read()
|
||||||
|
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
compiled = compile(text, path, 'exec')
|
compiled = compile(text, scriptfile.path, 'exec')
|
||||||
module = ModuleType(filename)
|
module = ModuleType(scriptfile.filename)
|
||||||
exec(compiled, module.__dict__)
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
scripts_data.append((script_class, path))
|
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error loading script: {filename}", file=sys.stderr)
|
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
sys.path = syspath
|
||||||
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
|
||||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
|||||||
class ScriptRunner:
|
class ScriptRunner:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
self.selectable_scripts = []
|
||||||
|
self.alwayson_scripts = []
|
||||||
self.titles = []
|
self.titles = []
|
||||||
|
self.infotext_fields = []
|
||||||
|
|
||||||
def setup_ui(self, is_img2img):
|
def setup_ui(self, is_img2img):
|
||||||
for script_class, path in scripts_data:
|
for script_class, path, basedir in scripts_data:
|
||||||
script = script_class()
|
script = script_class()
|
||||||
script.filename = path
|
script.filename = path
|
||||||
|
|
||||||
if not script.show(is_img2img):
|
visibility = script.show(is_img2img)
|
||||||
continue
|
|
||||||
|
|
||||||
self.scripts.append(script)
|
if visibility == AlwaysVisible:
|
||||||
|
self.scripts.append(script)
|
||||||
|
self.alwayson_scripts.append(script)
|
||||||
|
script.alwayson = True
|
||||||
|
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
|
elif visibility:
|
||||||
|
self.scripts.append(script)
|
||||||
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
dropdown.save_to_config = True
|
|
||||||
inputs = [dropdown]
|
|
||||||
|
|
||||||
for script in self.scripts:
|
inputs = [None]
|
||||||
|
inputs_alwayson = [True]
|
||||||
|
|
||||||
|
def create_script_ui(script, inputs, inputs_alwayson):
|
||||||
script.args_from = len(inputs)
|
script.args_from = len(inputs)
|
||||||
script.args_to = len(inputs)
|
script.args_to = len(inputs)
|
||||||
|
|
||||||
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
|
||||||
|
|
||||||
if controls is None:
|
if controls is None:
|
||||||
continue
|
return
|
||||||
|
|
||||||
for control in controls:
|
for control in controls:
|
||||||
control.custom_script_source = os.path.basename(script.filename)
|
control.custom_script_source = os.path.basename(script.filename)
|
||||||
control.visible = False
|
if not script.alwayson:
|
||||||
|
control.visible = False
|
||||||
|
|
||||||
|
if script.infotext_fields is not None:
|
||||||
|
self.infotext_fields += script.infotext_fields
|
||||||
|
|
||||||
inputs += controls
|
inputs += controls
|
||||||
|
inputs_alwayson += [script.alwayson for _ in controls]
|
||||||
script.args_to = len(inputs)
|
script.args_to = len(inputs)
|
||||||
|
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
with gr.Group():
|
||||||
|
create_script_ui(script, inputs, inputs_alwayson)
|
||||||
|
|
||||||
|
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
|
||||||
|
dropdown.save_to_config = True
|
||||||
|
inputs[0] = dropdown
|
||||||
|
|
||||||
|
for script in self.selectable_scripts:
|
||||||
|
create_script_ui(script, inputs, inputs_alwayson)
|
||||||
|
|
||||||
def select_script(script_index):
|
def select_script(script_index):
|
||||||
if 0 < script_index <= len(self.scripts):
|
if 0 < script_index <= len(self.selectable_scripts):
|
||||||
script = self.scripts[script_index-1]
|
script = self.selectable_scripts[script_index-1]
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
args_to = script.args_to
|
args_to = script.args_to
|
||||||
else:
|
else:
|
||||||
args_from = 0
|
args_from = 0
|
||||||
args_to = 0
|
args_to = 0
|
||||||
|
|
||||||
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
|
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
|
||||||
|
|
||||||
def init_field(title):
|
def init_field(title):
|
||||||
if title == 'None':
|
if title == 'None':
|
||||||
return
|
return
|
||||||
script_index = self.titles.index(title)
|
script_index = self.titles.index(title)
|
||||||
script = self.scripts[script_index]
|
script = self.selectable_scripts[script_index]
|
||||||
for i in range(script.args_from, script.args_to):
|
for i in range(script.args_from, script.args_to):
|
||||||
inputs[i].visible = True
|
inputs[i].visible = True
|
||||||
|
|
||||||
@ -164,7 +277,7 @@ class ScriptRunner:
|
|||||||
if script_index == 0:
|
if script_index == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
script = self.scripts[script_index-1]
|
script = self.selectable_scripts[script_index-1]
|
||||||
|
|
||||||
if script is None:
|
if script is None:
|
||||||
return None
|
return None
|
||||||
@ -176,7 +289,16 @@ class ScriptRunner:
|
|||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def reload_sources(self):
|
def run_alwayson_scripts(self, p):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.process(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
with open(script.filename, "r", encoding="utf8") as file:
|
with open(script.filename, "r", encoding="utf8") as file:
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
@ -186,9 +308,12 @@ class ScriptRunner:
|
|||||||
|
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
|
|
||||||
compiled = compile(text, filename, 'exec')
|
module = cache.get(filename, None)
|
||||||
module = ModuleType(script.filename)
|
if module is None:
|
||||||
exec(compiled, module.__dict__)
|
compiled = compile(text, filename, 'exec')
|
||||||
|
module = ModuleType(script.filename)
|
||||||
|
exec(compiled, module.__dict__)
|
||||||
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for key, script_class in module.__dict__.items():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
@ -197,19 +322,22 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_from = args_from
|
self.scripts[si].args_from = args_from
|
||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
|
||||||
|
|
||||||
def reload_script_body_only():
|
def reload_script_body_only():
|
||||||
scripts_txt2img.reload_sources()
|
cache = {}
|
||||||
scripts_img2img.reload_sources()
|
scripts_txt2img.reload_sources(cache)
|
||||||
|
scripts_img2img.reload_sources(cache)
|
||||||
|
|
||||||
|
|
||||||
def reload_scripts(basedir):
|
def reload_scripts():
|
||||||
global scripts_txt2img, scripts_img2img
|
global scripts_txt2img, scripts_img2img
|
||||||
|
|
||||||
scripts_data.clear()
|
load_scripts()
|
||||||
load_scripts(basedir)
|
|
||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
remade_tokens = remade_tokens[:last_comma]
|
remade_tokens = remade_tokens[:last_comma]
|
||||||
length = len(remade_tokens)
|
length = len(remade_tokens)
|
||||||
|
|
||||||
rem = int(math.ceil(length / 75)) * 75 - length
|
rem = int(math.ceil(length / 75)) * 75 - length
|
||||||
remade_tokens += [id_end] * rem + reloc_tokens
|
remade_tokens += [id_end] * rem + reloc_tokens
|
||||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||||
|
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
|
||||||
def process_text_old(self, text):
|
def process_text_old(self, text):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
hijack_fixes.append(fixes)
|
hijack_fixes.append(fixes)
|
||||||
batch_multipliers.append(multipliers)
|
batch_multipliers.append(multipliers)
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
def forward(self, text):
|
def forward(self, text):
|
||||||
use_old = opts.use_old_emphasis_implementation
|
use_old = opts.use_old_emphasis_implementation
|
||||||
if use_old:
|
if use_old:
|
||||||
@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
if use_old:
|
if use_old:
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
|
||||||
z = None
|
z = None
|
||||||
i = 0
|
i = 0
|
||||||
while max(map(len, remade_batch_tokens)) != 0:
|
while max(map(len, remade_batch_tokens)) != 0:
|
||||||
@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if fix[0] == i:
|
if fix[0] == i:
|
||||||
fixes.append(fix[1])
|
fixes.append(fix[1])
|
||||||
self.hijack.fixes.append(fixes)
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
multipliers = []
|
multipliers = []
|
||||||
for j in range(len(remade_batch_tokens)):
|
for j in range(len(remade_batch_tokens)):
|
||||||
@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
|
|
||||||
z1 = self.process_tokens(tokens, multipliers)
|
z1 = self.process_tokens(tokens, multipliers)
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
remade_batch_tokens = rem_tokens
|
||||||
batch_multipliers = rem_multipliers
|
batch_multipliers = rem_multipliers
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
if not opts.use_old_emphasis_implementation:
|
if not opts.use_old_emphasis_implementation:
|
||||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
|
|
||||||
@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = embedding.vec
|
emb = embedding.vec
|
||||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
vecs.append(tensor)
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
331
modules/sd_hijack_inpainting.py
Normal file
331
modules/sd_hijack_inpainting.py
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from einops import repeat
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
|
import ldm.models.diffusion.ddpm
|
||||||
|
import ldm.models.diffusion.ddim
|
||||||
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch DDIMSampler methods from RunwayML repo directly.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ddim(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch PLMSSampler methods.
|
||||||
|
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
|
||||||
|
# =================================================================================================
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_plms(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
|
||||||
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [
|
||||||
|
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||||
|
for i in range(len(c[k]))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
||||||
|
|
||||||
|
# =================================================================================================
|
||||||
|
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
|
||||||
|
# Adapted from:
|
||||||
|
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
|
||||||
|
# =================================================================================================
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
if hasattr(xc, "to"):
|
||||||
|
xc = xc.to(self.device)
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
concat_keys=("mask", "masked_image"),
|
||||||
|
masked_image_key="masked_image",
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.masked_image_key = masked_image_key
|
||||||
|
assert self.masked_image_key in concat_keys
|
||||||
|
self.concat_keys = concat_keys
|
||||||
|
|
||||||
|
|
||||||
|
def should_hijack_inpainting(checkpoint_info):
|
||||||
|
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
def do_inpainting_hijack():
|
||||||
|
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
|
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
|
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
@ -7,8 +7,9 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader, devices
|
from modules import shared, modelloader, devices, script_callbacks
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
|
|||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
|
||||||
from transformers import logging
|
from transformers import logging, CLIPModel
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||||||
return pl_sd
|
return pl_sd
|
||||||
|
|
||||||
|
|
||||||
|
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(model, checkpoint_info):
|
def load_model_weights(model, checkpoint_info):
|
||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
sd_model_hash = checkpoint_info.hash
|
sd_model_hash = checkpoint_info.hash
|
||||||
@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
if os.path.exists(vae_file):
|
if os.path.exists(vae_file):
|
||||||
print(f"Loading VAE weights from: {vae_file}")
|
print(f"Loading VAE weights from: {vae_file}")
|
||||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||||
model.first_stage_model.load_state_dict(vae_dict)
|
model.first_stage_model.load_state_dict(vae_dict)
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info):
|
|||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
def load_model():
|
def load_model(checkpoint_info=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if checkpoint_info.config != shared.cmd_opts.config:
|
if checkpoint_info.config != shared.cmd_opts.config:
|
||||||
print(f"Loading config from: {checkpoint_info.config}")
|
print(f"Loading config from: {checkpoint_info.config}")
|
||||||
|
|
||||||
sd_config = OmegaConf.load(checkpoint_info.config)
|
sd_config = OmegaConf.load(checkpoint_info.config)
|
||||||
|
|
||||||
|
if should_hijack_inpainting(checkpoint_info):
|
||||||
|
# Hardcoded config for now...
|
||||||
|
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||||
|
sd_config.model.params.use_ema = False
|
||||||
|
sd_config.model.params.conditioning_key = "hybrid"
|
||||||
|
sd_config.model.params.unet_config.params.in_channels = 9
|
||||||
|
|
||||||
|
# Create a "fake" config with a different name so that we know to unload it when switching models.
|
||||||
|
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
|
||||||
|
|
||||||
|
do_inpainting_hijack()
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
load_model_weights(sd_model, checkpoint_info)
|
load_model_weights(sd_model, checkpoint_info)
|
||||||
|
|
||||||
@ -222,6 +238,9 @@ def load_model():
|
|||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
|
shared.sd_model = sd_model
|
||||||
|
|
||||||
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
print(f"Model loaded.")
|
print(f"Model loaded.")
|
||||||
return sd_model
|
return sd_model
|
||||||
@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None):
|
|||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
shared.sd_model = load_model()
|
load_model(checkpoint_info)
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.config = None
|
self.config = None
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
|
|||||||
if self.stop_at is not None and self.step > self.stop_at:
|
if self.stop_at is not None and self.step > self.stop_at:
|
||||||
raise InterruptedException
|
raise InterruptedException
|
||||||
|
|
||||||
|
# Have to unwrap the inpainting conditioning here to perform pre-processing
|
||||||
|
image_conditioning = None
|
||||||
|
if isinstance(cond, dict):
|
||||||
|
image_conditioning = cond["c_concat"][0]
|
||||||
|
cond = cond["c_crossattn"][0]
|
||||||
|
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
|
||||||
|
|
||||||
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
|
|||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
|
|
||||||
|
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||||
|
# Note that they need to be lists because it just concatenates them later.
|
||||||
|
if image_conditioning is not None:
|
||||||
|
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler:
|
|||||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||||
|
|
||||||
self.init_latent = x
|
self.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
self.initialize(p)
|
self.initialize(p)
|
||||||
|
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
|
self.last_latent = x
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
if image_conditioning is not None:
|
||||||
|
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||||
|
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||||
|
|
||||||
# existing code fails with certain step counts, like 9
|
# existing code fails with certain step counts, like 9
|
||||||
try:
|
try:
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise InterruptedException
|
raise InterruptedException
|
||||||
|
|
||||||
@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
|
||||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
if tensor.shape[1] == uncond.shape[1]:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = min(a + batch_size, tensor.shape[0])
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
|
||||||
|
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
@ -306,6 +334,8 @@ class KDiffusionSampler:
|
|||||||
self.config = None
|
self.config = None
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
|
|
||||||
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
step = d['i']
|
step = d['i']
|
||||||
latent = d["denoised"]
|
latent = d["denoised"]
|
||||||
@ -361,7 +391,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
return extra_params_kwargs
|
return extra_params_kwargs
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
@ -388,12 +418,18 @@ class KDiffusionSampler:
|
|||||||
extra_params_kwargs['sigmas'] = sigma_sched
|
extra_params_kwargs['sigmas'] = sigma_sched
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
|
self.last_latent = x
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
if p.sampler_noise_scheduler_override:
|
if p.sampler_noise_scheduler_override:
|
||||||
@ -414,7 +450,13 @@ class KDiffusionSampler:
|
|||||||
else:
|
else:
|
||||||
extra_params_kwargs['sigmas'] = sigmas
|
extra_params_kwargs['sigmas'] = sigmas
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
self.last_latent = x
|
||||||
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
|
'cond': conditioning,
|
||||||
|
'image_cond': image_conditioning,
|
||||||
|
'uncond': unconditional_conditioning,
|
||||||
|
'cond_scale': p.cfg_scale
|
||||||
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -78,6 +79,8 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
|
|||||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
|
||||||
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
|
||||||
|
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
|
||||||
|
parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
restricted_opts = [
|
restricted_opts = [
|
||||||
@ -249,7 +252,7 @@ options_templates.update(options_section(('system', "System"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
|
||||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||||
@ -315,6 +318,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('images-history', "Images Browser"), {
|
||||||
|
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
|
||||||
|
"images_history_preload": OptionInfo(False, "Preload images at startup"),
|
||||||
|
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
|
||||||
|
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
|
||||||
|
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
|
||||||
|
|
||||||
|
}))
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
data = None
|
data = None
|
||||||
@ -387,6 +398,8 @@ sd_upscalers = []
|
|||||||
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
|
|
||||||
|
clip_model = None
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
|
|||||||
|
|
||||||
self.dataset.append(entry)
|
self.dataset.append(entry)
|
||||||
|
|
||||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
assert len(self.dataset) > 0, "No images have been found in the dataset."
|
||||||
self.length = len(self.dataset) * repeats // batch_size
|
self.length = len(self.dataset) * repeats // batch_size
|
||||||
|
|
||||||
self.initial_indexes = np.arange(len(self.dataset))
|
self.initial_indexes = np.arange(len(self.dataset))
|
||||||
@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
|
|||||||
self.shuffle()
|
self.shuffle()
|
||||||
|
|
||||||
def shuffle(self):
|
def shuffle(self):
|
||||||
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
|
||||||
|
|
||||||
def create_text(self, filename_text):
|
def create_text(self, filename_text):
|
||||||
text = random.choice(self.lines)
|
text = random.choice(self.lines)
|
||||||
|
@ -5,6 +5,7 @@ import zlib
|
|||||||
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
|
||||||
from fonts.ttf import Roboto
|
from fonts.ttf import Roboto
|
||||||
import torch
|
import torch
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoder(json.JSONEncoder):
|
class EmbeddingEncoder(json.JSONEncoder):
|
||||||
@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||||||
from math import cos
|
from math import cos
|
||||||
|
|
||||||
image = srcimage.copy()
|
image = srcimage.copy()
|
||||||
|
fontsize = 32
|
||||||
if textfont is None:
|
if textfont is None:
|
||||||
try:
|
try:
|
||||||
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
|
||||||
@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||||||
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
|
||||||
|
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
fontsize = 32
|
|
||||||
font = ImageFont.truetype(textfont, fontsize)
|
font = ImageFont.truetype(textfont, fontsize)
|
||||||
padding = 10
|
padding = 10
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
import math
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import tqdm
|
import tqdm
|
||||||
@ -11,7 +12,7 @@ if cmd_opts.deepdanbooru:
|
|||||||
import modules.deepbooru as deepbooru
|
import modules.deepbooru as deepbooru
|
||||||
|
|
||||||
|
|
||||||
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||||
try:
|
try:
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
@ -21,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
|||||||
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
|
||||||
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
|
||||||
|
|
||||||
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
|
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
@ -33,11 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
|
||||||
width = process_width
|
width = process_width
|
||||||
height = process_height
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
dst = os.path.abspath(process_dst)
|
dst = os.path.abspath(process_dst)
|
||||||
|
split_threshold = max(0.0, min(1.0, split_threshold))
|
||||||
|
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
|
||||||
|
|
||||||
assert src != dst, 'same directory specified as source and destination'
|
assert src != dst, 'same directory specified as source and destination'
|
||||||
|
|
||||||
@ -48,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||||||
shared.state.textinfo = "Preprocessing..."
|
shared.state.textinfo = "Preprocessing..."
|
||||||
shared.state.job_count = len(files)
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
def save_pic_with_caption(image, index):
|
def save_pic_with_caption(image, index, existing_caption=None):
|
||||||
caption = ""
|
caption = ""
|
||||||
|
|
||||||
if process_caption:
|
if process_caption:
|
||||||
@ -66,17 +69,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||||||
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
basename = f"{index:05}-{subindex[0]}-{filename_part}"
|
||||||
image.save(os.path.join(dst, f"{basename}.png"))
|
image.save(os.path.join(dst, f"{basename}.png"))
|
||||||
|
|
||||||
|
if preprocess_txt_action == 'prepend' and existing_caption:
|
||||||
|
caption = existing_caption + ' ' + caption
|
||||||
|
elif preprocess_txt_action == 'append' and existing_caption:
|
||||||
|
caption = caption + ' ' + existing_caption
|
||||||
|
elif preprocess_txt_action == 'copy' and existing_caption:
|
||||||
|
caption = existing_caption
|
||||||
|
|
||||||
|
caption = caption.strip()
|
||||||
|
|
||||||
if len(caption) > 0:
|
if len(caption) > 0:
|
||||||
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
|
||||||
file.write(caption)
|
file.write(caption)
|
||||||
|
|
||||||
subindex[0] += 1
|
subindex[0] += 1
|
||||||
|
|
||||||
def save_pic(image, index):
|
def save_pic(image, index, existing_caption=None):
|
||||||
save_pic_with_caption(image, index)
|
save_pic_with_caption(image, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
if process_flip:
|
if process_flip:
|
||||||
save_pic_with_caption(ImageOps.mirror(image), index)
|
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
|
||||||
|
|
||||||
|
def split_pic(image, inverse_xy):
|
||||||
|
if inverse_xy:
|
||||||
|
from_w, from_h = image.height, image.width
|
||||||
|
to_w, to_h = height, width
|
||||||
|
else:
|
||||||
|
from_w, from_h = image.width, image.height
|
||||||
|
to_w, to_h = width, height
|
||||||
|
h = from_h * to_w // from_w
|
||||||
|
if inverse_xy:
|
||||||
|
image = image.resize((h, to_w))
|
||||||
|
else:
|
||||||
|
image = image.resize((to_w, h))
|
||||||
|
|
||||||
|
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
|
||||||
|
y_step = (h - to_h) / (split_count - 1)
|
||||||
|
for i in range(split_count):
|
||||||
|
y = int(y_step * i)
|
||||||
|
if inverse_xy:
|
||||||
|
splitted = image.crop((y, 0, y + to_h, to_w))
|
||||||
|
else:
|
||||||
|
splitted = image.crop((0, y, to_w, y + to_h))
|
||||||
|
yield splitted
|
||||||
|
|
||||||
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
subindex = [0]
|
subindex = [0]
|
||||||
@ -86,31 +121,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
existing_caption = None
|
||||||
|
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
|
||||||
|
if os.path.exists(existing_caption_filename):
|
||||||
|
with open(existing_caption_filename, 'r', encoding="utf8") as file:
|
||||||
|
existing_caption = file.read()
|
||||||
|
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
ratio = img.height / img.width
|
if img.height > img.width:
|
||||||
is_tall = ratio > 1.35
|
ratio = (img.width * height) / (img.height * width)
|
||||||
is_wide = ratio < 1 / 1.35
|
inverse_xy = False
|
||||||
|
else:
|
||||||
|
ratio = (img.height * width) / (img.width * height)
|
||||||
|
inverse_xy = True
|
||||||
|
|
||||||
if process_split and is_tall:
|
if process_split and ratio < 1.0 and ratio <= split_threshold:
|
||||||
img = img.resize((width, height * img.height // img.width))
|
for splitted in split_pic(img, inverse_xy):
|
||||||
|
save_pic(splitted, index, existing_caption=existing_caption)
|
||||||
top = img.crop((0, 0, width, height))
|
|
||||||
save_pic(top, index)
|
|
||||||
|
|
||||||
bot = img.crop((0, img.height - height, width, img.height))
|
|
||||||
save_pic(bot, index)
|
|
||||||
elif process_split and is_wide:
|
|
||||||
img = img.resize((width * img.width // img.height, height))
|
|
||||||
|
|
||||||
left = img.crop((0, 0, width, height))
|
|
||||||
save_pic(left, index)
|
|
||||||
|
|
||||||
right = img.crop((img.width - width, 0, img.width, height))
|
|
||||||
save_pic(right, index)
|
|
||||||
else:
|
else:
|
||||||
img = images.resize_image(1, img, width, height)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index)
|
save_pic(img, index, existing_caption=existing_caption)
|
||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
@ -153,7 +153,7 @@ class EmbeddingDatabase:
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
|||||||
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
assert not os.path.exists(fn), f"file {fn} already exists"
|
if not overwrite_old:
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = 0
|
embedding.step = 0
|
||||||
@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
epoch_num = embedding.step // len(ds)
|
epoch_num = embedding.step // len(ds)
|
||||||
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
|
|||||||
from modules import sd_hijack, shared
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
def create_embedding(name, initialization_text, nvpt):
|
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
||||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import modules.scripts
|
import modules.scripts
|
||||||
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
|
||||||
|
StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||||||
firstphase_height=firstphase_height if enable_hr else None,
|
firstphase_height=firstphase_height if enable_hr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p.scripts = modules.scripts.scripts_txt2img
|
||||||
|
p.script_args = args
|
||||||
|
|
||||||
if cmd_opts.enable_console_prompts:
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||||||
processed.images = []
|
processed.images = []
|
||||||
|
|
||||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||||
|
|
||||||
|
127
modules/ui.py
127
modules/ui.py
@ -22,9 +22,14 @@ import piexif
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
from modules import localization, sd_hijack, sd_models
|
import gradio as gr
|
||||||
|
import gradio.utils
|
||||||
|
import gradio.routes
|
||||||
|
|
||||||
|
from modules import sd_hijack, sd_models, localization, script_callbacks
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts, opts, restricted_opts
|
|
||||||
|
from modules.shared import opts, cmd_opts, restricted_opts
|
||||||
|
|
||||||
if cmd_opts.deepdanbooru:
|
if cmd_opts.deepdanbooru:
|
||||||
from modules.deepbooru import get_deepbooru_tags
|
from modules.deepbooru import get_deepbooru_tags
|
||||||
@ -43,6 +48,11 @@ from modules import prompt_parser
|
|||||||
from modules.images import save_image
|
from modules.images import save_image
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
|
import modules.textual_inversion.ui
|
||||||
|
import modules.hypernetworks.ui
|
||||||
|
|
||||||
|
import modules.images_history as img_his
|
||||||
|
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
@ -593,27 +603,29 @@ def apply_setting(key, value):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
def refresh():
|
||||||
|
refresh_method()
|
||||||
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
|
for k, v in args.items():
|
||||||
|
setattr(refresh_component, k, v)
|
||||||
|
|
||||||
|
return gr.update(**(args or {}))
|
||||||
|
|
||||||
|
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
||||||
|
refresh_button.click(
|
||||||
|
fn=refresh,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[refresh_component]
|
||||||
|
)
|
||||||
|
return refresh_button
|
||||||
|
|
||||||
|
|
||||||
def create_ui(wrap_gradio_gpu_call):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
|
||||||
def refresh():
|
|
||||||
refresh_method()
|
|
||||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
|
||||||
|
|
||||||
for k, v in args.items():
|
|
||||||
setattr(refresh_component, k, v)
|
|
||||||
|
|
||||||
return gr.update(**(args or {}))
|
|
||||||
|
|
||||||
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
|
|
||||||
refresh_button.click(
|
|
||||||
fn = refresh,
|
|
||||||
inputs = [],
|
|
||||||
outputs = [refresh_component]
|
|
||||||
)
|
|
||||||
return refresh_button
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
@ -711,6 +723,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
firstphase_width,
|
firstphase_width,
|
||||||
firstphase_height,
|
firstphase_height,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
|
|
||||||
outputs=[
|
outputs=[
|
||||||
txt2img_gallery,
|
txt2img_gallery,
|
||||||
generation_info,
|
generation_info,
|
||||||
@ -787,6 +800,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
(firstphase_width, "First pass size-1"),
|
(firstphase_width, "First pass size-1"),
|
||||||
(firstphase_height, "First pass size-2"),
|
(firstphase_height, "First pass size-2"),
|
||||||
|
*modules.scripts.scripts_txt2img.infotext_fields
|
||||||
]
|
]
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
@ -854,8 +868,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
|
||||||
@ -1052,6 +1066,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
|
*modules.scripts.scripts_img2img.infotext_fields
|
||||||
]
|
]
|
||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
@ -1174,12 +1189,12 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
)
|
)
|
||||||
#images history
|
#images history
|
||||||
images_history_switch_dict = {
|
images_history_switch_dict = {
|
||||||
"fn":modules.generation_parameters_copypaste.connect_paste,
|
"fn": modules.generation_parameters_copypaste.connect_paste,
|
||||||
"t2i":txt2img_paste_fields,
|
"t2i": txt2img_paste_fields,
|
||||||
"i2i":img2img_paste_fields
|
"i2i": img2img_paste_fields
|
||||||
}
|
}
|
||||||
|
|
||||||
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
|
images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
|
||||||
|
|
||||||
with gr.Blocks() as modelmerger_interface:
|
with gr.Blocks() as modelmerger_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
@ -1212,6 +1227,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_embedding_name = gr.Textbox(label="Name")
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1227,6 +1243,8 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
|
||||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
|
||||||
|
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
|
||||||
|
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
@ -1240,13 +1258,18 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
process_dst = gr.Textbox(label='Destination directory')
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies')
|
process_flip = gr.Checkbox(label='Create flipped copies')
|
||||||
process_split = gr.Checkbox(label='Split oversized images into two')
|
process_split = gr.Checkbox(label='Split oversized images')
|
||||||
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
process_caption = gr.Checkbox(label='Use BLIP for caption')
|
||||||
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
|
||||||
|
|
||||||
|
with gr.Row(visible=False) as process_split_extra_row:
|
||||||
|
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
|
||||||
|
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
gr.HTML(value="")
|
gr.HTML(value="")
|
||||||
@ -1254,15 +1277,24 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
|
process_split.change(
|
||||||
|
fn=lambda show: gr_show(show),
|
||||||
|
inputs=[process_split],
|
||||||
|
outputs=[process_split_extra_row],
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
|
||||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
|
||||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
with gr.Row():
|
||||||
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
|
||||||
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
|
||||||
|
|
||||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
@ -1296,6 +1328,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
initialization_text,
|
initialization_text,
|
||||||
nvpt,
|
nvpt,
|
||||||
|
overwrite_old_embedding,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
@ -1309,6 +1342,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
|
overwrite_old_hypernetwork,
|
||||||
new_hypernetwork_layer_structure,
|
new_hypernetwork_layer_structure,
|
||||||
new_hypernetwork_activation_func,
|
new_hypernetwork_activation_func,
|
||||||
new_hypernetwork_add_layer_norm,
|
new_hypernetwork_add_layer_norm,
|
||||||
@ -1329,10 +1363,13 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
process_dst,
|
process_dst,
|
||||||
process_width,
|
process_width,
|
||||||
process_height,
|
process_height,
|
||||||
|
preprocess_txt_action,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
process_caption_deepbooru
|
process_caption_deepbooru,
|
||||||
|
process_split_threshold,
|
||||||
|
process_overlap_ratio,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
ti_output,
|
ti_output,
|
||||||
@ -1345,7 +1382,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_embedding_name,
|
train_embedding_name,
|
||||||
learn_rate,
|
embedding_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
@ -1370,7 +1407,7 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
train_hypernetwork_name,
|
train_hypernetwork_name,
|
||||||
learn_rate,
|
hypernetwork_learn_rate,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_directory,
|
dataset_directory,
|
||||||
log_directory,
|
log_directory,
|
||||||
@ -1491,10 +1528,10 @@ Requested path was: {f}
|
|||||||
if not opts.same_type(value, opts.data_labels[key].default):
|
if not opts.same_type(value, opts.data_labels[key].default):
|
||||||
return gr.update(visible=True), opts.dumpjson()
|
return gr.update(visible=True), opts.dumpjson()
|
||||||
|
|
||||||
|
oldval = opts.data.get(key, None)
|
||||||
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
|
||||||
return gr.update(value=oldval), opts.dumpjson()
|
return gr.update(value=oldval), opts.dumpjson()
|
||||||
|
|
||||||
oldval = opts.data.get(key, None)
|
|
||||||
opts.data[key] = value
|
opts.data[key] = value
|
||||||
|
|
||||||
if oldval != value:
|
if oldval != value:
|
||||||
@ -1600,19 +1637,24 @@ Requested path was: {f}
|
|||||||
(img2img_interface, "img2img", "img2img"),
|
(img2img_interface, "img2img", "img2img"),
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(images_history, "History", "images_history"),
|
(images_history, "Image Browser", "images_history"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
(train_interface, "Train", "ti"),
|
(train_interface, "Train", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
|
interfaces += script_callbacks.ui_tabs_callback()
|
||||||
css = file.read()
|
|
||||||
|
interfaces += [(settings_interface, "Settings", "settings")]
|
||||||
|
|
||||||
|
css = ""
|
||||||
|
|
||||||
|
for cssfile in modules.scripts.list_files_with_name("style.css"):
|
||||||
|
with open(cssfile, "r", encoding="utf8") as file:
|
||||||
|
css += file.read() + "\n"
|
||||||
|
|
||||||
if os.path.exists(os.path.join(script_path, "user.css")):
|
if os.path.exists(os.path.join(script_path, "user.css")):
|
||||||
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
|
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
|
||||||
usercss = file.read()
|
css += file.read() + "\n"
|
||||||
css += usercss
|
|
||||||
|
|
||||||
if not cmd_opts.no_progressbar_hiding:
|
if not cmd_opts.no_progressbar_hiding:
|
||||||
css += css_hide_progressbar
|
css += css_hide_progressbar
|
||||||
@ -1835,9 +1877,9 @@ def load_javascript(raw_response):
|
|||||||
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
|
||||||
javascript = f'<script>{jsfile.read()}</script>'
|
javascript = f'<script>{jsfile.read()}</script>'
|
||||||
|
|
||||||
jsdir = os.path.join(script_path, "javascript")
|
scripts_list = modules.scripts.list_scripts("javascript", ".js")
|
||||||
for filename in sorted(os.listdir(jsdir)):
|
for basedir, filename, path in scripts_list:
|
||||||
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
|
with open(path, "r", encoding="utf8") as jsfile:
|
||||||
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
|
||||||
|
|
||||||
if cmd_opts.theme is not None:
|
if cmd_opts.theme is not None:
|
||||||
@ -1855,6 +1897,5 @@ def load_javascript(raw_response):
|
|||||||
gradio.routes.templates.TemplateResponse = template_response
|
gradio.routes.templates.TemplateResponse = template_response
|
||||||
|
|
||||||
|
|
||||||
reload_javascript = partial(load_javascript,
|
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
|
||||||
gradio.routes.templates.TemplateResponse)
|
|
||||||
reload_javascript()
|
reload_javascript()
|
||||||
|
@ -172,54 +172,54 @@ class Script(scripts.Script):
|
|||||||
if down > 0:
|
if down > 0:
|
||||||
down = target_h - init_img.height - up
|
down = target_h - init_img.height - up
|
||||||
|
|
||||||
init_image = p.init_images[0]
|
def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
|
||||||
|
|
||||||
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
|
|
||||||
|
|
||||||
def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
|
|
||||||
is_horiz = is_left or is_right
|
is_horiz = is_left or is_right
|
||||||
is_vert = is_top or is_bottom
|
is_vert = is_top or is_bottom
|
||||||
pixels_horiz = expand_pixels if is_horiz else 0
|
pixels_horiz = expand_pixels if is_horiz else 0
|
||||||
pixels_vert = expand_pixels if is_vert else 0
|
pixels_vert = expand_pixels if is_vert else 0
|
||||||
|
|
||||||
res_w = init.width + pixels_horiz
|
images_to_process = []
|
||||||
res_h = init.height + pixels_vert
|
output_images = []
|
||||||
process_res_w = math.ceil(res_w / 64) * 64
|
for n in range(count):
|
||||||
process_res_h = math.ceil(res_h / 64) * 64
|
res_w = init[n].width + pixels_horiz
|
||||||
|
res_h = init[n].height + pixels_vert
|
||||||
|
process_res_w = math.ceil(res_w / 64) * 64
|
||||||
|
process_res_h = math.ceil(res_h / 64) * 64
|
||||||
|
|
||||||
img = Image.new("RGB", (process_res_w, process_res_h))
|
img = Image.new("RGB", (process_res_w, process_res_h))
|
||||||
img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
|
img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
|
||||||
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
|
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
|
||||||
draw = ImageDraw.Draw(mask)
|
draw = ImageDraw.Draw(mask)
|
||||||
draw.rectangle((
|
draw.rectangle((
|
||||||
expand_pixels + mask_blur if is_left else 0,
|
expand_pixels + mask_blur if is_left else 0,
|
||||||
expand_pixels + mask_blur if is_top else 0,
|
expand_pixels + mask_blur if is_top else 0,
|
||||||
mask.width - expand_pixels - mask_blur if is_right else res_w,
|
mask.width - expand_pixels - mask_blur if is_right else res_w,
|
||||||
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
|
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
|
||||||
), fill="black")
|
), fill="black")
|
||||||
|
|
||||||
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
np_image = (np.asarray(img) / 255.0).astype(np.float64)
|
||||||
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
|
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
|
||||||
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
|
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
|
||||||
out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
|
||||||
|
|
||||||
target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
|
target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
|
||||||
target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
|
target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
|
||||||
|
p.width = target_width if is_horiz else img.width
|
||||||
|
p.height = target_height if is_vert else img.height
|
||||||
|
|
||||||
crop_region = (
|
crop_region = (
|
||||||
0 if is_left else out.width - target_width,
|
0 if is_left else output_images[n].width - target_width,
|
||||||
0 if is_top else out.height - target_height,
|
0 if is_top else output_images[n].height - target_height,
|
||||||
target_width if is_left else out.width,
|
target_width if is_left else output_images[n].width,
|
||||||
target_height if is_top else out.height,
|
target_height if is_top else output_images[n].height,
|
||||||
)
|
)
|
||||||
|
mask = mask.crop(crop_region)
|
||||||
|
p.image_mask = mask
|
||||||
|
|
||||||
image_to_process = out.crop(crop_region)
|
image_to_process = output_images[n].crop(crop_region)
|
||||||
mask = mask.crop(crop_region)
|
images_to_process.append(image_to_process)
|
||||||
|
|
||||||
p.width = target_width if is_horiz else img.width
|
p.init_images = images_to_process
|
||||||
p.height = target_height if is_vert else img.height
|
|
||||||
p.init_images = [image_to_process]
|
|
||||||
p.image_mask = mask
|
|
||||||
|
|
||||||
latent_mask = Image.new("RGB", (p.width, p.height), "white")
|
latent_mask = Image.new("RGB", (p.width, p.height), "white")
|
||||||
draw = ImageDraw.Draw(latent_mask)
|
draw = ImageDraw.Draw(latent_mask)
|
||||||
@ -232,31 +232,52 @@ class Script(scripts.Script):
|
|||||||
p.latent_mask = latent_mask
|
p.latent_mask = latent_mask
|
||||||
|
|
||||||
proc = process_images(p)
|
proc = process_images(p)
|
||||||
proc_img = proc.images[0]
|
|
||||||
|
|
||||||
if initial_seed_and_info[0] is None:
|
if initial_seed_and_info[0] is None:
|
||||||
initial_seed_and_info[0] = proc.seed
|
initial_seed_and_info[0] = proc.seed
|
||||||
initial_seed_and_info[1] = proc.info
|
initial_seed_and_info[1] = proc.info
|
||||||
|
|
||||||
out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
|
for n in range(count):
|
||||||
out = out.crop((0, 0, res_w, res_h))
|
output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
|
||||||
return out
|
output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
|
||||||
|
|
||||||
img = init_image
|
return output_images
|
||||||
|
|
||||||
if left > 0:
|
batch_count = p.n_iter
|
||||||
img = expand(img, left, is_left=True)
|
batch_size = p.batch_size
|
||||||
if right > 0:
|
p.n_iter = 1
|
||||||
img = expand(img, right, is_right=True)
|
state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
|
||||||
if up > 0:
|
all_processed_images = []
|
||||||
img = expand(img, up, is_top=True)
|
|
||||||
if down > 0:
|
|
||||||
img = expand(img, down, is_bottom=True)
|
|
||||||
|
|
||||||
res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
|
for i in range(batch_count):
|
||||||
|
imgs = [init_img] * batch_size
|
||||||
|
state.job = f"Batch {i + 1} out of {batch_count}"
|
||||||
|
|
||||||
|
if left > 0:
|
||||||
|
imgs = expand(imgs, batch_size, left, is_left=True)
|
||||||
|
if right > 0:
|
||||||
|
imgs = expand(imgs, batch_size, right, is_right=True)
|
||||||
|
if up > 0:
|
||||||
|
imgs = expand(imgs, batch_size, up, is_top=True)
|
||||||
|
if down > 0:
|
||||||
|
imgs = expand(imgs, batch_size, down, is_bottom=True)
|
||||||
|
|
||||||
|
all_processed_images += imgs
|
||||||
|
|
||||||
|
all_images = all_processed_images
|
||||||
|
|
||||||
|
combined_grid_image = images.image_grid(all_processed_images)
|
||||||
|
unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
|
||||||
|
if opts.return_grid and not unwanted_grid_because_of_img_count:
|
||||||
|
all_images = [combined_grid_image] + all_processed_images
|
||||||
|
|
||||||
|
res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
|
||||||
|
|
||||||
if opts.samples_save:
|
if opts.samples_save:
|
||||||
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
|
for img in all_processed_images:
|
||||||
|
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
|
||||||
|
|
||||||
|
if opts.grid_save and not unwanted_grid_because_of_img_count:
|
||||||
|
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
|
|||||||
if info is None:
|
if info is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
p.sd_model = shared.sd_model
|
||||||
|
|
||||||
|
|
||||||
def confirm_checkpoints(p, xs):
|
def confirm_checkpoints(p, xs):
|
||||||
|
12
webui.py
12
webui.py
@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
def initialize():
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
modules.sd_models.setup_model()
|
modules.sd_models.setup_model()
|
||||||
@ -79,9 +80,9 @@ def initialize():
|
|||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
modelloader.load_upscalers()
|
modelloader.load_upscalers()
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts()
|
||||||
|
|
||||||
shared.sd_model = modules.sd_models.load_model()
|
modules.sd_models.load_model()
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
@ -118,7 +119,8 @@ def api_only():
|
|||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
def webui(launch_api=False):
|
def webui():
|
||||||
|
launch_api = cmd_opts.api
|
||||||
initialize()
|
initialize()
|
||||||
|
|
||||||
while 1:
|
while 1:
|
||||||
@ -144,7 +146,7 @@ def webui(launch_api=False):
|
|||||||
sd_samplers.set_samplers()
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
print('Reloading Custom Scripts')
|
print('Reloading Custom Scripts')
|
||||||
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.reload_scripts()
|
||||||
print('Reloading modules: modules.ui')
|
print('Reloading modules: modules.ui')
|
||||||
importlib.reload(modules.ui)
|
importlib.reload(modules.ui)
|
||||||
print('Refreshing Model List')
|
print('Refreshing Model List')
|
||||||
@ -158,4 +160,4 @@ if __name__ == "__main__":
|
|||||||
if cmd_opts.nowebui:
|
if cmd_opts.nowebui:
|
||||||
api_only()
|
api_only()
|
||||||
else:
|
else:
|
||||||
webui(cmd_opts.api)
|
webui()
|
||||||
|
Loading…
Reference in New Issue
Block a user