mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-21 13:50:12 +08:00
Merge branch 'master' into test_resolve_conflicts
This commit is contained in:
commit
7b7561f6e4
36
.github/workflows/on_pull_request.yaml
vendored
Normal file
36
.github/workflows/on_pull_request.yaml
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
|
||||
name: Run Linting/Formatting on Pull Requests
|
||||
|
||||
on:
|
||||
- push
|
||||
- pull_request
|
||||
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
|
||||
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
|
||||
# pull_request:
|
||||
# branches:
|
||||
# - master
|
||||
# branches-ignore:
|
||||
# - development
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: 3.10.6
|
||||
- name: Install PyLint
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint
|
||||
# This lets PyLint check to see if it can resolve imports
|
||||
- name: Install dependencies
|
||||
run : |
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||
python launch.py
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint $(git ls-files '*.py')
|
3
.pylintrc
Normal file
3
.pylintrc
Normal file
@ -0,0 +1,3 @@
|
||||
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
|
||||
[MESSAGES CONTROL]
|
||||
disable=C,R,W,E,I
|
5
javascript/dragdrop.js
vendored
5
javascript/dragdrop.js
vendored
@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
|
||||
window.document.addEventListener('dragover', e => {
|
||||
const target = e.composedPath()[0];
|
||||
const imgWrap = target.closest('[data-testid="image"]');
|
||||
if ( !imgWrap ) {
|
||||
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
|
||||
return;
|
||||
}
|
||||
e.stopPropagation();
|
||||
@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => {
|
||||
|
||||
window.document.addEventListener('drop', e => {
|
||||
const target = e.composedPath()[0];
|
||||
if (target.placeholder.indexOf("Prompt") == -1) {
|
||||
return;
|
||||
}
|
||||
const imgWrap = target.closest('[data-testid="image"]');
|
||||
if ( !imgWrap ) {
|
||||
return;
|
||||
|
@ -2,6 +2,8 @@ addEventListener('keydown', (event) => {
|
||||
let target = event.originalTarget || event.composedPath()[0];
|
||||
if (!target.hasAttribute("placeholder")) return;
|
||||
if (!target.placeholder.toLowerCase().includes("prompt")) return;
|
||||
if (! (event.metaKey || event.ctrlKey)) return;
|
||||
|
||||
|
||||
let plus = "ArrowUp"
|
||||
let minus = "ArrowDown"
|
||||
|
@ -16,6 +16,8 @@ titles = {
|
||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
|
||||
"\u{1f4c2}": "Open images output directory",
|
||||
"\u{1f4be}": "Save style",
|
||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||
|
||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||
@ -87,8 +89,8 @@ titles = {
|
||||
|
||||
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
|
||||
|
||||
"Weighted Sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * (1 - M)",
|
||||
"Weighted sum": "Result = A * (1 - M) + B * M",
|
||||
"Add difference": "Result = A + (B - C) * M",
|
||||
}
|
||||
|
||||
|
||||
|
19
javascript/imageParams.js
Normal file
19
javascript/imageParams.js
Normal file
@ -0,0 +1,19 @@
|
||||
window.onload = (function(){
|
||||
window.addEventListener('drop', e => {
|
||||
const target = e.composedPath()[0];
|
||||
const idx = selected_gallery_index();
|
||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||
|
||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||
|
||||
e.stopPropagation();
|
||||
e.preventDefault();
|
||||
const imgParent = gradioApp().getElementById(prompt_target);
|
||||
const files = e.dataTransfer.files;
|
||||
const fileInput = imgParent.querySelector('input[type="file"]');
|
||||
if ( fileInput ) {
|
||||
fileInput.files = files;
|
||||
fileInput.dispatchEvent(new Event('change'));
|
||||
}
|
||||
});
|
||||
});
|
206
javascript/images_history.js
Normal file
206
javascript/images_history.js
Normal file
@ -0,0 +1,206 @@
|
||||
var images_history_click_image = function(){
|
||||
if (!this.classList.contains("transform")){
|
||||
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var i = 0;
|
||||
var hidden_list = [];
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
hidden_list.push(i);
|
||||
}
|
||||
i += 1;
|
||||
})
|
||||
if (hidden_list.length > 0){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
}
|
||||
images_history_set_image_info(this);
|
||||
}
|
||||
|
||||
var images_history_click_tab = function(){
|
||||
var tabs_box = gradioApp().getElementById("images_history_tab");
|
||||
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
|
||||
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
|
||||
tabs_box.classList.add(this.getAttribute("tabname"))
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_disabled_del(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.setAttribute('disabled','disabled');
|
||||
});
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_class(item, class_name){
|
||||
var parent = item.parentElement;
|
||||
while(!parent.classList.contains(class_name)){
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_get_parent_by_tagname(item, tagname){
|
||||
var parent = item.parentElement;
|
||||
tagname = tagname.toUpperCase()
|
||||
while(parent.tagName != tagname){
|
||||
console.log(parent.tagName, tagname)
|
||||
parent = parent.parentElement;
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
function images_history_hide_buttons(hidden_list, gallery){
|
||||
var buttons = gallery.querySelectorAll(".gallery-item");
|
||||
var num = 0;
|
||||
buttons.forEach(function(e){
|
||||
if (e.style.display == "none"){
|
||||
num += 1;
|
||||
}
|
||||
});
|
||||
if (num == hidden_list.length){
|
||||
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
|
||||
}
|
||||
for( i in hidden_list){
|
||||
buttons[hidden_list[i]].style.display = "none";
|
||||
}
|
||||
}
|
||||
|
||||
function images_history_set_image_info(button){
|
||||
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
|
||||
var index = -1;
|
||||
var i = 0;
|
||||
buttons.forEach(function(e){
|
||||
if(e == button){
|
||||
index = i;
|
||||
}
|
||||
if(e.style.display != "none"){
|
||||
i += 1;
|
||||
}
|
||||
});
|
||||
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
|
||||
var set_btn = gallery.querySelector(".images_history_set_index");
|
||||
var curr_idx = set_btn.getAttribute("img_index", index);
|
||||
if (curr_idx != index) {
|
||||
set_btn.setAttribute("img_index", index);
|
||||
images_history_disabled_del();
|
||||
}
|
||||
set_btn.click();
|
||||
|
||||
}
|
||||
|
||||
function images_history_get_current_img(tabname, image_path, files){
|
||||
return [
|
||||
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
|
||||
image_path,
|
||||
files
|
||||
];
|
||||
}
|
||||
|
||||
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
|
||||
image_index = parseInt(image_index);
|
||||
var tab = gradioApp().getElementById(tabname + '_images_history');
|
||||
var set_btn = tab.querySelector(".images_history_set_index");
|
||||
var buttons = [];
|
||||
tab.querySelectorAll(".gallery-item").forEach(function(e){
|
||||
if (e.style.display != 'none'){
|
||||
buttons.push(e);
|
||||
}
|
||||
});
|
||||
var img_num = buttons.length / 2;
|
||||
if (img_num <= del_num){
|
||||
setTimeout(function(tabname){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, 30, tabname);
|
||||
} else {
|
||||
var next_img
|
||||
for (var i = 0; i < del_num; i++){
|
||||
if (image_index + i < image_index + img_num){
|
||||
buttons[image_index + i].style.display = 'none';
|
||||
buttons[image_index + img_num + 1].style.display = 'none';
|
||||
next_img = image_index + i + 1
|
||||
}
|
||||
}
|
||||
var bnt;
|
||||
if (next_img >= img_num){
|
||||
btn = buttons[image_index - del_num];
|
||||
} else {
|
||||
btn = buttons[next_img];
|
||||
}
|
||||
setTimeout(function(btn){btn.click()}, 30, btn);
|
||||
}
|
||||
images_history_disabled_del();
|
||||
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
|
||||
}
|
||||
|
||||
function images_history_turnpage(img_path, page_index, image_index, tabname){
|
||||
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
|
||||
buttons.forEach(function(elem) {
|
||||
elem.style.display = 'block';
|
||||
})
|
||||
return [img_path, page_index, image_index, tabname];
|
||||
}
|
||||
|
||||
function images_history_enable_del_buttons(){
|
||||
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
|
||||
btn.removeAttribute('disabled');
|
||||
})
|
||||
}
|
||||
|
||||
function images_history_init(){
|
||||
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
|
||||
if (load_txt2img_button){
|
||||
for (var i in images_history_tab_list ){
|
||||
tab = images_history_tab_list[i];
|
||||
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
|
||||
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
|
||||
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
|
||||
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
|
||||
|
||||
}
|
||||
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
|
||||
tabs_box.setAttribute("id", "images_history_tab");
|
||||
var tab_btns = tabs_box.querySelectorAll("button");
|
||||
for (var i in images_history_tab_list){
|
||||
var tabname = images_history_tab_list[i]
|
||||
tab_btns[i].setAttribute("tabname", tabname);
|
||||
|
||||
// this refreshes history upon tab switch
|
||||
// until the history is known to work well, which is not the case now, we do not do this at startup
|
||||
//tab_btns[i].addEventListener('click', images_history_click_tab);
|
||||
}
|
||||
tabs_box.classList.add(images_history_tab_list[0]);
|
||||
|
||||
// same as above, at page load
|
||||
//load_txt2img_button.click();
|
||||
} else {
|
||||
setTimeout(images_history_init, 500);
|
||||
}
|
||||
}
|
||||
|
||||
var images_history_tab_list = ["txt2img", "img2img", "extras"];
|
||||
setTimeout(images_history_init, 500);
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var mutationObserver = new MutationObserver(function(m){
|
||||
for (var i in images_history_tab_list ){
|
||||
let tabname = images_history_tab_list[i]
|
||||
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
|
||||
buttons.forEach(function(bnt){
|
||||
bnt.addEventListener('click', images_history_click_image, true);
|
||||
});
|
||||
|
||||
// same as load_txt2img_button.click() above
|
||||
/*
|
||||
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
|
||||
if (cls_btn){
|
||||
cls_btn.addEventListener('click', function(){
|
||||
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
|
||||
}, false);
|
||||
}*/
|
||||
|
||||
}
|
||||
});
|
||||
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
|
||||
|
||||
});
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
// code related to showing and updating progressbar shown as the image is being made
|
||||
global_progressbars = {}
|
||||
galleries = {}
|
||||
galleryObservers = {}
|
||||
|
||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
|
||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||
@ -31,21 +33,54 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
|
||||
preview.style.width = gallery.clientWidth + "px"
|
||||
preview.style.height = gallery.clientHeight + "px"
|
||||
|
||||
//only watch gallery if there is a generation process going on
|
||||
check_gallery(id_gallery);
|
||||
|
||||
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
|
||||
if(!progressDiv){
|
||||
if (skip) {
|
||||
skip.style.display = "none"
|
||||
}
|
||||
interrupt.style.display = "none"
|
||||
|
||||
//disconnect observer once generation finished, so user can close selected image if they want
|
||||
if (galleryObservers[id_gallery]) {
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
galleries[id_gallery] = null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
|
||||
});
|
||||
mutationObserver.observe( progressbar, { childList:true, subtree:true })
|
||||
}
|
||||
}
|
||||
|
||||
function check_gallery(id_gallery){
|
||||
let gallery = gradioApp().getElementById(id_gallery)
|
||||
// if gallery has no change, no need to setting up observer again.
|
||||
if (gallery && galleries[id_gallery] !== gallery){
|
||||
galleries[id_gallery] = gallery;
|
||||
if(galleryObservers[id_gallery]){
|
||||
galleryObservers[id_gallery].disconnect();
|
||||
}
|
||||
let prevSelectedIndex = selected_gallery_index();
|
||||
galleryObservers[id_gallery] = new MutationObserver(function (){
|
||||
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
|
||||
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
|
||||
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
|
||||
//automatically re-open previously selected index (if exists)
|
||||
galleryButtons[prevSelectedIndex].click();
|
||||
showGalleryImage();
|
||||
}
|
||||
})
|
||||
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
|
||||
}
|
||||
}
|
||||
|
||||
onUiUpdate(function(){
|
||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||
|
@ -141,7 +141,7 @@ function submit_img2img(){
|
||||
|
||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||
name_ = prompt('Style name:')
|
||||
return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text]
|
||||
return [name_, prompt_text, negative_prompt_text]
|
||||
}
|
||||
|
||||
|
||||
@ -187,12 +187,10 @@ onUiUpdate(function(){
|
||||
if (!txt2img_textarea) {
|
||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
||||
}
|
||||
if (!img2img_textarea) {
|
||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
||||
}
|
||||
})
|
||||
|
||||
@ -220,14 +218,6 @@ function update_token_counter(button_id) {
|
||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||
}
|
||||
|
||||
function submit_prompt(event, generate_button_id) {
|
||||
if (event.altKey && event.keyCode === 13) {
|
||||
event.preventDefault();
|
||||
gradioApp().getElementById(generate_button_id).click();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
function restart_reload(){
|
||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||
setTimeout(function(){location.reload()},2000)
|
||||
|
@ -9,6 +9,7 @@ import platform
|
||||
dir_repos = "repositories"
|
||||
python = sys.executable
|
||||
git = os.environ.get('GIT', "git")
|
||||
index_url = os.environ.get('INDEX_URL', "")
|
||||
|
||||
|
||||
def extract_arg(args, name):
|
||||
@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None):
|
||||
|
||||
|
||||
def run_pip(args, desc=None):
|
||||
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
||||
|
||||
|
||||
def check_run_python(code):
|
||||
@ -76,7 +78,7 @@ def git_clone(url, dir, name, commithash=None):
|
||||
return
|
||||
|
||||
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||
return
|
||||
|
||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||
|
@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
|
||||
|
||||
tags = dd.project.load_tags_from_project(model_path)
|
||||
model = dd.project.load_model_from_project(
|
||||
model_path, compile_model=True
|
||||
model_path, compile_model=False
|
||||
)
|
||||
return model, tags
|
||||
|
||||
|
@ -34,7 +34,7 @@ def enable_tf32():
|
||||
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
|
@ -159,24 +159,12 @@ def run_pnginfo(image):
|
||||
return '', geninfo, info
|
||||
|
||||
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name):
|
||||
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
|
||||
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
|
||||
def weighted_sum(theta0, theta1, theta2, alpha):
|
||||
return ((1 - alpha) * theta0) + (alpha * theta1)
|
||||
|
||||
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
||||
def sigmoid(theta0, theta1, theta2, alpha):
|
||||
alpha = alpha * alpha * (3 - (2 * alpha))
|
||||
return theta0 + ((theta1 - theta0) * alpha)
|
||||
|
||||
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
|
||||
def inv_sigmoid(theta0, theta1, theta2, alpha):
|
||||
import math
|
||||
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
|
||||
return theta0 + ((theta1 - theta0) * alpha)
|
||||
|
||||
def add_difference(theta0, theta1, theta2, alpha):
|
||||
return theta0 + (theta1 - theta2) * (1.0 - alpha)
|
||||
return theta0 + (theta1 - theta2) * alpha
|
||||
|
||||
primary_model_info = sd_models.checkpoints_list[primary_model_name]
|
||||
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
|
||||
@ -198,9 +186,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
theta_2 = None
|
||||
|
||||
theta_funcs = {
|
||||
"Weighted Sum": weighted_sum,
|
||||
"Sigmoid": sigmoid,
|
||||
"Inverse Sigmoid": inv_sigmoid,
|
||||
"Weighted sum": weighted_sum,
|
||||
"Add difference": add_difference,
|
||||
}
|
||||
theta_func = theta_funcs[interp_method]
|
||||
@ -209,7 +195,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
|
||||
for key in tqdm.tqdm(theta_0.keys()):
|
||||
if 'model' in key and key in theta_1:
|
||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
|
||||
t2 = (theta_2 or {}).get(key)
|
||||
if t2 is None:
|
||||
t2 = torch.zeros_like(theta_0[key])
|
||||
|
||||
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
|
||||
|
||||
if save_as_half:
|
||||
theta_0[key] = theta_0[key].half()
|
||||
|
||||
@ -222,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
|
||||
|
||||
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||
|
||||
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||
output_modelname = os.path.join(ckpt_dir, filename)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import os
|
||||
import sys
|
||||
import traceback
|
||||
import tqdm
|
||||
import csv
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,6 +15,7 @@ import torch
|
||||
from torch import einsum
|
||||
from einops import rearrange, repeat
|
||||
import modules.textual_inversion.dataset
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
|
||||
|
||||
@ -180,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
|
||||
def stack_conds(conds):
|
||||
if len(conds) == 1:
|
||||
return torch.stack(conds)
|
||||
|
||||
# same as in reconstruct_multicond_batch
|
||||
token_count = max([x.shape[0] for x in conds])
|
||||
for i in range(len(conds)):
|
||||
if conds[i].shape[0] != token_count:
|
||||
last_vector = conds[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
||||
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(conds)
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
@ -209,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
|
||||
|
||||
if unload:
|
||||
shared.sd_model.cond_stage_model.to(devices.cpu)
|
||||
@ -233,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
|
||||
for i, entry in pbar:
|
||||
for i, entries in pbar:
|
||||
hypernetwork.step = i + ititial_step
|
||||
|
||||
scheduler.apply(optimizer, hypernetwork.step)
|
||||
@ -244,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
cond = entry.cond.to(devices.device)
|
||||
x = entry.latent.to(devices.device)
|
||||
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
|
||||
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
|
||||
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
del x
|
||||
del cond
|
||||
del c
|
||||
|
||||
losses[hypernetwork.step % losses.shape[0]] = loss.item()
|
||||
|
||||
@ -262,23 +279,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
|
||||
hypernetwork.save(last_saved_file)
|
||||
|
||||
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
|
||||
"loss": f"{losses.mean():.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
|
||||
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
shared.sd_model.first_stage_model.to(devices.device)
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=preview_text,
|
||||
steps=20,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_index = preview_sampler_index
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0] if len(processed.images)>0 else None
|
||||
|
||||
@ -297,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {hypernetwork.step}<br/>
|
||||
Last prompt: {html.escape(entry.cond_text)}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
from collections import namedtuple
|
||||
@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
|
||||
rows = opts.n_rows
|
||||
elif opts.n_rows == 0:
|
||||
rows = batch_size
|
||||
elif opts.grid_prevent_empty_spots:
|
||||
rows = math.floor(math.sqrt(len(imgs)))
|
||||
while len(imgs) % rows != 0:
|
||||
rows -= 1
|
||||
else:
|
||||
rows = math.sqrt(len(imgs))
|
||||
rows = round(rows)
|
||||
@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
txt_fullfn = None
|
||||
|
||||
return fullfn, txt_fullfn
|
||||
|
||||
|
||||
def image_data(data):
|
||||
try:
|
||||
image = Image.open(io.BytesIO(data))
|
||||
textinfo = image.text["parameters"]
|
||||
return textinfo, None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
text = data.decode('utf8')
|
||||
assert len(text) < 10000
|
||||
return text, None
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return '', None
|
||||
|
181
modules/images_history.py
Normal file
181
modules/images_history.py
Normal file
@ -0,0 +1,181 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
def traverse_all_files(output_dir, image_list, curr_dir=None):
|
||||
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
|
||||
try:
|
||||
f_list = os.listdir(curr_path)
|
||||
except:
|
||||
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
|
||||
image_list.append(curr_dir)
|
||||
return image_list
|
||||
for file in f_list:
|
||||
file = file if curr_dir is None else os.path.join(curr_dir, file)
|
||||
file_path = os.path.join(curr_path, file)
|
||||
if file[-4:] == ".txt":
|
||||
pass
|
||||
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
|
||||
image_list.append(file)
|
||||
else:
|
||||
image_list = traverse_all_files(output_dir, image_list, file)
|
||||
return image_list
|
||||
|
||||
|
||||
def get_recent_images(dir_name, page_index, step, image_index, tabname):
|
||||
page_index = int(page_index)
|
||||
f_list = os.listdir(dir_name)
|
||||
image_list = []
|
||||
image_list = traverse_all_files(dir_name, image_list)
|
||||
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
|
||||
num = 48 if tabname != "extras" else 12
|
||||
max_page_index = len(image_list) // num + 1
|
||||
page_index = max_page_index if page_index == -1 else page_index + step
|
||||
page_index = 1 if page_index < 1 else page_index
|
||||
page_index = max_page_index if page_index > max_page_index else page_index
|
||||
idx_frm = (page_index - 1) * num
|
||||
image_list = image_list[idx_frm:idx_frm + num]
|
||||
image_index = int(image_index)
|
||||
if image_index < 0 or image_index > len(image_list) - 1:
|
||||
current_file = None
|
||||
hidden = None
|
||||
else:
|
||||
current_file = image_list[int(image_index)]
|
||||
hidden = os.path.join(dir_name, current_file)
|
||||
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
|
||||
|
||||
|
||||
def first_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, 1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def end_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, -1, 0, image_index, tabname)
|
||||
|
||||
|
||||
def prev_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
|
||||
|
||||
|
||||
def next_page_click(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
|
||||
|
||||
|
||||
def page_index_change(dir_name, page_index, image_index, tabname):
|
||||
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
|
||||
|
||||
|
||||
def show_image_info(num, image_path, filenames):
|
||||
# print(f"select image {num}")
|
||||
file = filenames[int(num)]
|
||||
return file, num, os.path.join(image_path, file)
|
||||
|
||||
|
||||
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
|
||||
if name == "":
|
||||
return filenames, delete_num
|
||||
else:
|
||||
delete_num = int(delete_num)
|
||||
index = list(filenames).index(name)
|
||||
i = 0
|
||||
new_file_list = []
|
||||
for name in filenames:
|
||||
if i >= index and i < index + delete_num:
|
||||
path = os.path.join(dir_name, name)
|
||||
if os.path.exists(path):
|
||||
print(f"Delete file {path}")
|
||||
os.remove(path)
|
||||
txt_file = os.path.splitext(path)[0] + ".txt"
|
||||
if os.path.exists(txt_file):
|
||||
os.remove(txt_file)
|
||||
else:
|
||||
print(f"Not exists file {path}")
|
||||
else:
|
||||
new_file_list.append(name)
|
||||
i += 1
|
||||
return new_file_list, 1
|
||||
|
||||
|
||||
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
|
||||
if opts.outdir_samples != "":
|
||||
dir_name = opts.outdir_samples
|
||||
elif tabname == "txt2img":
|
||||
dir_name = opts.outdir_txt2img_samples
|
||||
elif tabname == "img2img":
|
||||
dir_name = opts.outdir_img2img_samples
|
||||
elif tabname == "extras":
|
||||
dir_name = opts.outdir_extras_samples
|
||||
d = dir_name.split("/")
|
||||
dir_name = "/" if dir_name.startswith("/") else d[0]
|
||||
for p in d[1:]:
|
||||
dir_name = os.path.join(dir_name, p)
|
||||
with gr.Row():
|
||||
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
|
||||
first_page = gr.Button('First Page')
|
||||
prev_page = gr.Button('Prev Page')
|
||||
page_index = gr.Number(value=1, label="Page Index")
|
||||
next_page = gr.Button('Next Page')
|
||||
end_page = gr.Button('End Page')
|
||||
with gr.Row(elem_id=tabname + "_images_history"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
|
||||
with gr.Row():
|
||||
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
|
||||
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
|
||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
|
||||
img_file_name = gr.Textbox(label="File Name", interactive=False)
|
||||
with gr.Row():
|
||||
# hiden items
|
||||
|
||||
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
|
||||
tabname_box = gr.Textbox(tabname, visible=False)
|
||||
image_index = gr.Textbox(value=-1, visible=False)
|
||||
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
|
||||
filenames = gr.State()
|
||||
hidden = gr.Image(type="pil", visible=False)
|
||||
info1 = gr.Textbox(visible=False)
|
||||
info2 = gr.Textbox(visible=False)
|
||||
|
||||
# turn pages
|
||||
gallery_inputs = [img_path, page_index, image_index, tabname_box]
|
||||
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
|
||||
|
||||
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
|
||||
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
|
||||
|
||||
# other funcitons
|
||||
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
|
||||
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
|
||||
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
|
||||
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
|
||||
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
|
||||
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
|
||||
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
|
||||
|
||||
|
||||
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history:
|
||||
with gr.Tabs() as tabs:
|
||||
with gr.Tab("txt2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
|
||||
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("img2img history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
|
||||
with gr.Tab("extras history"):
|
||||
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
|
||||
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
|
||||
return images_history
|
@ -55,7 +55,7 @@ class InterrogateModels:
|
||||
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
model.eval()
|
||||
model = model.to(shared.device)
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
return model, preprocess
|
||||
|
||||
@ -65,14 +65,14 @@ class InterrogateModels:
|
||||
if not shared.cmd_opts.no_half:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(shared.device)
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
self.clip_model = self.clip_model.half()
|
||||
|
||||
self.clip_model = self.clip_model.to(shared.device)
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
||||
self.dtype = next(self.clip_model.parameters()).dtype
|
||||
|
||||
@ -99,11 +99,11 @@ class InterrogateModels:
|
||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
|
||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
||||
similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
@ -116,7 +116,7 @@ class InterrogateModels:
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||
])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
with torch.no_grad():
|
||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
||||
@ -140,7 +140,7 @@ class InterrogateModels:
|
||||
|
||||
res = caption
|
||||
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
|
||||
|
||||
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
with torch.no_grad(), precision_scope("cuda"):
|
||||
|
@ -145,9 +145,8 @@ class Processed:
|
||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
||||
self.subseed = int(
|
||||
self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
||||
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
||||
|
||||
self.all_prompts = all_prompts or [self.prompt]
|
||||
self.all_seeds = all_seeds or [self.seed]
|
||||
@ -541,16 +540,15 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
|
||||
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
firstphase_width = 0
|
||||
firstphase_height = 0
|
||||
firstphase_width_truncated = 0
|
||||
firstphase_height_truncated = 0
|
||||
|
||||
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs):
|
||||
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.scale_latent = scale_latent
|
||||
self.denoising_strength = denoising_strength
|
||||
self.firstphase_width = firstphase_width
|
||||
self.firstphase_height = firstphase_height
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
if self.enable_hr:
|
||||
@ -559,14 +557,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
else:
|
||||
state.job_count = state.job_count * 2
|
||||
|
||||
if self.firstphase_width == 0 or self.firstphase_height == 0:
|
||||
desired_pixel_count = 512 * 512
|
||||
actual_pixel_count = self.width * self.height
|
||||
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
||||
|
||||
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
||||
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
||||
self.firstphase_width_truncated = int(scale * self.width)
|
||||
self.firstphase_height_truncated = int(scale * self.height)
|
||||
firstphase_width_truncated = int(scale * self.width)
|
||||
firstphase_height_truncated = int(scale * self.height)
|
||||
|
||||
else:
|
||||
|
||||
width_ratio = self.width / self.firstphase_width
|
||||
height_ratio = self.height / self.firstphase_height
|
||||
|
||||
if width_ratio > height_ratio:
|
||||
firstphase_width_truncated = self.firstphase_width
|
||||
firstphase_height_truncated = self.firstphase_width * self.height / self.width
|
||||
else:
|
||||
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
||||
firstphase_height_truncated = self.firstphase_height
|
||||
|
||||
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
||||
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
||||
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
||||
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||
@ -585,22 +600,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
||||
|
||||
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f
|
||||
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
|
||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
||||
|
||||
samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2,
|
||||
truncate_x // 2:samples.shape[3] - truncate_x // 2]
|
||||
|
||||
if self.scale_latent:
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
|
||||
mode="bilinear")
|
||||
if opts.use_scale_latent_for_hires_fix:
|
||||
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||
else:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
||||
|
||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
|
||||
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width),
|
||||
mode="bilinear")
|
||||
else:
|
||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
batch_images = []
|
||||
|
@ -96,11 +96,18 @@ def load(filename, *args, **kwargs):
|
||||
if not shared.cmd_opts.disable_safe_unpickle:
|
||||
check_pt(filename)
|
||||
|
||||
except pickle.UnpicklingError:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return unsafe_torch_load(filename, *args, **kwargs)
|
||||
|
@ -1,4 +1,4 @@
|
||||
import glob
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
|
||||
checkpoints_list = {}
|
||||
checkpoints_loaded = collections.OrderedDict()
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
@ -132,14 +133,14 @@ def load_model_weights(model, checkpoint_info):
|
||||
checkpoint_file = checkpoint_info.filename
|
||||
sd_model_hash = checkpoint_info.hash
|
||||
|
||||
if checkpoint_info not in checkpoints_loaded:
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
@ -158,13 +159,20 @@ def load_model_weights(model, checkpoint_info):
|
||||
|
||||
if os.path.exists(vae_file):
|
||||
print(f"Loading VAE weights from: {vae_file}")
|
||||
vae_ckpt = torch.load(vae_file, map_location="cpu")
|
||||
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
|
||||
model.first_stage_model.load_state_dict(vae_dict)
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
else:
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_file
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
@ -202,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
|
||||
return
|
||||
|
||||
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
|
||||
checkpoints_loaded.clear()
|
||||
shared.sd_model = load_model()
|
||||
return shared.sd_model
|
||||
|
||||
|
@ -36,6 +36,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
|
||||
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
|
||||
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
|
||||
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
|
||||
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
|
||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
@ -56,7 +57,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@ -78,10 +79,11 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
|
||||
|
||||
device = devices.device
|
||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
|
||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||
@ -184,6 +186,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
|
||||
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
|
||||
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
|
||||
|
||||
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
|
||||
@ -224,6 +227,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||
@ -242,11 +246,13 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
|
||||
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
|
||||
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
|
||||
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
|
||||
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|
||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||
@ -260,7 +266,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||
@ -288,6 +293,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
|
@ -24,11 +24,12 @@ class DatasetEntry:
|
||||
|
||||
|
||||
class PersonalizedBase(Dataset):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
|
||||
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
|
||||
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
|
||||
|
||||
self.placeholder_token = placeholder_token
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||
@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
if include_cond:
|
||||
entry.cond_text = self.create_text(filename_text)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
|
||||
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
|
||||
|
||||
self.dataset.append(entry)
|
||||
|
||||
self.length = len(self.dataset) * repeats
|
||||
assert len(self.dataset) > 1, "No images have been found in the dataset."
|
||||
self.length = len(self.dataset) * repeats // batch_size
|
||||
|
||||
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||
self.initial_indexes = np.arange(len(self.dataset))
|
||||
self.indexes = None
|
||||
self.shuffle()
|
||||
|
||||
@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
if i % len(self.dataset) == 0:
|
||||
res = []
|
||||
|
||||
for j in range(self.batch_size):
|
||||
position = i * self.batch_size + j
|
||||
if position % len(self.indexes) == 0:
|
||||
self.shuffle()
|
||||
|
||||
index = self.indexes[i % len(self.indexes)]
|
||||
index = self.indexes[position % len(self.indexes)]
|
||||
entry = self.dataset[index]
|
||||
|
||||
if entry.cond is None:
|
||||
entry.cond_text = self.create_text(entry.filename_text)
|
||||
|
||||
return entry
|
||||
res.append(entry)
|
||||
|
||||
return res
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import tqdm
|
||||
import html
|
||||
import datetime
|
||||
import csv
|
||||
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
@ -172,15 +173,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
def batched(dataset, total, n=1):
|
||||
for ndx in range(0, total, n):
|
||||
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
|
||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
|
||||
if step % shared.opts.training_write_csv_every != 0:
|
||||
return
|
||||
|
||||
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
|
||||
|
||||
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
|
||||
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
|
||||
|
||||
if write_csv_header:
|
||||
csv_writer.writeheader()
|
||||
|
||||
epoch = step // epoch_len
|
||||
epoch_step = step - epoch * epoch_len
|
||||
|
||||
csv_writer.writerow({
|
||||
"step": step + 1,
|
||||
"epoch": epoch + 1,
|
||||
"epoch_step": epoch_step + 1,
|
||||
**values,
|
||||
})
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps,
|
||||
create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding,
|
||||
preview_image_prompt, batch_size=1,
|
||||
gradient_accumulation=1):
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
assert embedding_name, 'embedding not selected'
|
||||
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
@ -212,11 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
|
||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||
with torch.autocast("cuda"):
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width,
|
||||
height=training_height,
|
||||
repeats=shared.opts.training_image_repeats_per_epoch,
|
||||
placeholder_token=embedding_name, model=shared.sd_model,
|
||||
device=devices.device, template_file=template_file)
|
||||
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
|
||||
|
||||
hijack = sd_hijack.model_hijack
|
||||
|
||||
@ -235,8 +250,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
|
||||
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
|
||||
|
||||
pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step)
|
||||
for i, entry in pbar:
|
||||
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||
for i, entries in pbar:
|
||||
embedding.step = i + ititial_step
|
||||
|
||||
scheduler.apply(optimizer, embedding.step)
|
||||
@ -247,11 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
break
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
c = cond_model([e.cond_text for e in entry])
|
||||
|
||||
x = torch.stack([e.latent for e in entry]).to(devices.device)
|
||||
c = cond_model([entry.cond_text for entry in entries])
|
||||
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
|
||||
loss = shared.sd_model(x, c)[0]
|
||||
|
||||
del x
|
||||
|
||||
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||
@ -271,21 +284,37 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||
embedding.save(last_saved_file)
|
||||
|
||||
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
|
||||
"loss": f"{losses.mean():.7f}",
|
||||
"learn_rate": scheduler.learn_rate
|
||||
})
|
||||
|
||||
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||
|
||||
preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt
|
||||
|
||||
p = processing.StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
prompt=preview_text,
|
||||
steps=20,
|
||||
height=training_height,
|
||||
width=training_width,
|
||||
do_not_save_grid=True,
|
||||
do_not_save_samples=True,
|
||||
)
|
||||
|
||||
if preview_from_txt2img:
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_index = preview_sampler_index
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
p.height = preview_height
|
||||
else:
|
||||
p.prompt = entries[0].cond_text
|
||||
p.steps = 20
|
||||
p.width = training_width
|
||||
p.height = training_height
|
||||
|
||||
preview_text = p.prompt
|
||||
|
||||
processed = processing.process_images(p)
|
||||
image = processed.images[0]
|
||||
|
||||
@ -320,7 +349,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
|
||||
<p>
|
||||
Loss: {losses.mean():.7f}<br/>
|
||||
Step: {embedding.step}<br/>
|
||||
Last prompt: {html.escape(entry[-1].cond_text)}<br/>
|
||||
Last prompt: {html.escape(entries[0].cond_text)}<br/>
|
||||
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||
</p>
|
||||
|
@ -6,18 +6,13 @@ import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
|
||||
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
|
||||
restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
|
||||
subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
|
||||
height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float,
|
||||
aesthetic_lr=0,
|
||||
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0,
|
||||
aesthetic_weight=0, aesthetic_steps=0,
|
||||
aesthetic_imgs=None,
|
||||
aesthetic_slerp=False,
|
||||
aesthetic_imgs_text="",
|
||||
aesthetic_slerp_angle=0.15,
|
||||
aesthetic_text_negative=False,
|
||||
*args):
|
||||
aesthetic_text_negative=False, *args):
|
||||
p = StableDiffusionProcessingTxt2Img(
|
||||
sd_model=shared.sd_model,
|
||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||
@ -41,8 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
||||
restore_faces=restore_faces,
|
||||
tiling=tiling,
|
||||
enable_hr=enable_hr,
|
||||
scale_latent=scale_latent if enable_hr else None,
|
||||
denoising_strength=denoising_strength if enable_hr else None,
|
||||
firstphase_width=firstphase_width if enable_hr else None,
|
||||
firstphase_height=firstphase_height if enable_hr else None,
|
||||
)
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
|
219
modules/ui.py
219
modules/ui.py
@ -7,6 +7,7 @@ import mimetypes
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
import platform
|
||||
@ -22,7 +23,7 @@ import gradio as gr
|
||||
import gradio.utils
|
||||
import gradio.routes
|
||||
|
||||
from modules import sd_hijack
|
||||
from modules import sd_hijack, sd_models
|
||||
from modules.paths import script_path
|
||||
from modules.shared import opts, cmd_opts,aesthetic_embeddings
|
||||
|
||||
@ -41,7 +42,10 @@ from modules import prompt_parser
|
||||
from modules.images import save_image
|
||||
import modules.textual_inversion.ui
|
||||
import modules.hypernetworks.ui
|
||||
|
||||
import modules.aesthetic_clip
|
||||
import modules.images_history as img_his
|
||||
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||
mimetypes.init()
|
||||
@ -81,6 +85,8 @@ art_symbol = '\U0001f3a8' # 🎨
|
||||
paste_symbol = '\u2199\ufe0f' # ↙
|
||||
folder_symbol = '\U0001f4c2' # 📂
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
save_style_symbol = '\U0001f4be' # 💾
|
||||
apply_style_symbol = '\U0001f4cb' # 📋
|
||||
|
||||
|
||||
def plaintext_to_html(text):
|
||||
@ -89,6 +95,14 @@ def plaintext_to_html(text):
|
||||
|
||||
|
||||
def image_from_url_text(filedata):
|
||||
if type(filedata) == dict and filedata["is_file"]:
|
||||
filename = filedata["name"]
|
||||
tempdir = os.path.normpath(tempfile.gettempdir())
|
||||
normfn = os.path.normpath(filename)
|
||||
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
|
||||
|
||||
return Image.open(filename)
|
||||
|
||||
if type(filedata) == list:
|
||||
if len(filedata) == 0:
|
||||
return None
|
||||
@ -177,6 +191,23 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def save_pil_to_file(pil_image, dir=None):
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
for key, value in pil_image.info.items():
|
||||
if isinstance(key, str) and isinstance(value, str):
|
||||
metadata.add_text(key, value)
|
||||
use_metadata = True
|
||||
|
||||
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
||||
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
||||
return file_obj
|
||||
|
||||
|
||||
# override save to file function so that it also writes PNG info
|
||||
gr.processing_utils.save_pil_to_file = save_pil_to_file
|
||||
|
||||
|
||||
def wrap_gradio_call(func, extra_outputs=None):
|
||||
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||
@ -305,7 +336,7 @@ def visit(x, func, path=""):
|
||||
|
||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
||||
if name is None:
|
||||
return [gr_show(), gr_show()]
|
||||
return [gr_show() for x in range(4)]
|
||||
|
||||
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
||||
shared.prompt_styles.styles[style.name] = style
|
||||
@ -430,30 +461,38 @@ def create_toprow(is_img2img):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
|
||||
with gr.Row(elem_id="toprow"):
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=6):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)
|
||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2,
|
||||
placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2,
|
||||
placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)"
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, elem_id="roll_col"):
|
||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||
save_style = gr.Button(value=save_style_symbol, elem_id="style_create")
|
||||
prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply")
|
||||
|
||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
|
||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=8):
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
|
||||
with gr.Column(scale=1, elem_id="roll_col"):
|
||||
sh = gr.Button(elem_id="sh", visible=True)
|
||||
|
||||
with gr.Column(scale=1, elem_id="style_neg_col"):
|
||||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||
if cmd_opts.deepdanbooru:
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
@ -473,20 +512,14 @@ def create_toprow(is_img2img):
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with gr.Row(scale=1):
|
||||
if is_img2img:
|
||||
interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
if cmd_opts.deepdanbooru:
|
||||
deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
else:
|
||||
deepbooru = None
|
||||
else:
|
||||
interrogate = None
|
||||
deepbooru = None
|
||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
||||
save_style = gr.Button('Create style', elem_id="style_create")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="style_pos_col"):
|
||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||
|
||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||
with gr.Column(scale=1, elem_id="style_neg_col"):
|
||||
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())))
|
||||
|
||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||
|
||||
|
||||
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
@ -510,13 +543,40 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||
)
|
||||
|
||||
|
||||
def apply_setting(key, value):
|
||||
if value is None:
|
||||
return gr.update()
|
||||
|
||||
if key == "sd_model_checkpoint":
|
||||
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
||||
|
||||
if ckpt_info is not None:
|
||||
value = ckpt_info.title
|
||||
else:
|
||||
return gr.update()
|
||||
|
||||
comp_args = opts.data_labels[key].component_args
|
||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||
return
|
||||
|
||||
valtype = type(opts.data_labels[key].default)
|
||||
oldval = opts.data[key]
|
||||
opts.data[key] = valtype(value) if valtype != type(None) else value
|
||||
if oldval != value and opts.data_labels[key].onchange is not None:
|
||||
opts.data_labels[key].onchange()
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
return value
|
||||
|
||||
|
||||
def create_ui(wrap_gradio_gpu_call):
|
||||
import modules.img2img
|
||||
import modules.txt2img
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||
dummy_component = gr.Label(visible=False)
|
||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
||||
with gr.Row(elem_id='txt2img_progress_row'):
|
||||
with gr.Column(scale=1):
|
||||
@ -554,10 +614,11 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
enable_hr = gr.Checkbox(label='Highres. fix', value=False)
|
||||
|
||||
with gr.Row(visible=False) as hr_options:
|
||||
scale_latent = gr.Checkbox(label='Scale latent', value=False)
|
||||
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0)
|
||||
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0)
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1)
|
||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
|
||||
|
||||
@ -616,8 +677,9 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
height,
|
||||
width,
|
||||
enable_hr,
|
||||
scale_latent,
|
||||
denoising_strength,
|
||||
firstphase_width,
|
||||
firstphase_height,
|
||||
aesthetic_lr,
|
||||
aesthetic_weight,
|
||||
aesthetic_steps,
|
||||
@ -627,6 +689,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
aesthetic_slerp_angle,
|
||||
aesthetic_text_negative
|
||||
] + custom_inputs,
|
||||
|
||||
outputs=[
|
||||
txt2img_gallery,
|
||||
generation_info,
|
||||
@ -638,6 +701,17 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
txt2img_prompt.submit(**txt2img_args)
|
||||
submit.click(**txt2img_args)
|
||||
|
||||
txt_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
txt_prompt_img
|
||||
],
|
||||
outputs=[
|
||||
txt2img_prompt,
|
||||
txt_prompt_img
|
||||
]
|
||||
)
|
||||
|
||||
enable_hr.change(
|
||||
fn=lambda x: gr_show(x),
|
||||
inputs=[enable_hr],
|
||||
@ -690,14 +764,29 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
(denoising_strength, "Denoising strength"),
|
||||
(enable_hr, lambda d: "Denoising strength" in d),
|
||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||
(firstphase_width, "First pass size-1"),
|
||||
(firstphase_height, "First pass size-2"),
|
||||
]
|
||||
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
||||
|
||||
txt2img_preview_params = [
|
||||
txt2img_prompt,
|
||||
txt2img_negative_prompt,
|
||||
steps,
|
||||
sampler_index,
|
||||
cfg_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
]
|
||||
|
||||
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||
|
||||
with gr.Row(elem_id='img2img_progress_row'):
|
||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
|
||||
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
||||
@ -711,10 +800,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||
with gr.TabItem('img2img', id='img2img'):
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint'):
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480)
|
||||
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
|
||||
@ -792,6 +881,17 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||
|
||||
img2img_prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[
|
||||
img2img_prompt_img
|
||||
],
|
||||
outputs=[
|
||||
img2img_prompt,
|
||||
img2img_prompt_img
|
||||
]
|
||||
)
|
||||
|
||||
mask_mode.change(
|
||||
lambda mode, img: {
|
||||
init_img_with_mask: gr_show(mode == 0),
|
||||
@ -932,7 +1032,6 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
(seed_resize_from_h, "Seed resize from-2"),
|
||||
(denoising_strength, "Denoising strength"),
|
||||
]
|
||||
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||
@ -980,6 +1079,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||
|
||||
|
||||
submit.click(
|
||||
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||
_js="get_extras_tab_index",
|
||||
@ -1039,6 +1139,14 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
inputs=[image],
|
||||
outputs=[html, generation_info, html2],
|
||||
)
|
||||
#images history
|
||||
images_history_switch_dict = {
|
||||
"fn":modules.generation_parameters_copypaste.connect_paste,
|
||||
"t2i":txt2img_paste_fields,
|
||||
"i2i":img2img_paste_fields
|
||||
}
|
||||
|
||||
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
|
||||
|
||||
with gr.Blocks() as modelmerger_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
@ -1050,8 +1158,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||
custom_name = gr.Textbox(label="Custom Name (Optional)")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3)
|
||||
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method")
|
||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
|
||||
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
|
||||
save_as_half = gr.Checkbox(value=False, label="Save as float16")
|
||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
|
||||
|
||||
@ -1125,6 +1233,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
|
||||
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
|
||||
batch_size = gr.Number(label='Batch size', value=1, precision=0)
|
||||
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||
@ -1137,7 +1246,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
|
||||
preview_image_prompt = gr.Textbox(label='Preview prompt', value="")
|
||||
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
|
||||
|
||||
with gr.Row():
|
||||
interrupt_training = gr.Button(value="Interrupt")
|
||||
@ -1220,6 +1329,7 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
inputs=[
|
||||
train_embedding_name,
|
||||
learn_rate,
|
||||
batch_size,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
training_width,
|
||||
@ -1229,9 +1339,8 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
save_image_with_stored_embedding,
|
||||
preview_image_prompt,
|
||||
batch_size,
|
||||
gradient_accumulation
|
||||
preview_from_txt2img,
|
||||
*txt2img_preview_params,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
@ -1245,13 +1354,15 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
inputs=[
|
||||
train_hypernetwork_name,
|
||||
learn_rate,
|
||||
batch_size,
|
||||
dataset_directory,
|
||||
log_directory,
|
||||
steps,
|
||||
create_image_every,
|
||||
save_embedding_every,
|
||||
template_file,
|
||||
preview_image_prompt,
|
||||
preview_from_txt2img,
|
||||
*txt2img_preview_params,
|
||||
],
|
||||
outputs=[
|
||||
ti_output,
|
||||
@ -1463,6 +1574,7 @@ Requested path was: {f}
|
||||
(img2img_interface, "img2img", "img2img"),
|
||||
(extras_interface, "Extras", "extras"),
|
||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||
(images_history, "History", "images_history"),
|
||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||
(train_interface, "Train", "ti"),
|
||||
(settings_interface, "Settings", "settings"),
|
||||
@ -1603,8 +1715,22 @@ Requested path was: {f}
|
||||
outputs=[extras_image],
|
||||
)
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields, generation_info, 'switch_to_txt2img')
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields, generation_info, 'switch_to_img2img_img2img')
|
||||
settings_map = {
|
||||
'sd_hypernetwork': 'Hypernet',
|
||||
'CLIP_stop_at_last_layers': 'Clip skip',
|
||||
'sd_model_checkpoint': 'Model hash',
|
||||
}
|
||||
|
||||
settings_paste_fields = [
|
||||
(component_dict[k], lambda d, k=k, v=v: apply_setting(k, d.get(v, None)))
|
||||
for k, v in settings_map.items()
|
||||
]
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(txt2img_paste, txt2img_paste_fields + settings_paste_fields, txt2img_prompt)
|
||||
modules.generation_parameters_copypaste.connect_paste(img2img_paste, img2img_paste_fields + settings_paste_fields, img2img_prompt)
|
||||
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_txt2img, txt2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_txt2img')
|
||||
modules.generation_parameters_copypaste.connect_paste(pnginfo_send_to_img2img, img2img_paste_fields + settings_paste_fields, generation_info, 'switch_to_img2img_img2img')
|
||||
|
||||
ui_config_file = cmd_opts.ui_config_file
|
||||
ui_settings = {}
|
||||
@ -1686,3 +1812,4 @@ if 'gradio_routes_templates_response' not in globals():
|
||||
|
||||
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
||||
gradio.routes.templates.TemplateResponse = template_response
|
||||
|
||||
|
@ -4,7 +4,7 @@ fairscale==0.4.4
|
||||
fonts
|
||||
font-roboto
|
||||
gfpgan
|
||||
gradio==3.4.1
|
||||
gradio==3.5
|
||||
invisible-watermark
|
||||
numpy
|
||||
omegaconf
|
||||
|
@ -2,7 +2,7 @@ transformers==4.19.2
|
||||
diffusers==0.3.0
|
||||
basicsr==1.4.2
|
||||
gfpgan==1.3.8
|
||||
gradio==3.4.1
|
||||
gradio==3.5
|
||||
numpy==1.23.3
|
||||
Pillow==9.2.0
|
||||
realesrgan==0.3.0
|
||||
|
@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||
document.addEventListener('keydown', function(e) {
|
||||
var handled = false;
|
||||
if (e.key !== undefined) {
|
||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true;
|
||||
if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||
} else if (e.keyCode !== undefined) {
|
||||
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true;
|
||||
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||
}
|
||||
if (handled) {
|
||||
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||
|
@ -1,7 +1,9 @@
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import shlex
|
||||
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
@ -10,6 +12,75 @@ from modules.processing import Processed, process_images
|
||||
from PIL import Image
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
|
||||
|
||||
def process_string_tag(tag):
|
||||
return tag
|
||||
|
||||
|
||||
def process_int_tag(tag):
|
||||
return int(tag)
|
||||
|
||||
|
||||
def process_float_tag(tag):
|
||||
return float(tag)
|
||||
|
||||
|
||||
def process_boolean_tag(tag):
|
||||
return True if (tag == "true") else False
|
||||
|
||||
|
||||
prompt_tags = {
|
||||
"sd_model": None,
|
||||
"outpath_samples": process_string_tag,
|
||||
"outpath_grids": process_string_tag,
|
||||
"prompt_for_display": process_string_tag,
|
||||
"prompt": process_string_tag,
|
||||
"negative_prompt": process_string_tag,
|
||||
"styles": process_string_tag,
|
||||
"seed": process_int_tag,
|
||||
"subseed_strength": process_float_tag,
|
||||
"subseed": process_int_tag,
|
||||
"seed_resize_from_h": process_int_tag,
|
||||
"seed_resize_from_w": process_int_tag,
|
||||
"sampler_index": process_int_tag,
|
||||
"batch_size": process_int_tag,
|
||||
"n_iter": process_int_tag,
|
||||
"steps": process_int_tag,
|
||||
"cfg_scale": process_float_tag,
|
||||
"width": process_int_tag,
|
||||
"height": process_int_tag,
|
||||
"restore_faces": process_boolean_tag,
|
||||
"tiling": process_boolean_tag,
|
||||
"do_not_save_samples": process_boolean_tag,
|
||||
"do_not_save_grid": process_boolean_tag
|
||||
}
|
||||
|
||||
|
||||
def cmdargs(line):
|
||||
args = shlex.split(line)
|
||||
pos = 0
|
||||
res = {}
|
||||
|
||||
while pos < len(args):
|
||||
arg = args[pos]
|
||||
|
||||
assert arg.startswith("--"), f'must start with "--": {arg}'
|
||||
tag = arg[2:]
|
||||
|
||||
func = prompt_tags.get(tag, None)
|
||||
assert func, f'unknown commandline option: {arg}'
|
||||
|
||||
assert pos+1 < len(args), f'missing argument for command line option {arg}'
|
||||
|
||||
val = args[pos+1]
|
||||
|
||||
res[tag] = func(val)
|
||||
|
||||
pos += 2
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def title(self):
|
||||
return "Prompts from file or textbox"
|
||||
@ -32,26 +103,48 @@ class Script(scripts.Script):
|
||||
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
|
||||
|
||||
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
|
||||
if (checkbox_txt):
|
||||
if checkbox_txt:
|
||||
lines = [x.strip() for x in prompt_txt.splitlines()]
|
||||
else:
|
||||
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
|
||||
lines = [x for x in lines if len(x) > 0]
|
||||
|
||||
img_count = len(lines) * p.n_iter
|
||||
batch_count = math.ceil(img_count / p.batch_size)
|
||||
loop_count = math.ceil(batch_count / p.n_iter)
|
||||
print(f"Will process {img_count} images in {batch_count} batches.")
|
||||
|
||||
p.do_not_save_grid = True
|
||||
|
||||
state.job_count = batch_count
|
||||
job_count = 0
|
||||
jobs = []
|
||||
|
||||
for line in lines:
|
||||
if "--" in line:
|
||||
try:
|
||||
args = cmdargs(line)
|
||||
except Exception:
|
||||
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
args = {"prompt": line}
|
||||
else:
|
||||
args = {"prompt": line}
|
||||
|
||||
n_iter = args.get("n_iter", 1)
|
||||
if n_iter != 1:
|
||||
job_count += n_iter
|
||||
else:
|
||||
job_count += 1
|
||||
|
||||
jobs.append(args)
|
||||
|
||||
print(f"Will process {len(lines)} lines in {job_count} jobs.")
|
||||
state.job_count = job_count
|
||||
|
||||
images = []
|
||||
for loop_no in range(loop_count):
|
||||
state.job = f"{loop_no + 1} out of {loop_count}"
|
||||
p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
|
||||
proc = process_images(p)
|
||||
for n, args in enumerate(jobs):
|
||||
state.job = f"{state.job_no + 1} out of {state.job_count}"
|
||||
|
||||
copy_p = copy.copy(p)
|
||||
for k, v in args.items():
|
||||
setattr(copy_p, k, v)
|
||||
|
||||
proc = process_images(copy_p)
|
||||
images += proc.images
|
||||
|
||||
return Processed(p, images, p.seed, "")
|
||||
|
@ -12,7 +12,7 @@ import gradio as gr
|
||||
|
||||
from modules import images
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.processing import process_images, Processed, get_correct_sampler
|
||||
from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.sd_samplers
|
||||
@ -176,7 +176,7 @@ axis_options = [
|
||||
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
|
||||
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
|
||||
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
|
||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
|
||||
]
|
||||
|
||||
|
||||
@ -338,7 +338,7 @@ class Script(scripts.Script):
|
||||
ys = process_axis(y_opt, y_values)
|
||||
|
||||
def fix_axis_seeds(axis_opt, axis_list):
|
||||
if axis_opt.label == 'Seed':
|
||||
if axis_opt.label in ['Seed','Var. seed']:
|
||||
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
|
||||
else:
|
||||
return axis_list
|
||||
@ -354,6 +354,9 @@ class Script(scripts.Script):
|
||||
else:
|
||||
total_steps = p.steps * len(xs) * len(ys)
|
||||
|
||||
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
|
||||
total_steps *= 2
|
||||
|
||||
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
|
||||
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
|
||||
|
||||
|
40
style.css
40
style.css
@ -115,7 +115,7 @@
|
||||
padding: 0.4em 0;
|
||||
}
|
||||
|
||||
#roll, #paste{
|
||||
#roll, #paste, #style_create, #style_apply{
|
||||
min-width: 2em;
|
||||
min-height: 2em;
|
||||
max-width: 2em;
|
||||
@ -126,14 +126,14 @@
|
||||
margin: 0.1em 0;
|
||||
}
|
||||
|
||||
#style_apply, #style_create, #interrogate{
|
||||
margin: 0.75em 0.25em 0.25em 0.25em;
|
||||
min-width: 5em;
|
||||
#interrogate_col{
|
||||
min-width: 0 !important;
|
||||
max-width: 8em !important;
|
||||
}
|
||||
|
||||
#style_apply, #style_create, #deepbooru{
|
||||
margin: 0.75em 0.25em 0.25em 0.25em;
|
||||
min-width: 5em;
|
||||
#interrogate, #deepbooru{
|
||||
margin: 0em 0.25em 0.9em 0.25em;
|
||||
min-width: 8em;
|
||||
max-width: 8em;
|
||||
}
|
||||
|
||||
#style_pos_col, #style_neg_col{
|
||||
@ -167,18 +167,6 @@ button{
|
||||
align-self: stretch !important;
|
||||
}
|
||||
|
||||
#prompt, #negative_prompt{
|
||||
border: none !important;
|
||||
}
|
||||
#prompt textarea, #negative_prompt textarea{
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
|
||||
#img2maskimg .h-60{
|
||||
height: 30rem;
|
||||
}
|
||||
|
||||
.overflow-hidden, .gr-panel{
|
||||
overflow: visible !important;
|
||||
}
|
||||
@ -451,10 +439,6 @@ input[type="range"]{
|
||||
--tw-bg-opacity: 0 !important;
|
||||
}
|
||||
|
||||
#img2img_image div.h-60{
|
||||
height: 480px;
|
||||
}
|
||||
|
||||
#context-menu{
|
||||
z-index:9999;
|
||||
position:absolute;
|
||||
@ -529,3 +513,11 @@ canvas[key="mask"] {
|
||||
.row.gr-compact{
|
||||
overflow: visible;
|
||||
}
|
||||
|
||||
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
|
||||
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
|
||||
{
|
||||
height: 480px !important;
|
||||
max-height: 480px !important;
|
||||
min-height: 480px !important;
|
||||
}
|
Loading…
Reference in New Issue
Block a user