mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 12:25:06 +08:00
Merge branch 'master' into gallery-styling
This commit is contained in:
commit
ab4ddbf333
1
.gitignore
vendored
1
.gitignore
vendored
@ -25,3 +25,4 @@ __pycache__
|
|||||||
/.idea
|
/.idea
|
||||||
notification.mp3
|
notification.mp3
|
||||||
/SwinIR
|
/SwinIR
|
||||||
|
/textual_inversion
|
||||||
|
53
README.md
53
README.md
@ -11,44 +11,56 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- One click install and run script (but you still must install python and git)
|
- One click install and run script (but you still must install python and git)
|
||||||
- Outpainting
|
- Outpainting
|
||||||
- Inpainting
|
- Inpainting
|
||||||
- Prompt matrix
|
- Prompt Matrix
|
||||||
- Stable Diffusion upscale
|
- Stable Diffusion Upscale
|
||||||
- Attention
|
- Attention, specify parts of text that the model should pay more attention to
|
||||||
- Loopback
|
- a man in a ((tuxedo)) - will pay more attention to tuxedo
|
||||||
- X/Y plot
|
- a man in a (tuxedo:1.21) - alternative syntax
|
||||||
|
- Loopback, run img2img processing multiple times
|
||||||
|
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
|
||||||
- Textual Inversion
|
- Textual Inversion
|
||||||
|
- have as many embeddings as you want and use any names you like for them
|
||||||
|
- use multiple embeddings with different numbers of vectors per token
|
||||||
|
- works with half precision floating point numbers
|
||||||
- Extras tab with:
|
- Extras tab with:
|
||||||
- GFPGAN, neural network that fixes faces
|
- GFPGAN, neural network that fixes faces
|
||||||
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
- CodeFormer, face restoration tool as an alternative to GFPGAN
|
||||||
- RealESRGAN, neural network upscaler
|
- RealESRGAN, neural network upscaler
|
||||||
- ESRGAN, neural network with a lot of third party models
|
- ESRGAN, neural network upscaler with a lot of third party models
|
||||||
- SwinIR, neural network upscaler
|
- SwinIR, neural network upscaler
|
||||||
- LDSR, Latent diffusion super resolution upscaling
|
- LDSR, Latent diffusion super resolution upscaling
|
||||||
- Resizing aspect ratio options
|
- Resizing aspect ratio options
|
||||||
- Sampling method selection
|
- Sampling method selection
|
||||||
- Interrupt processing at any time
|
- Interrupt processing at any time
|
||||||
- 4GB video card support
|
- 4GB video card support (also reports of 2GB working)
|
||||||
- Correct seeds for batches
|
- Correct seeds for batches
|
||||||
- Prompt length validation
|
- Prompt length validation
|
||||||
- Generation parameters added as text to PNG
|
- get length of prompt in tokens as you type
|
||||||
- Tab to view an existing picture's generation parameters
|
- get a warning after generation if some text was truncated
|
||||||
|
- Generation parameters
|
||||||
|
- parameters you used to generate images are saved with that image
|
||||||
|
- in PNG chunks for PNG, in EXIF for JPEG
|
||||||
|
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
|
||||||
|
- can be disabled in settings
|
||||||
- Settings page
|
- Settings page
|
||||||
- Running custom code from UI
|
- Running arbitrary python code from UI (must run with --allow-code to enable)
|
||||||
- Mouseover hints for most UI elements
|
- Mouseover hints for most UI elements
|
||||||
- Possible to change defaults/mix/max/step values for UI elements via text config
|
- Possible to change defaults/mix/max/step values for UI elements via text config
|
||||||
- Random artist button
|
- Random artist button
|
||||||
- Tiling support: UI checkbox to create images that can be tiled like textures
|
- Tiling support, a checkbox to create images that can be tiled like textures
|
||||||
- Progress bar and live image generation preview
|
- Progress bar and live image generation preview
|
||||||
- Negative prompt
|
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
|
||||||
- Styles
|
- Styles, a way to save part of prompt and easily apply them via dropdown later
|
||||||
- Variations
|
- Variations, a way to generate same image but with tiny differences
|
||||||
- Seed resizing
|
- Seed resizing, a way to generate same image but at slightly different resolution
|
||||||
- CLIP interrogator
|
- CLIP interrogator, a button that tries to guess prompt from an image
|
||||||
- Prompt Editing
|
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
|
||||||
- Batch Processing
|
- Batch Processing, process a group of files using img2img
|
||||||
- Img2img Alternative
|
- Img2img Alternative
|
||||||
- Highres Fix
|
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
|
||||||
- LDSR Upscaling
|
- Reloading checkpoints on the fly
|
||||||
|
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
|
||||||
|
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
@ -101,6 +113,7 @@ The documentation was moved from this README over to the project's [wiki](https:
|
|||||||
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
- LDSR - https://github.com/Hafiidz/latent-diffusion
|
||||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||||
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||||
|
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||||
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
|
||||||
|
@ -15,7 +15,7 @@ titles = {
|
|||||||
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
|
||||||
"\u{1f3a8}": "Add a random artist to the prompt.",
|
"\u{1f3a8}": "Add a random artist to the prompt.",
|
||||||
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
|
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
|
||||||
"\uD83D\uDCC2": "Open images output directory",
|
"\u{1f4c2}": "Open images output directory",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
@ -47,6 +47,7 @@ titles = {
|
|||||||
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
|
||||||
|
|
||||||
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
|
||||||
|
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
|
||||||
|
|
||||||
"Tiling": "Produce an image that can be tiled.",
|
"Tiling": "Produce an image that can be tiled.",
|
||||||
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
||||||
|
@ -4,6 +4,21 @@ global_progressbars = {}
|
|||||||
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){
|
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){
|
||||||
var progressbar = gradioApp().getElementById(id_progressbar)
|
var progressbar = gradioApp().getElementById(id_progressbar)
|
||||||
var interrupt = gradioApp().getElementById(id_interrupt)
|
var interrupt = gradioApp().getElementById(id_interrupt)
|
||||||
|
|
||||||
|
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
|
||||||
|
if(progressbar.innerText){
|
||||||
|
let newtitle = 'Stable Diffusion - ' + progressbar.innerText
|
||||||
|
if(document.title != newtitle){
|
||||||
|
document.title = newtitle;
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
let newtitle = 'Stable Diffusion'
|
||||||
|
if(document.title != newtitle){
|
||||||
|
document.title = newtitle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
|
||||||
global_progressbars[id_progressbar] = progressbar
|
global_progressbars[id_progressbar] = progressbar
|
||||||
|
|
||||||
@ -30,6 +45,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
|
|||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
|
||||||
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
|
||||||
|
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
|
||||||
})
|
})
|
||||||
|
|
||||||
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
|
||||||
|
8
javascript/textualInversion.js
Normal file
8
javascript/textualInversion.js
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
|
||||||
|
function start_training_textual_inversion(){
|
||||||
|
requestProgress('ti')
|
||||||
|
gradioApp().querySelector('#ti_error').innerHTML=''
|
||||||
|
|
||||||
|
return args_to_array(arguments)
|
||||||
|
}
|
@ -186,10 +186,12 @@ onUiUpdate(function(){
|
|||||||
if (!txt2img_textarea) {
|
if (!txt2img_textarea) {
|
||||||
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
|
||||||
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
|
||||||
|
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
|
||||||
}
|
}
|
||||||
if (!img2img_textarea) {
|
if (!img2img_textarea) {
|
||||||
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
|
||||||
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
|
||||||
|
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -197,8 +199,35 @@ let txt2img_textarea, img2img_textarea = undefined;
|
|||||||
let wait_time = 800
|
let wait_time = 800
|
||||||
let token_timeout;
|
let token_timeout;
|
||||||
|
|
||||||
|
function update_txt2img_tokens(...args) {
|
||||||
|
update_token_counter("txt2img_token_button")
|
||||||
|
if (args.length == 2)
|
||||||
|
return args[0]
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
function update_img2img_tokens(...args) {
|
||||||
|
update_token_counter("img2img_token_button")
|
||||||
|
if (args.length == 2)
|
||||||
|
return args[0]
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
function update_token_counter(button_id) {
|
function update_token_counter(button_id) {
|
||||||
if (token_timeout)
|
if (token_timeout)
|
||||||
clearTimeout(token_timeout);
|
clearTimeout(token_timeout);
|
||||||
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function submit_prompt(event, generate_button_id) {
|
||||||
|
if (event.altKey && event.keyCode === 13) {
|
||||||
|
event.preventDefault();
|
||||||
|
gradioApp().getElementById(generate_button_id).click();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function restart_reload(){
|
||||||
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
|
setTimeout(function(){location.reload()},2000)
|
||||||
|
}
|
||||||
|
15
launch.py
15
launch.py
@ -15,10 +15,11 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
|||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
@ -85,6 +86,15 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
# TODO clone into temporary dir and move if successful
|
# TODO clone into temporary dir and move if successful
|
||||||
|
|
||||||
if os.path.exists(dir):
|
if os.path.exists(dir):
|
||||||
|
if commithash is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
||||||
|
if current_hash == commithash:
|
||||||
|
return
|
||||||
|
|
||||||
|
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
||||||
|
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
||||||
@ -111,6 +121,9 @@ if not skip_torch_cuda_test:
|
|||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
|
||||||
|
if not is_installed("clip"):
|
||||||
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
|
@ -8,7 +8,7 @@ import torch
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import shared, modelloader
|
from modules import devices, modelloader
|
||||||
from modules.bsrgan_model_arch import RRDBNet
|
from modules.bsrgan_model_arch import RRDBNet
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
|||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
if model is None:
|
if model is None:
|
||||||
return img
|
return img
|
||||||
model.to(shared.device)
|
model.to(devices.device_bsrgan)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
img = img[:, :, ::-1]
|
img = img[:, :, ::-1]
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img).float()
|
||||||
img = img.unsqueeze(0).to(shared.device)
|
img = img.unsqueeze(0).to(devices.device_bsrgan)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
output = model(img)
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
@ -67,10 +67,9 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
|||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
if not os.path.exists(filename) or filename is None:
|
||||||
print("Unable to load %s from %s" % (self.model_dir, filename))
|
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
|
||||||
return None
|
return None
|
||||||
print("Loading %s from %s" % (self.model_dir, filename))
|
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
|
||||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for k, v in model.named_parameters():
|
for k, v in model.named_parameters():
|
||||||
|
@ -76,7 +76,6 @@ class RRDBNet(nn.Module):
|
|||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
self.sf = sf
|
self.sf = sf
|
||||||
print([in_nc, out_nc, nf, nb, gc, sf])
|
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||||
|
@ -69,10 +69,14 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
self.net = net
|
self.net = net
|
||||||
self.face_helper = face_helper
|
self.face_helper = face_helper
|
||||||
self.net.to(devices.device_codeformer)
|
|
||||||
|
|
||||||
return net, face_helper
|
return net, face_helper
|
||||||
|
|
||||||
|
def send_model_to(self, device):
|
||||||
|
self.net.to(device)
|
||||||
|
self.face_helper.face_det.to(device)
|
||||||
|
self.face_helper.face_parse.to(device)
|
||||||
|
|
||||||
def restore(self, np_image, w=None):
|
def restore(self, np_image, w=None):
|
||||||
np_image = np_image[:, :, ::-1]
|
np_image = np_image[:, :, ::-1]
|
||||||
|
|
||||||
@ -82,6 +86,8 @@ def setup_model(dirname):
|
|||||||
if self.net is None or self.face_helper is None:
|
if self.net is None or self.face_helper is None:
|
||||||
return np_image
|
return np_image
|
||||||
|
|
||||||
|
self.send_model_to(devices.device_codeformer)
|
||||||
|
|
||||||
self.face_helper.clean_all()
|
self.face_helper.clean_all()
|
||||||
self.face_helper.read_image(np_image)
|
self.face_helper.read_image(np_image)
|
||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
@ -113,8 +119,10 @@ def setup_model(dirname):
|
|||||||
if original_resolution != restored_img.shape[0:2]:
|
if original_resolution != restored_img.shape[0:2]:
|
||||||
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
self.face_helper.clean_all()
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
if shared.opts.face_restoration_unload:
|
||||||
self.net.to(devices.cpu)
|
self.send_model_to(devices.cpu)
|
||||||
|
|
||||||
return restored_img
|
return restored_img
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
import contextlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
|
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
|
||||||
has_mps = getattr(torch, 'has_mps', False)
|
has_mps = getattr(torch, 'has_mps', False)
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
@ -32,10 +34,8 @@ def enable_tf32():
|
|||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
|
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
|
||||||
device = get_optimal_device()
|
dtype = torch.float16
|
||||||
device_codeformer = cpu if has_mps else device
|
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
||||||
@ -58,3 +58,11 @@ def randn_without_seed(shape):
|
|||||||
|
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def autocast():
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
return torch.autocast("cuda")
|
||||||
|
@ -6,8 +6,7 @@ from PIL import Image
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgam_model_arch as arch
|
import modules.esrgam_model_arch as arch
|
||||||
from modules import shared, modelloader, images
|
from modules import shared, modelloader, images, devices
|
||||||
from modules.devices import has_mps
|
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
@ -73,8 +72,8 @@ def fix_model_layers(crt_model, pretrained_net):
|
|||||||
class UpscalerESRGAN(Upscaler):
|
class UpscalerESRGAN(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "ESRGAN"
|
self.name = "ESRGAN"
|
||||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
|
||||||
self.model_name = "ESRGAN 4x"
|
self.model_name = "ESRGAN_4x"
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
self.model_path = os.path.join(models_path, self.name)
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
model = self.load_model(selected_model)
|
model = self.load_model(selected_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
return img
|
return img
|
||||||
model.to(shared.device)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
img = esrgan_upscale(model, img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
|
||||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||||
|
|
||||||
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
pretrained_net = fix_model_layers(crt_model, pretrained_net)
|
||||||
@ -127,7 +126,7 @@ def upscale_without_tiling(model, img):
|
|||||||
img = img[:, :, ::-1]
|
img = img[:, :, ::-1]
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
img = torch.from_numpy(img).float()
|
img = torch.from_numpy(img).float()
|
||||||
img = img.unsqueeze(0).to(shared.device)
|
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
output = model(img)
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||||||
|
|
||||||
outputs.append(image)
|
outputs.append(image)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, plaintext_to_html(info), ''
|
return outputs, plaintext_to_html(info), ''
|
||||||
|
|
||||||
|
|
||||||
@ -191,9 +193,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
|
|||||||
if save_as_half:
|
if save_as_half:
|
||||||
theta_0[key] = theta_0[key].half()
|
theta_0[key] = theta_0[key].half()
|
||||||
|
|
||||||
|
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
|
||||||
|
|
||||||
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
|
||||||
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
filename = filename if custom_name == '' else (custom_name + '.ckpt')
|
||||||
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
|
output_modelname = os.path.join(ckpt_dir, filename)
|
||||||
|
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
torch.save(primary_model, output_modelname)
|
torch.save(primary_model, output_modelname)
|
||||||
|
@ -21,7 +21,7 @@ def gfpgann():
|
|||||||
global loaded_gfpgan_model
|
global loaded_gfpgan_model
|
||||||
global model_path
|
global model_path
|
||||||
if loaded_gfpgan_model is not None:
|
if loaded_gfpgan_model is not None:
|
||||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||||
return loaded_gfpgan_model
|
return loaded_gfpgan_model
|
||||||
|
|
||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
@ -37,22 +37,32 @@ def gfpgann():
|
|||||||
print("Unable to load gfpgan model!")
|
print("Unable to load gfpgan model!")
|
||||||
return None
|
return None
|
||||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||||
model.gfpgan.to(shared.device)
|
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to(model, device):
|
||||||
|
model.gfpgan.to(device)
|
||||||
|
model.face_helper.face_det.to(device)
|
||||||
|
model.face_helper.face_parse.to(device)
|
||||||
|
|
||||||
|
|
||||||
def gfpgan_fix_faces(np_image):
|
def gfpgan_fix_faces(np_image):
|
||||||
model = gfpgann()
|
model = gfpgann()
|
||||||
if model is None:
|
if model is None:
|
||||||
return np_image
|
return np_image
|
||||||
|
|
||||||
|
send_model_to(model, devices.device_gfpgan)
|
||||||
|
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
np_image_bgr = np_image[:, :, ::-1]
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||||
|
|
||||||
|
model.face_helper.clean_all()
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
if shared.opts.face_restoration_unload:
|
||||||
model.gfpgan.to(devices.cpu)
|
send_model_to(model, devices.cpu)
|
||||||
|
|
||||||
return np_image
|
return np_image
|
||||||
|
|
||||||
@ -97,11 +107,7 @@ def setup_model(dirname):
|
|||||||
return "GFPGAN"
|
return "GFPGAN"
|
||||||
|
|
||||||
def restore(self, np_image):
|
def restore(self, np_image):
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
return gfpgan_fix_faces(np_image)
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
|
||||||
|
|
||||||
return np_image
|
|
||||||
|
|
||||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -213,17 +213,19 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
|
||||||
return im.resize((w, h), resample=LANCZOS)
|
return im.resize((w, h), resample=LANCZOS)
|
||||||
|
|
||||||
|
scale = max(w / im.width, h / im.height)
|
||||||
|
|
||||||
|
if scale > 1.0:
|
||||||
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
|
||||||
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
|
||||||
|
|
||||||
upscaler = upscalers[0]
|
upscaler = upscalers[0]
|
||||||
scale = max(w / im.width, h / im.height)
|
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
||||||
upscaled = upscaler.scaler.upscale(im, scale, upscaler.data_path)
|
|
||||||
|
|
||||||
if upscaled.width != w or upscaled.height != h:
|
if im.width != w or im.height != h:
|
||||||
upscaled = im.resize((w, h), resample=LANCZOS)
|
im = im.resize((w, h), resample=LANCZOS)
|
||||||
|
|
||||||
return upscaled
|
return im
|
||||||
|
|
||||||
if resize_mode == 0:
|
if resize_mode == 0:
|
||||||
res = resize(im, width, height)
|
res = resize(im, width, height)
|
||||||
@ -285,6 +287,25 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
x = x.replace("[seed]", str(seed))
|
x = x.replace("[seed]", str(seed))
|
||||||
|
|
||||||
|
if p is not None:
|
||||||
|
x = x.replace("[steps]", str(p.steps))
|
||||||
|
x = x.replace("[cfg]", str(p.cfg_scale))
|
||||||
|
x = x.replace("[width]", str(p.width))
|
||||||
|
x = x.replace("[height]", str(p.height))
|
||||||
|
|
||||||
|
#currently disabled if using the save button, will work otherwise
|
||||||
|
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
|
||||||
|
if hasattr(p, "styles"):
|
||||||
|
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False))
|
||||||
|
|
||||||
|
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||||
|
|
||||||
|
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||||
|
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||||
|
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
||||||
|
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
|
||||||
|
|
||||||
|
# Apply [prompt] at last. Because it may contain any replacement word.^M
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
x = x.replace("[prompt]", sanitize_filename_part(prompt))
|
||||||
if "[prompt_no_styles]" in x:
|
if "[prompt_no_styles]" in x:
|
||||||
@ -304,19 +325,6 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||||||
words = ["empty"]
|
words = ["empty"]
|
||||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||||
|
|
||||||
if p is not None:
|
|
||||||
x = x.replace("[steps]", str(p.steps))
|
|
||||||
x = x.replace("[cfg]", str(p.cfg_scale))
|
|
||||||
x = x.replace("[width]", str(p.width))
|
|
||||||
x = x.replace("[height]", str(p.height))
|
|
||||||
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
|
|
||||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
|
||||||
|
|
||||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
|
||||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
|
||||||
x = x.replace("[datetime]", datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
|
|
||||||
x = x.replace("[job_timestamp]", shared.state.job_timestamp)
|
|
||||||
|
|
||||||
if cmd_opts.hide_ui_dir_config:
|
if cmd_opts.hide_ui_dir_config:
|
||||||
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)
|
||||||
|
|
||||||
@ -345,7 +353,7 @@ def get_next_sequence_number(path, basename):
|
|||||||
return result + 1
|
return result + 1
|
||||||
|
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||||
if short_filename or prompt is None or seed is None:
|
if short_filename or prompt is None or seed is None:
|
||||||
file_decoration = ""
|
file_decoration = ""
|
||||||
elif opts.save_to_dirs:
|
elif opts.save_to_dirs:
|
||||||
@ -369,10 +377,11 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
else:
|
else:
|
||||||
pnginfo = None
|
pnginfo = None
|
||||||
|
|
||||||
|
if save_to_dirs is None:
|
||||||
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
|
||||||
|
|
||||||
if save_to_dirs:
|
if save_to_dirs:
|
||||||
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
|
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt).strip('\\ /')
|
||||||
path = os.path.join(path, dirname)
|
path = os.path.join(path, dirname)
|
||||||
|
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@ -423,4 +432,4 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
|
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
|
||||||
file.write(info + "\n")
|
file.write(info + "\n")
|
||||||
|
|
||||||
|
return fullfn
|
||||||
|
@ -23,8 +23,10 @@ def process_batch(p, input_dir, output_dir, args):
|
|||||||
|
|
||||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||||
|
|
||||||
|
save_normally = output_dir == ''
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
p.do_not_save_grid = True
|
||||||
p.do_not_save_samples = True
|
p.do_not_save_samples = not save_normally
|
||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(images) * p.n_iter
|
||||||
|
|
||||||
@ -48,6 +50,7 @@ def process_batch(p, input_dir, output_dir, args):
|
|||||||
left, right = os.path.splitext(filename)
|
left, right = os.path.splitext(filename)
|
||||||
filename = f"{left}-{n}{right}"
|
filename = f"{left}-{n}{right}"
|
||||||
|
|
||||||
|
if not save_normally:
|
||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
@ -103,6 +106,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||||||
inpaint_full_res_padding=inpaint_full_res_padding,
|
inpaint_full_res_padding=inpaint_full_res_padding,
|
||||||
inpainting_mask_invert=inpainting_mask_invert,
|
inpainting_mask_invert=inpainting_mask_invert,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
p.extra_generation_params["Mask blur"] = mask_blur
|
p.extra_generation_params["Mask blur"] = mask_blur
|
||||||
@ -124,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
|
|||||||
if opts.samples_log_stdout:
|
if opts.samples_log_stdout:
|
||||||
print(generation_info_js)
|
print(generation_info_js)
|
||||||
|
|
||||||
|
if opts.do_not_show_images:
|
||||||
|
processed.images = []
|
||||||
|
|
||||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||||
|
@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
|
|||||||
|
|
||||||
re_topn = re.compile(r"\.top(\d+)\.")
|
re_topn = re.compile(r"\.top(\d+)\.")
|
||||||
|
|
||||||
|
|
||||||
class InterrogateModels:
|
class InterrogateModels:
|
||||||
blip_model = None
|
blip_model = None
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
@ -22,9 +22,21 @@ class UpscalerLDSR(Upscaler):
|
|||||||
self.scalers = [scaler_data]
|
self.scalers = [scaler_data]
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
|
# Remove incorrect project.yaml file if too big
|
||||||
|
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||||
|
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||||
|
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||||
|
if os.path.exists(yaml_path):
|
||||||
|
statinfo = os.stat(yaml_path)
|
||||||
|
if statinfo.st_size >= 10485760:
|
||||||
|
print("Removing invalid LDSR YAML file.")
|
||||||
|
os.remove(yaml_path)
|
||||||
|
if os.path.exists(old_model_path):
|
||||||
|
print("Renaming model from model.pth to model.ckpt")
|
||||||
|
os.rename(old_model_path, new_model_path)
|
||||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||||
file_name="model.pth", progress=True)
|
file_name="model.ckpt", progress=True)
|
||||||
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
||||||
file_name="project.yaml", progress=True)
|
file_name="project.yaml", progress=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -41,5 +53,4 @@ class UpscalerLDSR(Upscaler):
|
|||||||
print("NO LDSR!")
|
print("NO LDSR!")
|
||||||
return img
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
pre_scale = shared.opts.ldsr_pre_down
|
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
|
@ -98,9 +98,7 @@ class LDSR:
|
|||||||
im_og = image
|
im_og = image
|
||||||
width_og, height_og = im_og.size
|
width_og, height_og = im_og.size
|
||||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||||
print("Foo")
|
|
||||||
down_sample_rate = target_scale / 4
|
down_sample_rate = target_scale / 4
|
||||||
print(f"Downsample rate is {down_sample_rate}")
|
|
||||||
wd = width_og * down_sample_rate
|
wd = width_og * down_sample_rate
|
||||||
hd = height_og * down_sample_rate
|
hd = height_og * down_sample_rate
|
||||||
width_downsampled_pre = int(wd)
|
width_downsampled_pre = int(wd)
|
||||||
@ -111,7 +109,7 @@ class LDSR:
|
|||||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
import glob
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler
|
from modules.upscaler import Upscaler
|
||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
@ -41,8 +41,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
for place in places:
|
for place in places:
|
||||||
if os.path.exists(place):
|
if os.path.exists(place):
|
||||||
for file in os.listdir(place):
|
for file in glob.iglob(place + '**/**', recursive=True):
|
||||||
full_path = os.path.join(place, file)
|
full_path = file
|
||||||
if os.path.isdir(full_path):
|
if os.path.isdir(full_path):
|
||||||
continue
|
continue
|
||||||
if len(ext_filter) != 0:
|
if len(ext_filter) != 0:
|
||||||
@ -120,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
|
sd = shared.script_path
|
||||||
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
|
modules_dir = os.path.join(sd, "modules")
|
||||||
|
for file in os.listdir(modules_dir):
|
||||||
|
if "_model.py" in file:
|
||||||
|
model_name = file.replace("_model.py", "")
|
||||||
|
full_model = f"modules.{model_name}_model"
|
||||||
|
try:
|
||||||
|
importlib.import_module(full_model)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
datas = []
|
datas = []
|
||||||
|
c_o = vars(shared.cmd_opts)
|
||||||
for cls in Upscaler.__subclasses__():
|
for cls in Upscaler.__subclasses__():
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
module_name = cls.__module__
|
module_name = cls.__module__
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
class_ = getattr(module, name)
|
class_ = getattr(module, name)
|
||||||
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||||
opt_string = None
|
opt_string = None
|
||||||
try:
|
try:
|
||||||
opt_string = shared.opts.__getattr__(cmd_name)
|
if cmd_name in c_o:
|
||||||
|
opt_string = c_o[cmd_name]
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
scaler = class_(opt_string)
|
scaler = class_(opt_string)
|
||||||
|
@ -20,7 +20,6 @@ path_dirs = [
|
|||||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
|
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -12,9 +11,8 @@ import cv2
|
|||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
@ -56,7 +54,7 @@ class StableDiffusionProcessing:
|
|||||||
self.prompt: str = prompt
|
self.prompt: str = prompt
|
||||||
self.prompt_for_display: str = None
|
self.prompt_for_display: str = None
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
self.negative_prompt: str = (negative_prompt or "")
|
||||||
self.styles: str = styles
|
self.styles: list = styles or []
|
||||||
self.seed: int = seed
|
self.seed: int = seed
|
||||||
self.subseed: int = subseed
|
self.subseed: int = subseed
|
||||||
self.subseed_strength: float = subseed_strength
|
self.subseed_strength: float = subseed_strength
|
||||||
@ -79,7 +77,7 @@ class StableDiffusionProcessing:
|
|||||||
self.paste_to = None
|
self.paste_to = None
|
||||||
self.color_corrections = None
|
self.color_corrections = None
|
||||||
self.denoising_strength: float = 0
|
self.denoising_strength: float = 0
|
||||||
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = opts.ddim_discretize
|
self.ddim_discretize = opts.ddim_discretize
|
||||||
self.s_churn = opts.s_churn
|
self.s_churn = opts.s_churn
|
||||||
self.s_tmin = opts.s_tmin
|
self.s_tmin = opts.s_tmin
|
||||||
@ -111,7 +109,7 @@ class Processed:
|
|||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_index = p.sampler_index
|
self.sampler_index = p.sampler_index
|
||||||
self.sampler = samplers[p.sampler_index].name
|
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
||||||
self.cfg_scale = p.cfg_scale
|
self.cfg_scale = p.cfg_scale
|
||||||
self.steps = p.steps
|
self.steps = p.steps
|
||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
@ -130,7 +128,7 @@ class Processed:
|
|||||||
self.s_tmin = p.s_tmin
|
self.s_tmin = p.s_tmin
|
||||||
self.s_tmax = p.s_tmax
|
self.s_tmax = p.s_tmax
|
||||||
self.s_noise = p.s_noise
|
self.s_noise = p.s_noise
|
||||||
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||||
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
|
||||||
@ -249,9 +247,16 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_fixed_seed(seed):
|
||||||
|
if seed is None or seed == '' or seed == -1:
|
||||||
|
return int(random.randrange(4294967294))
|
||||||
|
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
def fix_seed(p):
|
def fix_seed(p):
|
||||||
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == '' or p.seed == -1 else p.seed
|
p.seed = get_fixed_seed(p.seed)
|
||||||
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == '' or p.subseed == -1 else p.subseed
|
p.subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
||||||
@ -259,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": samplers[p.sampler_index].name,
|
"Sampler": sd_samplers.samplers[p.sampler_index].name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
@ -271,7 +276,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
|
|||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
@ -293,9 +298,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
fix_seed(p)
|
seed = get_fixed_seed(p.seed)
|
||||||
|
subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
if p.outpath_samples is not None:
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||||
|
|
||||||
|
if p.outpath_grids is not None:
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
@ -309,27 +318,27 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
||||||
|
|
||||||
if type(p.seed) == list:
|
if type(seed) == list:
|
||||||
all_seeds = p.seed
|
all_seeds = seed
|
||||||
else:
|
else:
|
||||||
all_seeds = [int(p.seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
||||||
|
|
||||||
if type(p.subseed) == list:
|
if type(subseed) == list:
|
||||||
all_subseeds = p.subseed
|
all_subseeds = subseed
|
||||||
else:
|
else:
|
||||||
all_subseeds = [int(p.subseed) + x for x in range(len(all_prompts))]
|
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir):
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
infotexts = []
|
infotexts = []
|
||||||
output_images = []
|
output_images = []
|
||||||
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
|
||||||
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
|
with torch.no_grad():
|
||||||
with torch.no_grad(), precision_scope("cuda"), ema_scope():
|
with devices.autocast():
|
||||||
p.init(all_prompts, all_seeds, all_subseeds)
|
p.init(all_prompts, all_seeds, all_subseeds)
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
@ -348,8 +357,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
||||||
#c = p.sd_model.get_learned_conditioning(prompts)
|
#c = p.sd_model.get_learned_conditioning(prompts)
|
||||||
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
|
with devices.autocast():
|
||||||
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
|
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
||||||
|
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
@ -358,16 +368,27 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
|
|
||||||
# if we are interruped, sample returns just noise
|
# if we are interruped, sample returns just noise
|
||||||
# use the image collected previously in sampler loop
|
# use the image collected previously in sampler loop
|
||||||
samples_ddim = shared.state.current_latent
|
samples_ddim = shared.state.current_latent
|
||||||
|
|
||||||
|
samples_ddim = samples_ddim.to(devices.dtype)
|
||||||
|
|
||||||
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
del samples_ddim
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if opts.filter_nsfw:
|
if opts.filter_nsfw:
|
||||||
import modules.safety as safety
|
import modules.safety as safety
|
||||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
||||||
@ -383,6 +404,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
|
||||||
@ -408,9 +430,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.samples_save and not p.do_not_save_samples:
|
if opts.samples_save and not p.do_not_save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
||||||
|
|
||||||
infotexts.append(infotext(n, i))
|
text = infotext(n, i)
|
||||||
|
infotexts.append(text)
|
||||||
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
|
del x_samples_ddim
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
state.nextjob()
|
state.nextjob()
|
||||||
|
|
||||||
p.color_corrections = None
|
p.color_corrections = None
|
||||||
@ -421,7 +449,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
grid = images.image_grid(output_images, p.batch_size)
|
grid = images.image_grid(output_images, p.batch_size)
|
||||||
|
|
||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
infotexts.insert(0, infotext())
|
text = infotext()
|
||||||
|
infotexts.insert(0, text)
|
||||||
|
grid.info["parameters"] = text
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
index_of_first_image = 1
|
index_of_first_image = 1
|
||||||
|
|
||||||
@ -462,7 +492,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.firstphase_height_truncated = int(scale * self.height)
|
self.firstphase_height_truncated = int(scale * self.height)
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
@ -505,7 +535,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
@ -540,7 +571,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.nmask = None
|
self.nmask = None
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
if self.image_mask is not None:
|
if self.image_mask is not None:
|
||||||
@ -647,4 +678,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
samples = samples * self.nmask + self.init_latent * self.mask
|
samples = samples * self.nmask + self.init_latent * self.mask
|
||||||
|
|
||||||
|
del x
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
@ -1,19 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import torch
|
from typing import List
|
||||||
|
import lark
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
re_prompt = re.compile(r'''
|
|
||||||
(.*?)
|
|
||||||
\[
|
|
||||||
([^]:]+):
|
|
||||||
(?:([^]:]*):)?
|
|
||||||
([0-9]*\.?[0-9]+)
|
|
||||||
]
|
|
||||||
|
|
|
||||||
(.+)
|
|
||||||
''', re.X)
|
|
||||||
|
|
||||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
||||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||||
@ -23,71 +11,112 @@ re_prompt = re.compile(r'''
|
|||||||
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
# [75, 'fantasy landscape with a lake and an oak in background masterful']
|
||||||
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
|
||||||
|
|
||||||
|
schedule_parser = lark.Lark(r"""
|
||||||
|
!start: (prompt | /[][():]/+)*
|
||||||
|
prompt: (emphasized | scheduled | plain | WHITESPACE)*
|
||||||
|
!emphasized: "(" prompt ")"
|
||||||
|
| "(" prompt ":" prompt ")"
|
||||||
|
| "[" prompt "]"
|
||||||
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
||||||
|
WHITESPACE: /\s+/
|
||||||
|
plain: /([^\\\[\]():]|\\.)+/
|
||||||
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
|
""")
|
||||||
|
|
||||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
res = []
|
"""
|
||||||
cache = {}
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||||
|
>>> g("test")
|
||||||
|
[[10, 'test']]
|
||||||
|
>>> g("a [b:3]")
|
||||||
|
[[3, 'a '], [10, 'a b']]
|
||||||
|
>>> g("a [b: 3]")
|
||||||
|
[[3, 'a '], [10, 'a b']]
|
||||||
|
>>> g("a [[[b]]:2]")
|
||||||
|
[[2, 'a '], [10, 'a [[b]]']]
|
||||||
|
>>> g("[(a:2):3]")
|
||||||
|
[[3, ''], [10, '(a:2)']]
|
||||||
|
>>> g("a [b : c : 1] d")
|
||||||
|
[[1, 'a b d'], [10, 'a c d']]
|
||||||
|
>>> g("a[b:[c:d:2]:1]e")
|
||||||
|
[[1, 'abe'], [2, 'ace'], [10, 'ade']]
|
||||||
|
>>> g("a [unbalanced")
|
||||||
|
[[10, 'a [unbalanced']]
|
||||||
|
>>> g("a [b:.5] c")
|
||||||
|
[[5, 'a c'], [10, 'a b c']]
|
||||||
|
>>> g("a [{b|d{:.5] c") # not handling this right now
|
||||||
|
[[5, 'a c'], [10, 'a {b|d{ c']]
|
||||||
|
>>> g("((a][:b:c [d:3]")
|
||||||
|
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||||
|
"""
|
||||||
|
|
||||||
for prompt in prompts:
|
def collect_steps(steps, tree):
|
||||||
prompt_schedule: list[list[str | int]] = [[steps, ""]]
|
l = [steps]
|
||||||
|
class CollectSteps(lark.Visitor):
|
||||||
|
def scheduled(self, tree):
|
||||||
|
tree.children[-1] = float(tree.children[-1])
|
||||||
|
if tree.children[-1] < 1:
|
||||||
|
tree.children[-1] *= steps
|
||||||
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
|
l.append(tree.children[-1])
|
||||||
|
CollectSteps().visit(tree)
|
||||||
|
return sorted(set(l))
|
||||||
|
|
||||||
cached = cache.get(prompt, None)
|
def at_step(step, tree):
|
||||||
if cached is not None:
|
class AtStep(lark.Transformer):
|
||||||
res.append(cached)
|
def scheduled(self, args):
|
||||||
continue
|
before, after, _, when = args
|
||||||
|
yield before or () if step <= when else after
|
||||||
|
def start(self, args):
|
||||||
|
def flatten(x):
|
||||||
|
if type(x) == str:
|
||||||
|
yield x
|
||||||
|
else:
|
||||||
|
for gen in x:
|
||||||
|
yield from flatten(gen)
|
||||||
|
return ''.join(flatten(args))
|
||||||
|
def plain(self, args):
|
||||||
|
yield args[0].value
|
||||||
|
def __default__(self, data, children, meta):
|
||||||
|
for child in children:
|
||||||
|
yield from child
|
||||||
|
return AtStep().transform(tree)
|
||||||
|
|
||||||
for m in re_prompt.finditer(prompt):
|
def get_schedule(prompt):
|
||||||
plaintext = m.group(1) if m.group(5) is None else m.group(5)
|
try:
|
||||||
concept_from = m.group(2)
|
tree = schedule_parser.parse(prompt)
|
||||||
concept_to = m.group(3)
|
except lark.exceptions.LarkError as e:
|
||||||
if concept_to is None:
|
if 0:
|
||||||
concept_to = concept_from
|
import traceback
|
||||||
concept_from = ""
|
traceback.print_exc()
|
||||||
swap_position = float(m.group(4)) if m.group(4) is not None else None
|
return [[steps, prompt]]
|
||||||
|
return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
||||||
|
|
||||||
if swap_position is not None:
|
promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
|
||||||
if swap_position < 1:
|
return [promptdict[prompt] for prompt in prompts]
|
||||||
swap_position = swap_position * steps
|
|
||||||
swap_position = int(min(swap_position, steps))
|
|
||||||
|
|
||||||
swap_index = None
|
|
||||||
found_exact_index = False
|
|
||||||
for i in range(len(prompt_schedule)):
|
|
||||||
end_step = prompt_schedule[i][0]
|
|
||||||
prompt_schedule[i][1] += plaintext
|
|
||||||
|
|
||||||
if swap_position is not None and swap_index is None:
|
|
||||||
if swap_position == end_step:
|
|
||||||
swap_index = i
|
|
||||||
found_exact_index = True
|
|
||||||
|
|
||||||
if swap_position < end_step:
|
|
||||||
swap_index = i
|
|
||||||
|
|
||||||
if swap_index is not None:
|
|
||||||
if not found_exact_index:
|
|
||||||
prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
|
|
||||||
|
|
||||||
for i in range(len(prompt_schedule)):
|
|
||||||
end_step = prompt_schedule[i][0]
|
|
||||||
must_replace = swap_position < end_step
|
|
||||||
|
|
||||||
prompt_schedule[i][1] += concept_to if must_replace else concept_from
|
|
||||||
|
|
||||||
res.append(prompt_schedule)
|
|
||||||
cache[prompt] = prompt_schedule
|
|
||||||
#for t in prompt_schedule:
|
|
||||||
# print(t)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||||
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
|
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(prompts, steps):
|
def get_learned_conditioning(model, prompts, steps):
|
||||||
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
||||||
|
|
||||||
|
Output:
|
||||||
|
[
|
||||||
|
[
|
||||||
|
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
||||||
|
],
|
||||||
|
[
|
||||||
|
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
||||||
|
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
||||||
|
]
|
||||||
|
]
|
||||||
|
"""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||||
@ -101,7 +130,7 @@ def get_learned_conditioning(prompts, steps):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
texts = [x[1] for x in prompt_schedule]
|
texts = [x[1] for x in prompt_schedule]
|
||||||
conds = shared.sd_model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
||||||
@ -110,22 +139,109 @@ def get_learned_conditioning(prompts, steps):
|
|||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
res.append(cond_schedule)
|
res.append(cond_schedule)
|
||||||
|
|
||||||
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
|
return res
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
re_AND = re.compile(r"\bAND\b")
|
||||||
res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
|
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||||
for i, cond_schedule in enumerate(c.schedules):
|
|
||||||
|
def get_multicond_prompt_list(prompts):
|
||||||
|
res_indexes = []
|
||||||
|
|
||||||
|
prompt_flat_list = []
|
||||||
|
prompt_indexes = {}
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
subprompts = re_AND.split(prompt)
|
||||||
|
|
||||||
|
indexes = []
|
||||||
|
for subprompt in subprompts:
|
||||||
|
match = re_weight.search(subprompt)
|
||||||
|
|
||||||
|
text, weight = match.groups() if match is not None else (subprompt, 1.0)
|
||||||
|
|
||||||
|
weight = float(weight) if weight is not None else 1.0
|
||||||
|
|
||||||
|
index = prompt_indexes.get(text, None)
|
||||||
|
if index is None:
|
||||||
|
index = len(prompt_flat_list)
|
||||||
|
prompt_flat_list.append(text)
|
||||||
|
prompt_indexes[text] = index
|
||||||
|
|
||||||
|
indexes.append((index, weight))
|
||||||
|
|
||||||
|
res_indexes.append(indexes)
|
||||||
|
|
||||||
|
return res_indexes, prompt_flat_list, prompt_indexes
|
||||||
|
|
||||||
|
|
||||||
|
class ComposableScheduledPromptConditioning:
|
||||||
|
def __init__(self, schedules, weight=1.0):
|
||||||
|
self.schedules: List[ScheduledPromptConditioning] = schedules
|
||||||
|
self.weight: float = weight
|
||||||
|
|
||||||
|
|
||||||
|
class MulticondLearnedConditioning:
|
||||||
|
def __init__(self, shape, batch):
|
||||||
|
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||||
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||||
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
|
||||||
|
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
||||||
|
"""
|
||||||
|
|
||||||
|
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||||
|
|
||||||
|
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for indexes in res_indexes:
|
||||||
|
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||||
|
|
||||||
|
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||||
|
param = c[0][0].cond
|
||||||
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for curret_index, (end_at, cond) in enumerate(cond_schedule):
|
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= end_at:
|
||||||
target_index = curret_index
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
res[i] = cond_schedule[target_index].cond
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||||
|
param = c.batch[0][0].schedules[0].cond
|
||||||
|
|
||||||
|
tensors = []
|
||||||
|
conds_list = []
|
||||||
|
|
||||||
|
for batch_no, composable_prompts in enumerate(c.batch):
|
||||||
|
conds_for_batch = []
|
||||||
|
|
||||||
|
for cond_index, composable_prompt in enumerate(composable_prompts):
|
||||||
|
target_index = 0
|
||||||
|
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
||||||
|
if current_step <= end_at:
|
||||||
|
target_index = current
|
||||||
|
break
|
||||||
|
|
||||||
|
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
||||||
|
tensors.append(composable_prompt.schedules[target_index].cond)
|
||||||
|
|
||||||
|
conds_list.append(conds_for_batch)
|
||||||
|
|
||||||
|
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
|
|
||||||
re_attention = re.compile(r"""
|
re_attention = re.compile(r"""
|
||||||
\\\(|
|
\\\(|
|
||||||
\\\)|
|
\\\)|
|
||||||
@ -157,14 +273,18 @@ def parse_prompt_attention(text):
|
|||||||
\\ - literal character '\'
|
\\ - literal character '\'
|
||||||
anything else - just text
|
anything else - just text
|
||||||
|
|
||||||
Example:
|
>>> parse_prompt_attention('normal text')
|
||||||
|
[['normal text', 1.0]]
|
||||||
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
|
>>> parse_prompt_attention('an (important) word')
|
||||||
|
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||||
produces:
|
>>> parse_prompt_attention('(unbalanced')
|
||||||
|
[['unbalanced', 1.1]]
|
||||||
[
|
>>> parse_prompt_attention('\(literal\]')
|
||||||
['a ', 1.0],
|
[['(literal]', 1.0]]
|
||||||
|
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||||
|
[['unnecessaryparens', 1.1]]
|
||||||
|
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||||
|
[['a ', 1.0],
|
||||||
['house', 1.5730000000000004],
|
['house', 1.5730000000000004],
|
||||||
[' ', 1.1],
|
[' ', 1.1],
|
||||||
['on', 1.0],
|
['on', 1.0],
|
||||||
@ -172,8 +292,7 @@ def parse_prompt_attention(text):
|
|||||||
['hill', 0.55],
|
['hill', 0.55],
|
||||||
[', sun, ', 1.1],
|
[', sun, ', 1.1],
|
||||||
['sky', 1.4641000000000006],
|
['sky', 1.4641000000000006],
|
||||||
['.', 1.1]
|
['.', 1.1]]
|
||||||
]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
@ -215,4 +334,19 @@ def parse_prompt_attention(text):
|
|||||||
if len(res) == 0:
|
if len(res) == 0:
|
||||||
res = [["", 1.0]]
|
res = [["", 1.0]]
|
||||||
|
|
||||||
|
# merge runs of identical weights
|
||||||
|
i = 0
|
||||||
|
while i + 1 < len(res):
|
||||||
|
if res[i][1] == res[i + 1][1]:
|
||||||
|
res[i][0] += res[i + 1][0]
|
||||||
|
res.pop(i + 1)
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import doctest
|
||||||
|
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
|
||||||
|
else:
|
||||||
|
import torch # doctest faster
|
||||||
|
@ -162,6 +162,40 @@ class ScriptRunner:
|
|||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
def reload_sources(self):
|
||||||
|
for si, script in list(enumerate(self.scripts)):
|
||||||
|
with open(script.filename, "r", encoding="utf8") as file:
|
||||||
|
args_from = script.args_from
|
||||||
|
args_to = script.args_to
|
||||||
|
filename = script.filename
|
||||||
|
text = file.read()
|
||||||
|
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
|
compiled = compile(text, filename, 'exec')
|
||||||
|
module = ModuleType(script.filename)
|
||||||
|
exec(compiled, module.__dict__)
|
||||||
|
|
||||||
|
for key, script_class in module.__dict__.items():
|
||||||
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
|
self.scripts[si] = script_class()
|
||||||
|
self.scripts[si].filename = filename
|
||||||
|
self.scripts[si].args_from = args_from
|
||||||
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
scripts_txt2img = ScriptRunner()
|
||||||
|
scripts_img2img = ScriptRunner()
|
||||||
|
|
||||||
|
def reload_script_body_only():
|
||||||
|
scripts_txt2img.reload_sources()
|
||||||
|
scripts_img2img.reload_sources()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_scripts(basedir):
|
||||||
|
global scripts_txt2img, scripts_img2img
|
||||||
|
|
||||||
|
scripts_data.clear()
|
||||||
|
load_scripts(basedir)
|
||||||
|
|
||||||
scripts_txt2img = ScriptRunner()
|
scripts_txt2img = ScriptRunner()
|
||||||
scripts_img2img = ScriptRunner()
|
scripts_img2img = ScriptRunner()
|
||||||
|
90
modules/scunet_model.py
Normal file
90
modules/scunet_model.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
import modules.upscaler
|
||||||
|
from modules import devices, modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
from modules.scunet_model_arch import SCUNet as net
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "ScuNET"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.model_name = "ScuNET GAN"
|
||||||
|
self.model_name2 = "ScuNET PSNR"
|
||||||
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
|
||||||
|
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
model_paths = self.find_models(ext_filter=[".pth"])
|
||||||
|
scalers = []
|
||||||
|
add_model2 = True
|
||||||
|
for file in model_paths:
|
||||||
|
if "http" in file:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
if name == self.model_name2 or file == self.model_url2:
|
||||||
|
add_model2 = False
|
||||||
|
try:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
if add_model2:
|
||||||
|
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
|
||||||
|
scalers.append(scaler_data2)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_file):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
model = self.load_model(selected_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
|
||||||
|
device = devices.device_scunet
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
img = img.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(img)
|
||||||
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output = 255. * np.moveaxis(output, 0, 2)
|
||||||
|
output = output.astype(np.uint8)
|
||||||
|
output = output[:, :, ::-1]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return PIL.Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
device = devices.device_scunet
|
||||||
|
if "http" in path:
|
||||||
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||||
|
progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
||||||
|
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
|
model.eval()
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
265
modules/scunet_model_arch.py
Normal file
265
modules/scunet_model_arch.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
from timm.models.layers import trunc_normal_, DropPath
|
||||||
|
|
||||||
|
|
||||||
|
class WMSA(nn.Module):
|
||||||
|
""" Self-attention module in Swin Transformer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
|
||||||
|
super(WMSA, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.n_heads = input_dim // head_dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.type = type
|
||||||
|
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
|
||||||
|
|
||||||
|
self.relative_position_params = nn.Parameter(
|
||||||
|
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
|
||||||
|
|
||||||
|
self.linear = nn.Linear(self.input_dim, self.output_dim)
|
||||||
|
|
||||||
|
trunc_normal_(self.relative_position_params, std=.02)
|
||||||
|
self.relative_position_params = torch.nn.Parameter(
|
||||||
|
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
|
||||||
|
2).transpose(
|
||||||
|
0, 1))
|
||||||
|
|
||||||
|
def generate_mask(self, h, w, p, shift):
|
||||||
|
""" generating the mask of SW-MSA
|
||||||
|
Args:
|
||||||
|
shift: shift parameters in CyclicShift.
|
||||||
|
Returns:
|
||||||
|
attn_mask: should be (1 1 w p p),
|
||||||
|
"""
|
||||||
|
# supporting sqaure.
|
||||||
|
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
|
||||||
|
if self.type == 'W':
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
s = p - shift
|
||||||
|
attn_mask[-1, :, :s, :, s:, :] = True
|
||||||
|
attn_mask[-1, :, s:, :, :s, :] = True
|
||||||
|
attn_mask[:, -1, :, :s, :, s:] = True
|
||||||
|
attn_mask[:, -1, :, s:, :, :s] = True
|
||||||
|
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
""" Forward pass of Window Multi-head Self-attention module.
|
||||||
|
Args:
|
||||||
|
x: input tensor with shape of [b h w c];
|
||||||
|
attn_mask: attention mask, fill -inf where the value is True;
|
||||||
|
Returns:
|
||||||
|
output: tensor shape [b h w c]
|
||||||
|
"""
|
||||||
|
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
|
h_windows = x.size(1)
|
||||||
|
w_windows = x.size(2)
|
||||||
|
# sqaure validation
|
||||||
|
# assert h_windows == w_windows
|
||||||
|
|
||||||
|
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
|
||||||
|
qkv = self.embedding_layer(x)
|
||||||
|
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
|
||||||
|
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
|
||||||
|
# Adding learnable relative embedding
|
||||||
|
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
|
||||||
|
# Using Attn Mask to distinguish different subwindows.
|
||||||
|
if self.type != 'W':
|
||||||
|
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
|
||||||
|
sim = sim.masked_fill_(attn_mask, float("-inf"))
|
||||||
|
|
||||||
|
probs = nn.functional.softmax(sim, dim=-1)
|
||||||
|
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
|
||||||
|
output = rearrange(output, 'h b w p c -> b w p (h c)')
|
||||||
|
output = self.linear(output)
|
||||||
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
|
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
||||||
|
dims=(1, 2))
|
||||||
|
return output
|
||||||
|
|
||||||
|
def relative_embedding(self):
|
||||||
|
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
|
||||||
|
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
|
||||||
|
# negative is allowed
|
||||||
|
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||||
|
""" SwinTransformer Block
|
||||||
|
"""
|
||||||
|
super(Block, self).__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
assert type in ['W', 'SW']
|
||||||
|
self.type = type
|
||||||
|
if input_resolution <= window_size:
|
||||||
|
self.type = 'W'
|
||||||
|
|
||||||
|
self.ln1 = nn.LayerNorm(input_dim)
|
||||||
|
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.ln2 = nn.LayerNorm(input_dim)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(input_dim, 4 * input_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(4 * input_dim, output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.msa(self.ln1(x)))
|
||||||
|
x = x + self.drop_path(self.mlp(self.ln2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvTransBlock(nn.Module):
|
||||||
|
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
|
||||||
|
""" SwinTransformer and Conv Block
|
||||||
|
"""
|
||||||
|
super(ConvTransBlock, self).__init__()
|
||||||
|
self.conv_dim = conv_dim
|
||||||
|
self.trans_dim = trans_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.window_size = window_size
|
||||||
|
self.drop_path = drop_path
|
||||||
|
self.type = type
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
|
||||||
|
assert self.type in ['W', 'SW']
|
||||||
|
if self.input_resolution <= self.window_size:
|
||||||
|
self.type = 'W'
|
||||||
|
|
||||||
|
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
|
||||||
|
self.type, self.input_resolution)
|
||||||
|
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||||
|
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
|
||||||
|
|
||||||
|
self.conv_block = nn.Sequential(
|
||||||
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
|
||||||
|
conv_x = self.conv_block(conv_x) + conv_x
|
||||||
|
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
|
||||||
|
trans_x = self.trans_block(trans_x)
|
||||||
|
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
|
||||||
|
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
|
||||||
|
x = x + res
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SCUNet(nn.Module):
|
||||||
|
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||||
|
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
|
||||||
|
super(SCUNet, self).__init__()
|
||||||
|
if config is None:
|
||||||
|
config = [2, 2, 2, 2, 2, 2, 2]
|
||||||
|
self.config = config
|
||||||
|
self.dim = dim
|
||||||
|
self.head_dim = 32
|
||||||
|
self.window_size = 8
|
||||||
|
|
||||||
|
# drop path rate for each layer
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
|
||||||
|
|
||||||
|
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
|
||||||
|
|
||||||
|
begin = 0
|
||||||
|
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution)
|
||||||
|
for i in range(config[0])] + \
|
||||||
|
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[0]
|
||||||
|
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||||
|
for i in range(config[1])] + \
|
||||||
|
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[1]
|
||||||
|
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||||
|
for i in range(config[2])] + \
|
||||||
|
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
|
||||||
|
|
||||||
|
begin += config[2]
|
||||||
|
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 8)
|
||||||
|
for i in range(config[3])]
|
||||||
|
|
||||||
|
begin += config[3]
|
||||||
|
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 4)
|
||||||
|
for i in range(config[4])]
|
||||||
|
|
||||||
|
begin += config[4]
|
||||||
|
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution // 2)
|
||||||
|
for i in range(config[5])]
|
||||||
|
|
||||||
|
begin += config[5]
|
||||||
|
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
|
||||||
|
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
|
||||||
|
'W' if not i % 2 else 'SW', input_resolution)
|
||||||
|
for i in range(config[6])]
|
||||||
|
|
||||||
|
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
|
||||||
|
|
||||||
|
self.m_head = nn.Sequential(*self.m_head)
|
||||||
|
self.m_down1 = nn.Sequential(*self.m_down1)
|
||||||
|
self.m_down2 = nn.Sequential(*self.m_down2)
|
||||||
|
self.m_down3 = nn.Sequential(*self.m_down3)
|
||||||
|
self.m_body = nn.Sequential(*self.m_body)
|
||||||
|
self.m_up3 = nn.Sequential(*self.m_up3)
|
||||||
|
self.m_up2 = nn.Sequential(*self.m_up2)
|
||||||
|
self.m_up1 = nn.Sequential(*self.m_up1)
|
||||||
|
self.m_tail = nn.Sequential(*self.m_tail)
|
||||||
|
# self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def forward(self, x0):
|
||||||
|
|
||||||
|
h, w = x0.size()[-2:]
|
||||||
|
paddingBottom = int(np.ceil(h / 64) * 64 - h)
|
||||||
|
paddingRight = int(np.ceil(w / 64) * 64 - w)
|
||||||
|
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
|
||||||
|
|
||||||
|
x1 = self.m_head(x0)
|
||||||
|
x2 = self.m_down1(x1)
|
||||||
|
x3 = self.m_down2(x2)
|
||||||
|
x4 = self.m_down3(x3)
|
||||||
|
x = self.m_body(x4)
|
||||||
|
x = self.m_up3(x + x4)
|
||||||
|
x = self.m_up2(x + x3)
|
||||||
|
x = self.m_up1(x + x2)
|
||||||
|
x = self.m_tail(x + x1)
|
||||||
|
|
||||||
|
x = x[..., :h, :w]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
@ -5,240 +5,44 @@ import traceback
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
from modules import prompt_parser
|
import modules.textual_inversion.textual_inversion
|
||||||
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import opts, device, cmd_opts
|
||||||
|
|
||||||
from ldm.util import default
|
|
||||||
from einops import rearrange
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q = self.to_q(x)
|
|
||||||
context = default(context, x)
|
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
del context, x
|
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
|
||||||
for i in range(0, q.shape[0], 2):
|
|
||||||
end = i + 2
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
||||||
s1 *= self.scale
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
|
||||||
del s2
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion
|
def apply_optimizations():
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
h = self.heads
|
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
if cmd_opts.opt_split_attention_v1:
|
||||||
context = default(context, x)
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||||
k_in = self.to_k(context) * self.scale
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||||
v_in = self.to_v(context)
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||||
del context, x
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
|
||||||
del q_in, k_in, v_in
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
def undo_optimizations():
|
||||||
|
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
|
||||||
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
||||||
|
|
||||||
if steps > 64:
|
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
||||||
del s2
|
|
||||||
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
|
||||||
del r1
|
|
||||||
|
|
||||||
return self.to_out(r2)
|
|
||||||
|
|
||||||
def nonlinearity_hijack(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q1 = self.q(h_)
|
|
||||||
k1 = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
b, c, h, w = q1.shape
|
|
||||||
|
|
||||||
q2 = q1.reshape(b, c, h*w)
|
|
||||||
del q1
|
|
||||||
|
|
||||||
q = q2.permute(0, 2, 1) # b,hw,c
|
|
||||||
del q2
|
|
||||||
|
|
||||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
|
||||||
del k1
|
|
||||||
|
|
||||||
h_ = torch.zeros_like(k, device=q.device)
|
|
||||||
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
|
||||||
mem_active = stats['active_bytes.all.current']
|
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
|
||||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
|
||||||
mem_free_torch = mem_reserved - mem_active
|
|
||||||
mem_free_total = mem_free_cuda + mem_free_torch
|
|
||||||
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
mem_required = tensor_size * 2.5
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
|
|
||||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
|
||||||
w2 = w1 * (int(c)**(-0.5))
|
|
||||||
del w1
|
|
||||||
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
|
||||||
del w2
|
|
||||||
|
|
||||||
# attend to values
|
|
||||||
v1 = v.reshape(b, c, h*w)
|
|
||||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
|
||||||
del w3
|
|
||||||
|
|
||||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
|
||||||
del v1, w4
|
|
||||||
|
|
||||||
h2 = h_.reshape(b, c, h, w)
|
|
||||||
del h_
|
|
||||||
|
|
||||||
h3 = self.proj_out(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
h3 += x
|
|
||||||
|
|
||||||
return h3
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
ids_lookup = {}
|
|
||||||
word_embeddings = {}
|
|
||||||
word_embeddings_checksums = {}
|
|
||||||
fixes = None
|
fixes = None
|
||||||
comments = []
|
comments = []
|
||||||
dir_mtime = None
|
|
||||||
layers = None
|
layers = None
|
||||||
circular_enabled = False
|
circular_enabled = False
|
||||||
clip = None
|
clip = None
|
||||||
|
|
||||||
def load_textual_inversion_embeddings(self, dirname, model):
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
mt = os.path.getmtime(dirname)
|
|
||||||
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.dir_mtime = mt
|
|
||||||
self.ids_lookup.clear()
|
|
||||||
self.word_embeddings.clear()
|
|
||||||
|
|
||||||
tokenizer = model.cond_stage_model.tokenizer
|
|
||||||
|
|
||||||
def const_hash(a):
|
|
||||||
r = 0
|
|
||||||
for v in a:
|
|
||||||
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
|
||||||
return r
|
|
||||||
|
|
||||||
def process_file(path, filename):
|
|
||||||
name = os.path.splitext(filename)[0]
|
|
||||||
|
|
||||||
data = torch.load(path, map_location="cpu")
|
|
||||||
|
|
||||||
# textual inversion embeddings
|
|
||||||
if 'string_to_param' in data:
|
|
||||||
param_dict = data['string_to_param']
|
|
||||||
if hasattr(param_dict, '_parameters'):
|
|
||||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
|
||||||
emb = next(iter(param_dict.items()))[1]
|
|
||||||
# diffuser concepts
|
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
|
||||||
if len(emb.shape) == 1:
|
|
||||||
emb = emb.unsqueeze(0)
|
|
||||||
|
|
||||||
self.word_embeddings[name] = emb.detach().to(device)
|
|
||||||
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
|
|
||||||
|
|
||||||
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
|
||||||
|
|
||||||
first_id = ids[0]
|
|
||||||
if first_id not in self.ids_lookup:
|
|
||||||
self.ids_lookup[first_id] = []
|
|
||||||
self.ids_lookup[first_id].append((ids, name))
|
|
||||||
|
|
||||||
for fn in os.listdir(dirname):
|
|
||||||
try:
|
|
||||||
process_file(os.path.join(dirname, fn), fn)
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
@ -248,12 +52,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
if cmd_opts.opt_split_attention_v1:
|
apply_optimizations()
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
|
||||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
|
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
|
||||||
|
|
||||||
def flatten(el):
|
def flatten(el):
|
||||||
flattened = [flatten(children) for children in el.children()]
|
flattened = [flatten(children) for children in el.children()]
|
||||||
@ -291,7 +90,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
def __init__(self, wrapped, hijack):
|
def __init__(self, wrapped, hijack):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.hijack = hijack
|
self.hijack: StableDiffusionModelHijack = hijack
|
||||||
self.tokenizer = wrapped.tokenizer
|
self.tokenizer = wrapped.tokenizer
|
||||||
self.max_length = wrapped.max_length
|
self.max_length = wrapped.max_length
|
||||||
self.token_mults = {}
|
self.token_mults = {}
|
||||||
@ -312,7 +111,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
if mult != 1.0:
|
if mult != 1.0:
|
||||||
self.token_mults[ident] = mult
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
@ -334,28 +132,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
if possible_matches is None:
|
if embedding is None:
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(weight)
|
|
||||||
else:
|
|
||||||
found = False
|
|
||||||
for ids, word in possible_matches:
|
|
||||||
if tokens[i:i + len(ids)] == ids:
|
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
|
||||||
fixes.append((len(remade_tokens), word))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
i += len(ids) - 1
|
|
||||||
found = True
|
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(weight)
|
multipliers.append(weight)
|
||||||
i += 1
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
fixes.append((len(remade_tokens), embedding))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [weight] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
@ -426,32 +215,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
while i < len(tokens):
|
while i < len(tokens):
|
||||||
token = tokens[i]
|
token = tokens[i]
|
||||||
|
|
||||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||||
if mult_change is not None:
|
if mult_change is not None:
|
||||||
mult *= mult_change
|
mult *= mult_change
|
||||||
elif possible_matches is None:
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
remade_tokens.append(token)
|
remade_tokens.append(token)
|
||||||
multipliers.append(mult)
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
else:
|
else:
|
||||||
found = False
|
emb_len = int(embedding.vec.shape[0])
|
||||||
for ids, word in possible_matches:
|
fixes.append((len(remade_tokens), embedding))
|
||||||
if tokens[i:i+len(ids)] == ids:
|
|
||||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
|
||||||
fixes.append((len(remade_tokens), word))
|
|
||||||
remade_tokens += [0] * emb_len
|
remade_tokens += [0] * emb_len
|
||||||
multipliers += [mult] * emb_len
|
multipliers += [mult] * emb_len
|
||||||
i += len(ids) - 1
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
found = True
|
i += embedding_length_in_tokens
|
||||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
if len(remade_tokens) > maxlen - 2:
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
@ -459,6 +239,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
token_count = len(remade_tokens)
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
@ -479,7 +260,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
self.hijack.comments = hijack_comments
|
self.hijack.comments = hijack_comments
|
||||||
|
|
||||||
@ -512,15 +292,20 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||||||
|
|
||||||
inputs_embeds = self.wrapped(input_ids)
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
if batch_fixes is not None:
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
||||||
for offset, word in fixes:
|
|
||||||
emb = self.embeddings.word_embeddings[word]
|
|
||||||
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
|
||||||
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
|
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
|
vecs = []
|
||||||
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
for offset, embedding in fixes:
|
||||||
|
emb = embedding.vec
|
||||||
|
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
|
||||||
|
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
|
||||||
|
|
||||||
|
vecs.append(tensor)
|
||||||
|
|
||||||
|
return torch.stack(vecs)
|
||||||
|
|
||||||
|
|
||||||
def add_circular_option_to_conv_2d():
|
def add_circular_option_to_conv_2d():
|
||||||
conv2d_constructor = torch.nn.Conv2d.__init__
|
conv2d_constructor = torch.nn.Conv2d.__init__
|
||||||
|
156
modules/sd_hijack_optimizations.py
Normal file
156
modules/sd_hijack_optimizations.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from ldm.util import default
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||||
|
for i in range(0, q.shape[0], 2):
|
||||||
|
end = i + 2
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
|
s1 *= self.scale
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
|
del s2
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
|
||||||
|
# taken from https://github.com/Doggettx/stable-diffusion
|
||||||
|
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||||
|
h = self.heads
|
||||||
|
|
||||||
|
q_in = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
k_in = self.to_k(context) * self.scale
|
||||||
|
v_in = self.to_v(context)
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
|
if steps > 64:
|
||||||
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
del r1
|
||||||
|
|
||||||
|
return self.to_out(r2)
|
||||||
|
|
||||||
|
def cross_attention_attnblock_forward(self, x):
|
||||||
|
h_ = x
|
||||||
|
h_ = self.norm(h_)
|
||||||
|
q1 = self.q(h_)
|
||||||
|
k1 = self.k(h_)
|
||||||
|
v = self.v(h_)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
b, c, h, w = q1.shape
|
||||||
|
|
||||||
|
q2 = q1.reshape(b, c, h*w)
|
||||||
|
del q1
|
||||||
|
|
||||||
|
q = q2.permute(0, 2, 1) # b,hw,c
|
||||||
|
del q2
|
||||||
|
|
||||||
|
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||||
|
del k1
|
||||||
|
|
||||||
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = mem_free_cuda + mem_free_torch
|
||||||
|
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
mem_required = tensor_size * 2.5
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
|
||||||
|
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||||
|
w2 = w1 * (int(c)**(-0.5))
|
||||||
|
del w1
|
||||||
|
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
|
||||||
|
del w2
|
||||||
|
|
||||||
|
# attend to values
|
||||||
|
v1 = v.reshape(b, c, h*w)
|
||||||
|
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||||
|
del w3
|
||||||
|
|
||||||
|
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||||
|
del v1, w4
|
||||||
|
|
||||||
|
h2 = h_.reshape(b, c, h, w)
|
||||||
|
del h_
|
||||||
|
|
||||||
|
h3 = self.proj_out(h2)
|
||||||
|
del h2
|
||||||
|
|
||||||
|
h3 += x
|
||||||
|
|
||||||
|
return h3
|
@ -8,14 +8,11 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared, modelloader
|
from modules import shared, modelloader, devices
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||||
model_name = "sd-v1-4.ckpt"
|
|
||||||
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
|
|
||||||
user_dir = None
|
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
@ -30,12 +27,10 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def setup_model(dirname):
|
def setup_model():
|
||||||
global user_dir
|
|
||||||
user_dir = dirname
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
os.makedirs(model_path)
|
os.makedirs(model_path)
|
||||||
checkpoints_list.clear()
|
|
||||||
list_models()
|
list_models()
|
||||||
|
|
||||||
|
|
||||||
@ -45,13 +40,13 @@ def checkpoint_tiles():
|
|||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
|
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
|
||||||
|
|
||||||
def modeltitle(path, shorthash):
|
def modeltitle(path, shorthash):
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
if user_dir is not None and abspath.startswith(user_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(user_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
name = abspath.replace(model_path, '')
|
name = abspath.replace(model_path, '')
|
||||||
else:
|
else:
|
||||||
@ -69,6 +64,7 @@ def list_models():
|
|||||||
h = model_hash(cmd_ckpt)
|
h = model_hash(cmd_ckpt)
|
||||||
title, short_model_name = modeltitle(cmd_ckpt, h)
|
title, short_model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
|
||||||
|
shared.opts.data['sd_model_checkpoint'] = title
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
for filename in model_list:
|
for filename in model_list:
|
||||||
@ -105,7 +101,10 @@ def select_checkpoint():
|
|||||||
|
|
||||||
if len(checkpoints_list) == 0:
|
if len(checkpoints_list) == 0:
|
||||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||||
|
if shared.cmd_opts.ckpt is not None:
|
||||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||||
|
print(f" - directory {model_path}", file=sys.stderr)
|
||||||
|
if shared.cmd_opts.ckpt_dir is not None:
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
||||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
@ -133,6 +132,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
|
|||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
|
|
||||||
model.sd_model_hash = sd_model_hash
|
model.sd_model_hash = sd_model_hash
|
||||||
model.sd_model_checkpint = checkpoint_file
|
model.sd_model_checkpint = checkpoint_file
|
||||||
|
|
||||||
|
@ -13,31 +13,57 @@ from modules.shared import opts, cmd_opts, state
|
|||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
samplers_k_diffusion = [
|
samplers_k_diffusion = [
|
||||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
|
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
|
||||||
('Euler', 'sample_euler', ['k_euler']),
|
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||||
('LMS', 'sample_lms', ['k_lms']),
|
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||||
('Heun', 'sample_heun', ['k_heun']),
|
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||||
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
|
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
||||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
|
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||||
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
||||||
|
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
||||||
]
|
]
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
samplers_data_k_diffusion = [
|
||||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
|
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||||
for label, funcname, aliases in samplers_k_diffusion
|
for label, funcname, aliases, options in samplers_k_diffusion
|
||||||
if hasattr(k_diffusion.sampling, funcname)
|
if hasattr(k_diffusion.sampling, funcname)
|
||||||
]
|
]
|
||||||
|
|
||||||
samplers = [
|
all_samplers = [
|
||||||
*samplers_data_k_diffusion,
|
*samplers_data_k_diffusion,
|
||||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||||
]
|
]
|
||||||
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
|
|
||||||
|
samplers = []
|
||||||
|
samplers_for_img2img = []
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler_with_index(list_of_configs, index, model):
|
||||||
|
config = list_of_configs[index]
|
||||||
|
sampler = config.constructor(model)
|
||||||
|
sampler.config = config
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
|
|
||||||
|
def set_samplers():
|
||||||
|
global samplers, samplers_for_img2img
|
||||||
|
|
||||||
|
hidden = set(opts.hide_samplers)
|
||||||
|
hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive'])
|
||||||
|
|
||||||
|
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||||
|
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||||
|
|
||||||
|
|
||||||
|
set_samplers()
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
@ -77,7 +103,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
|||||||
state.sampling_steps = len(sequence)
|
state.sampling_steps = len(sequence)
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
|
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -102,14 +130,18 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 0.0
|
self.default_eta = 0.0
|
||||||
|
self.config = None
|
||||||
|
|
||||||
def number_of_needed_noises(self, p):
|
def number_of_needed_noises(self, p):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||||
|
|
||||||
|
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
|
||||||
|
cond = tensor
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||||
@ -125,7 +157,7 @@ class VanillaStableDiffusionSampler:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def initialize(self, p):
|
def initialize(self, p):
|
||||||
self.eta = p.eta or opts.eta_ddim
|
self.eta = p.eta if p.eta is not None else opts.eta_ddim
|
||||||
|
|
||||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||||
if hasattr(self.sampler, fieldname):
|
if hasattr(self.sampler, fieldname):
|
||||||
@ -181,19 +213,31 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
|
||||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||||
|
|
||||||
|
batch_size = len(conds_list)
|
||||||
|
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
|
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||||
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_in = torch.cat([x] * 2)
|
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
|
||||||
cond_in = torch.cat([uncond, cond])
|
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
|
||||||
denoised = uncond + (cond - uncond) * cond_scale
|
|
||||||
else:
|
else:
|
||||||
uncond = self.inner_model(x, sigma, cond=uncond)
|
x_out = torch.zeros_like(x_in)
|
||||||
cond = self.inner_model(x, sigma, cond=cond)
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
denoised = uncond + (cond - uncond) * cond_scale
|
a = batch_offset
|
||||||
|
b = a + batch_size
|
||||||
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
|
||||||
|
|
||||||
|
denoised_uncond = x_out[-batch_size:]
|
||||||
|
denoised = torch.clone(denoised_uncond)
|
||||||
|
|
||||||
|
for i, conds in enumerate(conds_list):
|
||||||
|
for cond_index, weight in conds:
|
||||||
|
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
@ -207,7 +251,9 @@ def extended_trange(sampler, count, *args, **kwargs):
|
|||||||
state.sampling_steps = count
|
state.sampling_steps = count
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
|
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
|
||||||
|
|
||||||
|
for x in seq:
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -246,6 +292,7 @@ class KDiffusionSampler:
|
|||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.default_eta = 1.0
|
self.default_eta = 1.0
|
||||||
|
self.config = None
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
store_latent(d["denoised"])
|
store_latent(d["denoised"])
|
||||||
@ -290,6 +337,9 @@ class KDiffusionSampler:
|
|||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
|
if p.sampler_noise_scheduler_override:
|
||||||
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
|
else:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
noise = noise * sigmas[steps - t_enc - 1]
|
noise = noise * sigmas[steps - t_enc - 1]
|
||||||
@ -306,7 +356,13 @@ class KDiffusionSampler:
|
|||||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
|
if p.sampler_noise_scheduler_override:
|
||||||
|
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||||
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||||
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||||
|
else:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
x = x * sigmas[0]
|
x = x * sigmas[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
|
@ -12,15 +12,15 @@ import modules.interrogate
|
|||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.styles
|
import modules.styles
|
||||||
from modules.devices import get_optimal_device
|
import modules.devices as devices
|
||||||
from modules.paths import script_path, sd_path
|
from modules import sd_samplers
|
||||||
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
model_path = os.path.join(script_path, 'models')
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||||
@ -35,16 +35,18 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
|
|||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
|
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
|
||||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
|
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
|
||||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||||
|
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
|
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
@ -53,13 +55,21 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire
|
|||||||
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
|
||||||
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
|
||||||
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
|
||||||
|
parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||||
|
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||||
|
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
device = get_optimal_device()
|
|
||||||
|
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||||
|
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
|
||||||
|
|
||||||
|
device = devices.device
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
||||||
@ -78,6 +88,7 @@ class State:
|
|||||||
current_latent = None
|
current_latent = None
|
||||||
current_image = None
|
current_image = None
|
||||||
current_image_sampling_step = 0
|
current_image_sampling_step = 0
|
||||||
|
textinfo = None
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
@ -88,7 +99,7 @@ class State:
|
|||||||
self.current_image_sampling_step = 0
|
self.current_image_sampling_step = 0
|
||||||
|
|
||||||
def get_job_timestamp(self):
|
def get_job_timestamp(self):
|
||||||
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
|
||||||
|
|
||||||
|
|
||||||
state = State()
|
state = State()
|
||||||
@ -165,9 +176,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
"save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
|
||||||
"grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"),
|
"grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
|
||||||
|
"use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
|
||||||
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
"directories_filename_pattern": OptionInfo("", "Directory name pattern"),
|
||||||
"directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
"directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
@ -177,7 +189,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
|||||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
@ -189,7 +201,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
|||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
@ -198,7 +210,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
||||||
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
||||||
@ -218,13 +230,16 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
"show_progressbar": OptionInfo(True, "Show progressbar"),
|
||||||
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
@ -233,6 +248,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
data = None
|
data = None
|
||||||
data_labels = options_templates
|
data_labels = options_templates
|
||||||
@ -318,14 +334,14 @@ class TotalTQDM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
if not opts.multiple_tqdm:
|
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||||
return
|
return
|
||||||
if self._tqdm is None:
|
if self._tqdm is None:
|
||||||
self.reset()
|
self.reset()
|
||||||
self._tqdm.update()
|
self._tqdm.update()
|
||||||
|
|
||||||
def updateTotal(self, new_total):
|
def updateTotal(self, new_total):
|
||||||
if not opts.multiple_tqdm:
|
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
|
||||||
return
|
return
|
||||||
if self._tqdm is None:
|
if self._tqdm is None:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -5,6 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
@ -122,6 +123,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||||
|
|
||||||
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||||
for h_idx in h_idx_list:
|
for h_idx in h_idx_list:
|
||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
@ -134,6 +136,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
W[
|
W[
|
||||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
].add_(out_patch_mask)
|
].add_(out_patch_mask)
|
||||||
|
pbar.update(1)
|
||||||
output = E.div_(W)
|
output = E.div_(W)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
81
modules/textual_inversion/dataset.py
Normal file
81
modules/textual_inversion/dataset.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import random
|
||||||
|
import tqdm
|
||||||
|
from modules import devices
|
||||||
|
import re
|
||||||
|
|
||||||
|
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
||||||
|
|
||||||
|
|
||||||
|
class PersonalizedBase(Dataset):
|
||||||
|
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
||||||
|
|
||||||
|
self.placeholder_token = placeholder_token
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
self.dataset = []
|
||||||
|
|
||||||
|
with open(template_file, "r") as file:
|
||||||
|
lines = [x.strip() for x in file.readlines()]
|
||||||
|
|
||||||
|
self.lines = lines
|
||||||
|
|
||||||
|
assert data_root, 'dataset directory not specified'
|
||||||
|
|
||||||
|
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
||||||
|
print("Preparing dataset...")
|
||||||
|
for path in tqdm.tqdm(self.image_paths):
|
||||||
|
image = Image.open(path)
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
||||||
|
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
filename_tokens = os.path.splitext(filename)[0]
|
||||||
|
filename_tokens = re_tag.findall(filename_tokens)
|
||||||
|
|
||||||
|
npimage = np.array(image).astype(np.uint8)
|
||||||
|
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
||||||
|
torchdata = torch.moveaxis(torchdata, 2, 0)
|
||||||
|
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
||||||
|
init_latent = init_latent.to(devices.cpu)
|
||||||
|
|
||||||
|
self.dataset.append((init_latent, filename_tokens))
|
||||||
|
|
||||||
|
self.length = len(self.dataset) * repeats
|
||||||
|
|
||||||
|
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
||||||
|
self.indexes = None
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
def shuffle(self):
|
||||||
|
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
if i % len(self.dataset) == 0:
|
||||||
|
self.shuffle()
|
||||||
|
|
||||||
|
index = self.indexes[i % len(self.indexes)]
|
||||||
|
x, filename_tokens = self.dataset[index]
|
||||||
|
|
||||||
|
text = random.choice(self.lines)
|
||||||
|
text = text.replace("[name]", self.placeholder_token)
|
||||||
|
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
||||||
|
|
||||||
|
return x, text
|
104
modules/textual_inversion/preprocess.py
Normal file
104
modules/textual_inversion/preprocess.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from modules import shared, images
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
|
||||||
|
size = 512
|
||||||
|
src = os.path.abspath(process_src)
|
||||||
|
dst = os.path.abspath(process_dst)
|
||||||
|
|
||||||
|
assert src != dst, 'same directory specified as source and destination'
|
||||||
|
|
||||||
|
os.makedirs(dst, exist_ok=True)
|
||||||
|
|
||||||
|
files = os.listdir(src)
|
||||||
|
|
||||||
|
shared.state.textinfo = "Preprocessing..."
|
||||||
|
shared.state.job_count = len(files)
|
||||||
|
|
||||||
|
if process_caption:
|
||||||
|
shared.interrogator.load()
|
||||||
|
|
||||||
|
def save_pic_with_caption(image, index):
|
||||||
|
if process_caption:
|
||||||
|
caption = "-" + shared.interrogator.generate_caption(image)
|
||||||
|
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
||||||
|
else:
|
||||||
|
caption = filename
|
||||||
|
caption = os.path.splitext(caption)[0]
|
||||||
|
caption = os.path.basename(caption)
|
||||||
|
|
||||||
|
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
||||||
|
subindex[0] += 1
|
||||||
|
|
||||||
|
def save_pic(image, index):
|
||||||
|
save_pic_with_caption(image, index)
|
||||||
|
|
||||||
|
if process_flip:
|
||||||
|
save_pic_with_caption(ImageOps.mirror(image), index)
|
||||||
|
|
||||||
|
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
||||||
|
subindex = [0]
|
||||||
|
filename = os.path.join(src, imagefile)
|
||||||
|
img = Image.open(filename).convert("RGB")
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
ratio = img.height / img.width
|
||||||
|
is_tall = ratio > 1.35
|
||||||
|
is_wide = ratio < 1 / 1.35
|
||||||
|
|
||||||
|
if process_split and is_tall:
|
||||||
|
img = img.resize((size, size * img.height // img.width))
|
||||||
|
|
||||||
|
top = img.crop((0, 0, size, size))
|
||||||
|
save_pic(top, index)
|
||||||
|
|
||||||
|
bot = img.crop((0, img.height - size, size, img.height))
|
||||||
|
save_pic(bot, index)
|
||||||
|
elif process_split and is_wide:
|
||||||
|
img = img.resize((size * img.width // img.height, size))
|
||||||
|
|
||||||
|
left = img.crop((0, 0, size, size))
|
||||||
|
save_pic(left, index)
|
||||||
|
|
||||||
|
right = img.crop((img.width - size, 0, img.width, size))
|
||||||
|
save_pic(right, index)
|
||||||
|
else:
|
||||||
|
img = images.resize_image(1, img, size, size)
|
||||||
|
save_pic(img, index)
|
||||||
|
|
||||||
|
shared.state.nextjob()
|
||||||
|
|
||||||
|
if process_caption:
|
||||||
|
shared.interrogator.send_blip_to_ram()
|
||||||
|
|
||||||
|
def sanitize_caption(base_path, original_caption, suffix):
|
||||||
|
operating_system = platform.system().lower()
|
||||||
|
if (operating_system == "windows"):
|
||||||
|
invalid_path_characters = "\\/:*?\"<>|"
|
||||||
|
max_path_length = 259
|
||||||
|
else:
|
||||||
|
invalid_path_characters = "/" #linux/macos
|
||||||
|
max_path_length = 1023
|
||||||
|
caption = original_caption
|
||||||
|
for invalid_character in invalid_path_characters:
|
||||||
|
caption = caption.replace(invalid_character, "")
|
||||||
|
fixed_path_length = len(base_path) + len(suffix)
|
||||||
|
if fixed_path_length + len(caption) <= max_path_length:
|
||||||
|
return caption
|
||||||
|
caption_tokens = caption.split()
|
||||||
|
new_caption = ""
|
||||||
|
for token in caption_tokens:
|
||||||
|
last_caption = new_caption
|
||||||
|
new_caption = new_caption + token + " "
|
||||||
|
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
|
||||||
|
break
|
||||||
|
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
|
||||||
|
return last_caption.strip()
|
271
modules/textual_inversion/textual_inversion.py
Normal file
271
modules/textual_inversion/textual_inversion.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import html
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
from modules import shared, devices, sd_hijack, processing, sd_models
|
||||||
|
import modules.textual_inversion.dataset
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding:
|
||||||
|
def __init__(self, vec, name, step=None):
|
||||||
|
self.vec = vec
|
||||||
|
self.name = name
|
||||||
|
self.step = step
|
||||||
|
self.cached_checksum = None
|
||||||
|
self.sd_checkpoint = None
|
||||||
|
self.sd_checkpoint_name = None
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
embedding_data = {
|
||||||
|
"string_to_token": {"*": 265},
|
||||||
|
"string_to_param": {"*": self.vec},
|
||||||
|
"name": self.name,
|
||||||
|
"step": self.step,
|
||||||
|
"sd_checkpoint": self.sd_checkpoint,
|
||||||
|
"sd_checkpoint_name": self.sd_checkpoint_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.save(embedding_data, filename)
|
||||||
|
|
||||||
|
def checksum(self):
|
||||||
|
if self.cached_checksum is not None:
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
def const_hash(a):
|
||||||
|
r = 0
|
||||||
|
for v in a:
|
||||||
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||||
|
return r
|
||||||
|
|
||||||
|
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
||||||
|
return self.cached_checksum
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingDatabase:
|
||||||
|
def __init__(self, embeddings_dir):
|
||||||
|
self.ids_lookup = {}
|
||||||
|
self.word_embeddings = {}
|
||||||
|
self.dir_mtime = None
|
||||||
|
self.embeddings_dir = embeddings_dir
|
||||||
|
|
||||||
|
def register_embedding(self, embedding, model):
|
||||||
|
|
||||||
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
|
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
|
||||||
|
first_id = ids[0]
|
||||||
|
if first_id not in self.ids_lookup:
|
||||||
|
self.ids_lookup[first_id] = []
|
||||||
|
|
||||||
|
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def load_textual_inversion_embeddings(self):
|
||||||
|
mt = os.path.getmtime(self.embeddings_dir)
|
||||||
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dir_mtime = mt
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
|
||||||
|
def process_file(path, filename):
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = torch.load(path, map_location="cpu")
|
||||||
|
|
||||||
|
# textual inversion embeddings
|
||||||
|
if 'string_to_param' in data:
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
if hasattr(param_dict, '_parameters'):
|
||||||
|
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1]
|
||||||
|
# diffuser concepts
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||||
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if len(emb.shape) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = data.get('step', None)
|
||||||
|
embedding.sd_checkpoint = data.get('hash', None)
|
||||||
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
|
self.register_embedding(embedding, shared.sd_model)
|
||||||
|
|
||||||
|
for fn in os.listdir(self.embeddings_dir):
|
||||||
|
try:
|
||||||
|
fullfn = os.path.join(self.embeddings_dir, fn)
|
||||||
|
|
||||||
|
if os.stat(fullfn).st_size == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
process_file(fullfn, fn)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||||
|
|
||||||
|
def find_embedding_at_position(self, tokens, offset):
|
||||||
|
token = tokens[offset]
|
||||||
|
possible_matches = self.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for ids, embedding in possible_matches:
|
||||||
|
if tokens[offset:offset + len(ids)] == ids:
|
||||||
|
return embedding, len(ids)
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
|
for i in range(num_vectors_per_token):
|
||||||
|
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
||||||
|
|
||||||
|
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
||||||
|
assert not os.path.exists(fn), f"file {fn} already exists"
|
||||||
|
|
||||||
|
embedding = Embedding(vec, name)
|
||||||
|
embedding.step = 0
|
||||||
|
embedding.save(fn)
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
||||||
|
assert embedding_name, 'embedding not selected'
|
||||||
|
|
||||||
|
shared.state.textinfo = "Initializing textual inversion training..."
|
||||||
|
shared.state.job_count = steps
|
||||||
|
|
||||||
|
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
||||||
|
|
||||||
|
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
||||||
|
|
||||||
|
if save_embedding_every > 0:
|
||||||
|
embedding_dir = os.path.join(log_directory, "embeddings")
|
||||||
|
os.makedirs(embedding_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
embedding_dir = None
|
||||||
|
|
||||||
|
if create_image_every > 0:
|
||||||
|
images_dir = os.path.join(log_directory, "images")
|
||||||
|
os.makedirs(images_dir, exist_ok=True)
|
||||||
|
else:
|
||||||
|
images_dir = None
|
||||||
|
|
||||||
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
|
|
||||||
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
||||||
|
|
||||||
|
hijack = sd_hijack.model_hijack
|
||||||
|
|
||||||
|
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
||||||
|
embedding.vec.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
||||||
|
|
||||||
|
losses = torch.zeros((32,))
|
||||||
|
|
||||||
|
last_saved_file = "<none>"
|
||||||
|
last_saved_image = "<none>"
|
||||||
|
|
||||||
|
ititial_step = embedding.step or 0
|
||||||
|
if ititial_step > steps:
|
||||||
|
return embedding, filename
|
||||||
|
|
||||||
|
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
||||||
|
for i, (x, text) in pbar:
|
||||||
|
embedding.step = i + ititial_step
|
||||||
|
|
||||||
|
if embedding.step > steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if shared.state.interrupted:
|
||||||
|
break
|
||||||
|
|
||||||
|
with torch.autocast("cuda"):
|
||||||
|
c = cond_model([text])
|
||||||
|
|
||||||
|
x = x.to(devices.device)
|
||||||
|
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
||||||
|
del x
|
||||||
|
|
||||||
|
losses[embedding.step % losses.shape[0]] = loss.item()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
pbar.set_description(f"loss: {losses.mean():.7f}")
|
||||||
|
|
||||||
|
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
||||||
|
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
||||||
|
embedding.save(last_saved_file)
|
||||||
|
|
||||||
|
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
||||||
|
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
|
sd_model=shared.sd_model,
|
||||||
|
prompt=text,
|
||||||
|
steps=20,
|
||||||
|
do_not_save_grid=True,
|
||||||
|
do_not_save_samples=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
processed = processing.process_images(p)
|
||||||
|
image = processed.images[0]
|
||||||
|
|
||||||
|
shared.state.current_image = image
|
||||||
|
image.save(last_saved_image)
|
||||||
|
|
||||||
|
last_saved_image += f", prompt: {text}"
|
||||||
|
|
||||||
|
shared.state.job_no = embedding.step
|
||||||
|
|
||||||
|
shared.state.textinfo = f"""
|
||||||
|
<p>
|
||||||
|
Loss: {losses.mean():.7f}<br/>
|
||||||
|
Step: {embedding.step}<br/>
|
||||||
|
Last prompt: {html.escape(text)}<br/>
|
||||||
|
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
||||||
|
Last saved image: {html.escape(last_saved_image)}<br/>
|
||||||
|
</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpoint = sd_models.select_checkpoint()
|
||||||
|
|
||||||
|
embedding.sd_checkpoint = checkpoint.hash
|
||||||
|
embedding.sd_checkpoint_name = checkpoint.model_name
|
||||||
|
embedding.cached_checksum = None
|
||||||
|
embedding.save(filename)
|
||||||
|
|
||||||
|
return embedding, filename
|
||||||
|
|
40
modules/textual_inversion/ui.py
Normal file
40
modules/textual_inversion/ui.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import html
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
import modules.textual_inversion.preprocess
|
||||||
|
from modules import sd_hijack, shared
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding(name, initialization_text, nvpt):
|
||||||
|
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(*args):
|
||||||
|
modules.textual_inversion.preprocess.preprocess(*args)
|
||||||
|
|
||||||
|
return "Preprocessing finished.", ""
|
||||||
|
|
||||||
|
|
||||||
|
def train_embedding(*args):
|
||||||
|
|
||||||
|
try:
|
||||||
|
sd_hijack.undo_optimizations()
|
||||||
|
|
||||||
|
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||||
|
|
||||||
|
res = f"""
|
||||||
|
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
||||||
|
Embedding saved to {html.escape(filename)}
|
||||||
|
"""
|
||||||
|
return res, ""
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
sd_hijack.apply_optimizations()
|
||||||
|
|
@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||||||
denoising_strength=denoising_strength if enable_hr else None,
|
denoising_strength=denoising_strength if enable_hr else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cmd_opts.enable_console_prompts:
|
||||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
processed = modules.scripts.scripts_txt2img.run(p, *args)
|
||||||
|
|
||||||
if processed is None:
|
if processed is None:
|
||||||
@ -46,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
|
|||||||
if opts.samples_log_stdout:
|
if opts.samples_log_stdout:
|
||||||
print(generation_info_js)
|
print(generation_info_js)
|
||||||
|
|
||||||
|
if opts.do_not_show_images:
|
||||||
|
processed.images = []
|
||||||
|
|
||||||
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
return processed.images, generation_info_js, plaintext_to_html(processed.info)
|
||||||
|
|
||||||
|
301
modules/ui.py
301
modules/ui.py
@ -11,6 +11,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import platform
|
import platform
|
||||||
import subprocess as sp
|
import subprocess as sp
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -21,6 +22,7 @@ import gradio as gr
|
|||||||
import gradio.utils
|
import gradio.utils
|
||||||
import gradio.routes
|
import gradio.routes
|
||||||
|
|
||||||
|
from modules import sd_hijack
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -32,6 +34,9 @@ import modules.gfpgan_model
|
|||||||
import modules.codeformer_model
|
import modules.codeformer_model
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.generation_parameters_copypaste
|
import modules.generation_parameters_copypaste
|
||||||
|
from modules import prompt_parser
|
||||||
|
from modules.images import save_image
|
||||||
|
import modules.textual_inversion.ui
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
@ -64,7 +69,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
|
|||||||
reuse_symbol = '\u267b\ufe0f' # ♻️
|
reuse_symbol = '\u267b\ufe0f' # ♻️
|
||||||
art_symbol = '\U0001f3a8' # 🎨
|
art_symbol = '\U0001f3a8' # 🎨
|
||||||
paste_symbol = '\u2199\ufe0f' # ↙
|
paste_symbol = '\u2199\ufe0f' # ↙
|
||||||
folder_symbol = '\uD83D\uDCC2'
|
folder_symbol = '\U0001f4c2' # 📂
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
|
||||||
@ -95,17 +100,27 @@ def send_gradio_gallery_to_image(x):
|
|||||||
|
|
||||||
def save_files(js_data, images, index):
|
def save_files(js_data, images, index):
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
os.makedirs(opts.outdir_save, exist_ok=True)
|
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
|
|
||||||
|
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
|
||||||
|
class MyObject:
|
||||||
|
def __init__(self, d=None):
|
||||||
|
if d is not None:
|
||||||
|
for key, value in d.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
data = json.loads(js_data)
|
data = json.loads(js_data)
|
||||||
|
|
||||||
|
p = MyObject(data)
|
||||||
|
path = opts.outdir_save
|
||||||
|
save_to_dirs = opts.use_save_to_dirs_for_ui
|
||||||
|
extension: str = opts.samples_format
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
|
||||||
|
|
||||||
images = [images[index]]
|
images = [images[index]]
|
||||||
infotexts = [data["infotexts"][index]]
|
start_index = index
|
||||||
else:
|
|
||||||
infotexts = data["infotexts"]
|
|
||||||
|
|
||||||
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
|
||||||
at_start = file.tell() == 0
|
at_start = file.tell() == 0
|
||||||
@ -113,28 +128,18 @@ def save_files(js_data, images, index):
|
|||||||
if at_start:
|
if at_start:
|
||||||
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
|
||||||
|
|
||||||
filename_base = str(int(time.time() * 1000))
|
for image_index, filedata in enumerate(images, start_index):
|
||||||
extension = opts.samples_format.lower()
|
|
||||||
for i, filedata in enumerate(images):
|
|
||||||
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
|
|
||||||
filepath = os.path.join(opts.outdir_save, filename)
|
|
||||||
|
|
||||||
if filedata.startswith("data:image/png;base64,"):
|
if filedata.startswith("data:image/png;base64,"):
|
||||||
filedata = filedata[len("data:image/png;base64,"):]
|
filedata = filedata[len("data:image/png;base64,"):]
|
||||||
|
|
||||||
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
|
||||||
if opts.enable_pnginfo and extension == 'png':
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
|
||||||
pnginfo.add_text('parameters', infotexts[i])
|
|
||||||
image.save(filepath, pnginfo=pnginfo)
|
|
||||||
else:
|
|
||||||
image.save(filepath, quality=opts.jpeg_quality)
|
|
||||||
|
|
||||||
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
|
is_grid = image_index < p.index_of_first_image
|
||||||
piexif.insert(piexif.dump({"Exif": {
|
i = 0 if is_grid else (image_index - p.index_of_first_image)
|
||||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
|
|
||||||
}}), filepath)
|
|
||||||
|
|
||||||
|
fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)
|
||||||
|
|
||||||
|
filename = os.path.relpath(fullfn, path)
|
||||||
filenames.append(filename)
|
filenames.append(filename)
|
||||||
|
|
||||||
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])
|
||||||
@ -142,8 +147,8 @@ def save_files(js_data, images, index):
|
|||||||
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_call(func):
|
def wrap_gradio_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
|
||||||
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
shared.mem_mon.monitor()
|
shared.mem_mon.monitor()
|
||||||
@ -159,9 +164,17 @@ def wrap_gradio_call(func):
|
|||||||
shared.state.job = ""
|
shared.state.job = ""
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
|
||||||
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
if extra_outputs_array is None:
|
||||||
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
|
||||||
|
|
||||||
elapsed = time.perf_counter() - t
|
elapsed = time.perf_counter() - t
|
||||||
|
elapsed_m = int(elapsed // 60)
|
||||||
|
elapsed_s = elapsed % 60
|
||||||
|
elapsed_text = f"{elapsed_s:.2f}s"
|
||||||
|
if (elapsed_m > 0):
|
||||||
|
elapsed_text = f"{elapsed_m}m "+elapsed_text
|
||||||
|
|
||||||
if run_memmon:
|
if run_memmon:
|
||||||
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
|
||||||
@ -176,9 +189,10 @@ def wrap_gradio_call(func):
|
|||||||
vram_html = ''
|
vram_html = ''
|
||||||
|
|
||||||
# last item is always HTML
|
# last item is always HTML
|
||||||
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
|
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
|
||||||
|
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.job_count = 0
|
||||||
|
|
||||||
return tuple(res)
|
return tuple(res)
|
||||||
|
|
||||||
@ -187,7 +201,7 @@ def wrap_gradio_call(func):
|
|||||||
|
|
||||||
def check_progress_call(id_part):
|
def check_progress_call(id_part):
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return "", gr_show(False), gr_show(False)
|
return "", gr_show(False), gr_show(False), gr_show(False)
|
||||||
|
|
||||||
progress = 0
|
progress = 0
|
||||||
|
|
||||||
@ -219,13 +233,19 @@ def check_progress_call(id_part):
|
|||||||
else:
|
else:
|
||||||
preview_visibility = gr_show(True)
|
preview_visibility = gr_show(True)
|
||||||
|
|
||||||
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
|
if shared.state.textinfo is not None:
|
||||||
|
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
|
||||||
|
else:
|
||||||
|
textinfo_result = gr_show(False)
|
||||||
|
|
||||||
|
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
|
||||||
|
|
||||||
|
|
||||||
def check_progress_call_initial(id_part):
|
def check_progress_call_initial(id_part):
|
||||||
shared.state.job_count = -1
|
shared.state.job_count = -1
|
||||||
shared.state.current_latent = None
|
shared.state.current_latent = None
|
||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
return check_progress_call(id_part)
|
return check_progress_call(id_part)
|
||||||
|
|
||||||
@ -345,11 +365,24 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
|
|||||||
outputs=[seed, dummy_component]
|
outputs=[seed, dummy_component]
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_token_counter(text):
|
|
||||||
tokens, token_count, max_length = model_hijack.tokenize(text)
|
def update_token_counter(text, steps):
|
||||||
|
try:
|
||||||
|
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||||
|
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# a parsing error can happen here during typing, and we don't want to bother the user with
|
||||||
|
# messages related to it in console
|
||||||
|
prompt_schedules = [[[steps, text]]]
|
||||||
|
|
||||||
|
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
|
||||||
|
prompts = [prompt_text for step, prompt_text in flat_prompts]
|
||||||
|
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
|
||||||
style_class = ' class="red"' if (token_count > max_length) else ""
|
style_class = ' class="red"' if (token_count > max_length) else ""
|
||||||
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
return f"<span {style_class}>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
|
||||||
def create_toprow(is_img2img):
|
def create_toprow(is_img2img):
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
@ -364,8 +397,7 @@ def create_toprow(is_img2img):
|
|||||||
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
|
||||||
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
paste = gr.Button(value=paste_symbol, elem_id="paste")
|
||||||
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
|
||||||
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
|
|
||||||
|
|
||||||
with gr.Column(scale=10, elem_id="style_pos_col"):
|
with gr.Column(scale=10, elem_id="style_pos_col"):
|
||||||
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
|
||||||
@ -380,7 +412,7 @@ def create_toprow(is_img2img):
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
|
||||||
submit = gr.Button('Generate', elem_id="generate", variant='primary')
|
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||||
|
|
||||||
interrupt.click(
|
interrupt.click(
|
||||||
fn=lambda: shared.state.interrupt(),
|
fn=lambda: shared.state.interrupt(),
|
||||||
@ -396,16 +428,19 @@ def create_toprow(is_img2img):
|
|||||||
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
|
||||||
save_style = gr.Button('Create style', elem_id="style_create")
|
save_style = gr.Button('Create style', elem_id="style_create")
|
||||||
|
|
||||||
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
|
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(progressbar, preview, id_part):
|
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
|
||||||
|
if textinfo is None:
|
||||||
|
textinfo = gr.HTML(visible=False)
|
||||||
|
|
||||||
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
|
||||||
check_progress.click(
|
check_progress.click(
|
||||||
fn=lambda: check_progress_call(id_part),
|
fn=lambda: check_progress_call(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
|
||||||
@ -413,13 +448,16 @@ def setup_progressbar(progressbar, preview, id_part):
|
|||||||
fn=lambda: check_progress_call_initial(id_part),
|
fn=lambda: check_progress_call_initial(id_part),
|
||||||
show_progress=False,
|
show_progress=False,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[progressbar, preview, preview],
|
outputs=[progressbar, preview, preview, textinfo],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
def create_ui(wrap_gradio_gpu_call):
|
||||||
|
import modules.img2img
|
||||||
|
import modules.txt2img
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
|
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
|
|
||||||
with gr.Row(elem_id='txt2img_progress_row'):
|
with gr.Row(elem_id='txt2img_progress_row'):
|
||||||
@ -483,7 +521,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=txt2img,
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
||||||
_js="submit",
|
_js="submit",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
@ -539,6 +577,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_txt2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
txt2img_prompt,
|
txt2img_prompt,
|
||||||
],
|
],
|
||||||
@ -567,9 +606,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True)
|
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
with gr.Row(elem_id='img2img_progress_row'):
|
with gr.Row(elem_id='img2img_progress_row'):
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
@ -585,7 +625,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
|
||||||
with gr.TabItem('img2img', id='img2img'):
|
with gr.TabItem('img2img', id='img2img'):
|
||||||
init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil")
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint'):
|
with gr.TabItem('Inpaint', id='inpaint'):
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA")
|
||||||
@ -599,7 +639,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
|
||||||
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
|
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
|
||||||
|
|
||||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
|
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
|
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
|
||||||
@ -607,7 +647,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
with gr.TabItem('Batch img2img', id='batch'):
|
with gr.TabItem('Batch img2img', id='batch'):
|
||||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||||
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.{hidden}</p>")
|
gr.HTML(f"<p class=\"text-gray-500\">Process images in a directory on the same machine where the server is running.<br>Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}</p>")
|
||||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
|
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs)
|
||||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
|
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
|
||||||
|
|
||||||
@ -675,7 +715,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
)
|
)
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=img2img,
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
@ -743,6 +783,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
roll.click(
|
roll.click(
|
||||||
fn=roll_artist,
|
fn=roll_artist,
|
||||||
|
_js="update_img2img_tokens",
|
||||||
inputs=[
|
inputs=[
|
||||||
img2img_prompt,
|
img2img_prompt,
|
||||||
],
|
],
|
||||||
@ -753,6 +794,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
||||||
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
|
||||||
|
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
||||||
|
|
||||||
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
||||||
button.click(
|
button.click(
|
||||||
@ -764,9 +806,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
|
||||||
)
|
)
|
||||||
|
|
||||||
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
|
for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
||||||
button.click(
|
button.click(
|
||||||
fn=apply_styles,
|
fn=apply_styles,
|
||||||
|
_js=js_func,
|
||||||
inputs=[prompt, negative_prompt, style1, style2],
|
inputs=[prompt, negative_prompt, style1, style2],
|
||||||
outputs=[prompt, negative_prompt, style1, style2],
|
outputs=[prompt, negative_prompt, style1, style2],
|
||||||
)
|
)
|
||||||
@ -789,6 +832,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
]
|
]
|
||||||
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
|
||||||
|
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row().style(equal_height=False):
|
||||||
@ -828,7 +872,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
|
||||||
|
|
||||||
submit.click(
|
submit.click(
|
||||||
fn=run_extras,
|
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
|
||||||
_js="get_extras_tab_index",
|
_js="get_extras_tab_index",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
@ -878,7 +922,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
pnginfo_send_to_img2img = gr.Button('Send to img2img')
|
||||||
|
|
||||||
image.change(
|
image.change(
|
||||||
fn=wrap_gradio_call(run_pnginfo),
|
fn=wrap_gradio_call(modules.extras.run_pnginfo),
|
||||||
inputs=[image],
|
inputs=[image],
|
||||||
outputs=[html, generation_info, html2],
|
outputs=[html, generation_info, html2],
|
||||||
)
|
)
|
||||||
@ -900,6 +944,130 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
|
with gr.Blocks() as textual_inversion_interface:
|
||||||
|
with gr.Row().style(equal_height=False):
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||||
|
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
|
||||||
|
|
||||||
|
new_embedding_name = gr.Textbox(label="Name")
|
||||||
|
initialization_text = gr.Textbox(label="Initialization text", value="*")
|
||||||
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
create_embedding = gr.Button(value="Create", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
|
||||||
|
|
||||||
|
process_src = gr.Textbox(label='Source directory')
|
||||||
|
process_dst = gr.Textbox(label='Destination directory')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
process_flip = gr.Checkbox(label='Flip')
|
||||||
|
process_split = gr.Checkbox(label='Split into two')
|
||||||
|
process_caption = gr.Checkbox(label='Add caption')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
run_preprocess = gr.Button(value="Preprocess", variant='primary')
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
|
||||||
|
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
|
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
|
||||||
|
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
|
||||||
|
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
|
||||||
|
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
|
||||||
|
steps = gr.Number(label='Max steps', value=100000, precision=0)
|
||||||
|
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
gr.HTML(value="")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with gr.Row():
|
||||||
|
interrupt_training = gr.Button(value="Interrupt")
|
||||||
|
train_embedding = gr.Button(value="Train", variant='primary')
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
progressbar = gr.HTML(elem_id="ti_progressbar")
|
||||||
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
|
|
||||||
|
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
||||||
|
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
|
||||||
|
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||||
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
|
||||||
|
|
||||||
|
create_embedding.click(
|
||||||
|
fn=modules.textual_inversion.ui.create_embedding,
|
||||||
|
inputs=[
|
||||||
|
new_embedding_name,
|
||||||
|
initialization_text,
|
||||||
|
nvpt,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_preprocess.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
process_src,
|
||||||
|
process_dst,
|
||||||
|
process_flip,
|
||||||
|
process_split,
|
||||||
|
process_caption,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_embedding.click(
|
||||||
|
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
||||||
|
_js="start_training_textual_inversion",
|
||||||
|
inputs=[
|
||||||
|
train_embedding_name,
|
||||||
|
learn_rate,
|
||||||
|
dataset_directory,
|
||||||
|
log_directory,
|
||||||
|
steps,
|
||||||
|
create_image_every,
|
||||||
|
save_embedding_every,
|
||||||
|
template_file,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
ti_output,
|
||||||
|
ti_outcome,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
interrupt_training.click(
|
||||||
|
fn=lambda: shared.state.interrupt(),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
)
|
||||||
|
|
||||||
def create_setting_component(key):
|
def create_setting_component(key):
|
||||||
def fun():
|
def fun():
|
||||||
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
return opts.data[key] if key in opts.data else opts.data_labels[key].default
|
||||||
@ -1002,6 +1170,32 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
_js='function(){}'
|
_js='function(){}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
|
||||||
|
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
|
||||||
|
|
||||||
|
|
||||||
|
def reload_scripts():
|
||||||
|
modules.scripts.reload_script_body_only()
|
||||||
|
|
||||||
|
reload_script_bodies.click(
|
||||||
|
fn=reload_scripts,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='function(){}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def request_restart():
|
||||||
|
shared.state.interrupt()
|
||||||
|
settings_interface.gradio_ref.do_restart = True
|
||||||
|
|
||||||
|
restart_gradio.click(
|
||||||
|
fn=request_restart,
|
||||||
|
inputs=[],
|
||||||
|
outputs=[],
|
||||||
|
_js='function(){restart_reload()}'
|
||||||
|
)
|
||||||
|
|
||||||
if column is not None:
|
if column is not None:
|
||||||
column.__exit__()
|
column.__exit__()
|
||||||
|
|
||||||
@ -1011,6 +1205,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
||||||
|
(textual_inversion_interface, "Textual inversion", "ti"),
|
||||||
(settings_interface, "Settings", "settings"),
|
(settings_interface, "Settings", "settings"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1027,6 +1222,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
|
|
||||||
|
settings_interface.gradio_ref = demo
|
||||||
|
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
for interface, label, ifid in interfaces:
|
for interface, label, ifid in interfaces:
|
||||||
with gr.TabItem(label, id=ifid):
|
with gr.TabItem(label, id=ifid):
|
||||||
@ -1044,11 +1241,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
|
|||||||
|
|
||||||
def modelmerger(*args):
|
def modelmerger(*args):
|
||||||
try:
|
try:
|
||||||
results = run_modelmerger(*args)
|
results = modules.extras.run_modelmerger(*args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error loading/saving model file:", file=sys.stderr)
|
print("Error loading/saving model file:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
modules.sd_models.list_models() #To remove the potentially missing models from the list
|
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -1206,12 +1403,12 @@ for filename in sorted(os.listdir(jsdir)):
|
|||||||
javascript += f"\n<script>{jsfile.read()}</script>"
|
javascript += f"\n<script>{jsfile.read()}</script>"
|
||||||
|
|
||||||
|
|
||||||
|
if 'gradio_routes_templates_response' not in globals():
|
||||||
def template_response(*args, **kwargs):
|
def template_response(*args, **kwargs):
|
||||||
res = gradio_routes_templates_response(*args, **kwargs)
|
res = gradio_routes_templates_response(*args, **kwargs)
|
||||||
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8"))
|
||||||
res.init_headers()
|
res.init_headers()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
|
||||||
gradio.routes.templates.TemplateResponse = template_response
|
gradio.routes.templates.TemplateResponse = template_response
|
||||||
|
@ -13,14 +13,13 @@ Pillow
|
|||||||
pytorch_lightning
|
pytorch_lightning
|
||||||
realesrgan
|
realesrgan
|
||||||
scikit-image>=0.19
|
scikit-image>=0.19
|
||||||
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
|
|
||||||
timm==0.4.12
|
timm==0.4.12
|
||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
torch
|
torch
|
||||||
einops
|
einops
|
||||||
jsonmerge
|
jsonmerge
|
||||||
clean-fid
|
clean-fid
|
||||||
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
|
||||||
resize-right
|
resize-right
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
kornia
|
kornia
|
||||||
|
lark
|
||||||
|
@ -18,7 +18,7 @@ piexif==1.1.3
|
|||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
jsonmerge==1.8.0
|
jsonmerge==1.8.0
|
||||||
clean-fid==0.1.29
|
clean-fid==0.1.29
|
||||||
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
|
||||||
resize-right==0.0.2
|
resize-right==0.0.2
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
|
lark==1.1.2
|
||||||
|
@ -8,7 +8,6 @@ import gradio as gr
|
|||||||
|
|
||||||
from modules import processing, shared, sd_samplers, prompt_parser
|
from modules import processing, shared, sd_samplers, prompt_parser
|
||||||
from modules.processing import Processed
|
from modules.processing import Processed
|
||||||
from modules.sd_samplers import samplers
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -159,7 +158,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
||||||
|
|
||||||
sampler = samplers[p.sampler_index].constructor(p.sd_model)
|
sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model)
|
||||||
|
|
||||||
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
||||||
|
|
||||||
|
@ -11,46 +11,8 @@ from modules import images, processing, devices
|
|||||||
from modules.processing import Processed, process_images
|
from modules.processing import Processed, process_images
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
# https://github.com/parlance-zz/g-diffuser-bot
|
|
||||||
def expand(x, dir, amount, power=0.75):
|
|
||||||
is_left = dir == 3
|
|
||||||
is_right = dir == 1
|
|
||||||
is_up = dir == 0
|
|
||||||
is_down = dir == 2
|
|
||||||
|
|
||||||
if is_left or is_right:
|
|
||||||
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
|
|
||||||
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
|
|
||||||
if is_right:
|
|
||||||
indexes = 1 - indexes
|
|
||||||
indexes = (indexes * (x.shape[1] - 1)).astype(int)
|
|
||||||
|
|
||||||
for row in range(x.shape[0]):
|
|
||||||
if is_left:
|
|
||||||
noise[row] = x[row][indexes[row]]
|
|
||||||
else:
|
|
||||||
noise[row] = np.flip(x[row][indexes[row]], axis=0)
|
|
||||||
|
|
||||||
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
if is_up or is_down:
|
|
||||||
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
|
|
||||||
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
|
|
||||||
if is_down:
|
|
||||||
indexes = 1 - indexes
|
|
||||||
indexes = (indexes * x.shape[0] - 1).astype(int)
|
|
||||||
|
|
||||||
for row in range(x.shape[1]):
|
|
||||||
if is_up:
|
|
||||||
noise[:, row] = x[:, row][indexes[row]]
|
|
||||||
else:
|
|
||||||
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
|
|
||||||
|
|
||||||
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
|
||||||
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
|
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
|
||||||
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
# helper fft routines that keep ortho normalization and auto-shift before and after fft
|
||||||
def _fft2(data):
|
def _fft2(data):
|
||||||
@ -123,8 +85,11 @@ def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.0
|
|||||||
src_dist = np.absolute(src_fft)
|
src_dist = np.absolute(src_fft)
|
||||||
src_phase = src_fft / src_dist
|
src_phase = src_fft / src_dist
|
||||||
|
|
||||||
|
# create a generator with a static seed to make outpainting deterministic / only follow global seed
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
|
||||||
noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
|
noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
|
||||||
noise_rgb = np.random.random_sample((width, height, num_channels))
|
noise_rgb = rng.random((width, height, num_channels))
|
||||||
noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
|
noise_grey = (np.sum(noise_rgb, axis=2) / 3.)
|
||||||
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
|
noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
|
||||||
for c in range(num_channels):
|
for c in range(num_channels):
|
||||||
|
@ -34,7 +34,11 @@ class Script(scripts.Script):
|
|||||||
seed = p.seed
|
seed = p.seed
|
||||||
|
|
||||||
init_img = p.init_images[0]
|
init_img = p.init_images[0]
|
||||||
|
|
||||||
|
if(upscaler.name != "None"):
|
||||||
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
|
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
|
||||||
|
else:
|
||||||
|
img = init_img
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from itertools import permutations, chain
|
||||||
import random
|
import random
|
||||||
|
import csv
|
||||||
|
from io import StringIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -29,6 +31,31 @@ def apply_prompt(p, x, xs):
|
|||||||
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
|
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_order(p, x, xs):
|
||||||
|
token_order = []
|
||||||
|
|
||||||
|
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
|
||||||
|
for token in x:
|
||||||
|
token_order.append((p.prompt.find(token), token))
|
||||||
|
|
||||||
|
token_order.sort(key=lambda t: t[0])
|
||||||
|
|
||||||
|
prompt_parts = []
|
||||||
|
|
||||||
|
# Split the prompt up, taking out the tokens
|
||||||
|
for _, token in token_order:
|
||||||
|
n = p.prompt.find(token)
|
||||||
|
prompt_parts.append(p.prompt[0:n])
|
||||||
|
p.prompt = p.prompt[n + len(token):]
|
||||||
|
|
||||||
|
# Rebuild the prompt with the tokens in the order we want
|
||||||
|
prompt_tmp = ""
|
||||||
|
for idx, part in enumerate(prompt_parts):
|
||||||
|
prompt_tmp += part
|
||||||
|
prompt_tmp += x[idx]
|
||||||
|
p.prompt = prompt_tmp + p.prompt
|
||||||
|
|
||||||
|
|
||||||
samplers_dict = {}
|
samplers_dict = {}
|
||||||
for i, sampler in enumerate(modules.sd_samplers.samplers):
|
for i, sampler in enumerate(modules.sd_samplers.samplers):
|
||||||
samplers_dict[sampler.name.lower()] = i
|
samplers_dict[sampler.name.lower()] = i
|
||||||
@ -60,16 +87,26 @@ def format_value_add_label(p, opt, x):
|
|||||||
def format_value(p, opt, x):
|
def format_value(p, opt, x):
|
||||||
if type(x) == float:
|
if type(x) == float:
|
||||||
x = round(x, 8)
|
x = round(x, 8)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def format_value_join_list(p, opt, x):
|
||||||
|
return ", ".join(x)
|
||||||
|
|
||||||
|
|
||||||
def do_nothing(p, x, xs):
|
def do_nothing(p, x, xs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def format_nothing(p, opt, x):
|
def format_nothing(p, opt, x):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def str_permutations(x):
|
||||||
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
||||||
|
|
||||||
@ -82,6 +119,7 @@ axis_options = [
|
|||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
||||||
|
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
AxisOption("Sampler", str, apply_sampler, format_value),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
|
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
|
||||||
@ -159,7 +197,7 @@ class Script(scripts.Script):
|
|||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
valslist = [x.strip() for x in vals.split(",")]
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
|
||||||
|
|
||||||
if opt.type == int:
|
if opt.type == int:
|
||||||
valslist_ext = []
|
valslist_ext = []
|
||||||
@ -206,6 +244,8 @@ class Script(scripts.Script):
|
|||||||
valslist_ext.append(val)
|
valslist_ext.append(val)
|
||||||
|
|
||||||
valslist = valslist_ext
|
valslist = valslist_ext
|
||||||
|
elif opt.type == str_permutations:
|
||||||
|
valslist = list(permutations(valslist))
|
||||||
|
|
||||||
valslist = [opt.type(x) for x in valslist]
|
valslist = [opt.type(x) for x in valslist]
|
||||||
|
|
||||||
|
16
style.css
16
style.css
@ -23,7 +23,7 @@
|
|||||||
text-align: right;
|
text-align: right;
|
||||||
}
|
}
|
||||||
|
|
||||||
#generate{
|
#txt2img_generate, #img2img_generate {
|
||||||
min-height: 4.5em;
|
min-height: 4.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ button{
|
|||||||
max-width: 10em;
|
max-width: 10em;
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview, #img2img_preview{
|
#txt2img_preview, #img2img_preview, #ti_preview{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
width: 320px;
|
width: 320px;
|
||||||
left: 0;
|
left: 0;
|
||||||
@ -172,18 +172,18 @@ button{
|
|||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (min-width: 768px) {
|
@media screen and (min-width: 768px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: absolute;
|
position: absolute;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@media screen and (max-width: 767px) {
|
@media screen and (max-width: 767px) {
|
||||||
#txt2img_preview, #img2img_preview {
|
#txt2img_preview, #img2img_preview, #ti_preview {
|
||||||
position: relative;
|
position: relative;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
|
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ input[type="range"]{
|
|||||||
#txt2img_negative_prompt, #img2img_negative_prompt{
|
#txt2img_negative_prompt, #img2img_negative_prompt{
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_progressbar, #img2img_progressbar{
|
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
|
||||||
position: absolute;
|
position: absolute;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
right: 0;
|
right: 0;
|
||||||
@ -407,3 +407,7 @@ input[type="range"]{
|
|||||||
.gallery-item {
|
.gallery-item {
|
||||||
--tw-bg-opacity: 0 !important;
|
--tw-bg-opacity: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#img2img_image div.h-60{
|
||||||
|
height: 480px;
|
||||||
|
}
|
19
textual_inversion_templates/style.txt
Normal file
19
textual_inversion_templates/style.txt
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
a painting, art by [name]
|
||||||
|
a rendering, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
the painting, art by [name]
|
||||||
|
a clean painting, art by [name]
|
||||||
|
a dirty painting, art by [name]
|
||||||
|
a dark painting, art by [name]
|
||||||
|
a picture, art by [name]
|
||||||
|
a cool painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a bright painting, art by [name]
|
||||||
|
a cropped painting, art by [name]
|
||||||
|
a good painting, art by [name]
|
||||||
|
a close-up painting, art by [name]
|
||||||
|
a rendition, art by [name]
|
||||||
|
a nice painting, art by [name]
|
||||||
|
a small painting, art by [name]
|
||||||
|
a weird painting, art by [name]
|
||||||
|
a large painting, art by [name]
|
19
textual_inversion_templates/style_filewords.txt
Normal file
19
textual_inversion_templates/style_filewords.txt
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
a painting of [filewords], art by [name]
|
||||||
|
a rendering of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
the painting of [filewords], art by [name]
|
||||||
|
a clean painting of [filewords], art by [name]
|
||||||
|
a dirty painting of [filewords], art by [name]
|
||||||
|
a dark painting of [filewords], art by [name]
|
||||||
|
a picture of [filewords], art by [name]
|
||||||
|
a cool painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a bright painting of [filewords], art by [name]
|
||||||
|
a cropped painting of [filewords], art by [name]
|
||||||
|
a good painting of [filewords], art by [name]
|
||||||
|
a close-up painting of [filewords], art by [name]
|
||||||
|
a rendition of [filewords], art by [name]
|
||||||
|
a nice painting of [filewords], art by [name]
|
||||||
|
a small painting of [filewords], art by [name]
|
||||||
|
a weird painting of [filewords], art by [name]
|
||||||
|
a large painting of [filewords], art by [name]
|
27
textual_inversion_templates/subject.txt
Normal file
27
textual_inversion_templates/subject.txt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
a photo of a [name]
|
||||||
|
a rendering of a [name]
|
||||||
|
a cropped photo of the [name]
|
||||||
|
the photo of a [name]
|
||||||
|
a photo of a clean [name]
|
||||||
|
a photo of a dirty [name]
|
||||||
|
a dark photo of the [name]
|
||||||
|
a photo of my [name]
|
||||||
|
a photo of the cool [name]
|
||||||
|
a close-up photo of a [name]
|
||||||
|
a bright photo of the [name]
|
||||||
|
a cropped photo of a [name]
|
||||||
|
a photo of the [name]
|
||||||
|
a good photo of the [name]
|
||||||
|
a photo of one [name]
|
||||||
|
a close-up photo of the [name]
|
||||||
|
a rendition of the [name]
|
||||||
|
a photo of the clean [name]
|
||||||
|
a rendition of a [name]
|
||||||
|
a photo of a nice [name]
|
||||||
|
a good photo of a [name]
|
||||||
|
a photo of the nice [name]
|
||||||
|
a photo of the small [name]
|
||||||
|
a photo of the weird [name]
|
||||||
|
a photo of the large [name]
|
||||||
|
a photo of a cool [name]
|
||||||
|
a photo of a small [name]
|
27
textual_inversion_templates/subject_filewords.txt
Normal file
27
textual_inversion_templates/subject_filewords.txt
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
a photo of a [name], [filewords]
|
||||||
|
a rendering of a [name], [filewords]
|
||||||
|
a cropped photo of the [name], [filewords]
|
||||||
|
the photo of a [name], [filewords]
|
||||||
|
a photo of a clean [name], [filewords]
|
||||||
|
a photo of a dirty [name], [filewords]
|
||||||
|
a dark photo of the [name], [filewords]
|
||||||
|
a photo of my [name], [filewords]
|
||||||
|
a photo of the cool [name], [filewords]
|
||||||
|
a close-up photo of a [name], [filewords]
|
||||||
|
a bright photo of the [name], [filewords]
|
||||||
|
a cropped photo of a [name], [filewords]
|
||||||
|
a photo of the [name], [filewords]
|
||||||
|
a good photo of the [name], [filewords]
|
||||||
|
a photo of one [name], [filewords]
|
||||||
|
a close-up photo of the [name], [filewords]
|
||||||
|
a rendition of the [name], [filewords]
|
||||||
|
a photo of the clean [name], [filewords]
|
||||||
|
a rendition of a [name], [filewords]
|
||||||
|
a photo of a nice [name], [filewords]
|
||||||
|
a good photo of a [name], [filewords]
|
||||||
|
a photo of the nice [name], [filewords]
|
||||||
|
a photo of the small [name], [filewords]
|
||||||
|
a photo of the weird [name], [filewords]
|
||||||
|
a photo of the large [name], [filewords]
|
||||||
|
a photo of a cool [name], [filewords]
|
||||||
|
a photo of a small [name], [filewords]
|
54
webui.py
54
webui.py
@ -1,34 +1,35 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from modules import devices
|
import importlib
|
||||||
from modules.paths import script_path
|
|
||||||
import signal
|
import signal
|
||||||
import threading
|
import threading
|
||||||
import modules.paths
|
|
||||||
|
from modules.paths import script_path
|
||||||
|
|
||||||
|
from modules import devices, sd_samplers
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.esrgan_model as esrgan
|
|
||||||
import modules.bsrgan_model as bsrgan
|
|
||||||
import modules.extras
|
import modules.extras
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.gfpgan_model as gfpgan
|
import modules.gfpgan_model as gfpgan
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
import modules.ldsr_model as ldsr
|
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
import modules.realesrgan_model as realesrgan
|
import modules.paths
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.swinir_model as swinir
|
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
|
|
||||||
import modules.ui
|
import modules.ui
|
||||||
|
from modules import devices
|
||||||
from modules import modelloader
|
from modules import modelloader
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
|
modules.sd_models.setup_model()
|
||||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
@ -46,7 +47,7 @@ def wrap_queued_call(func):
|
|||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func):
|
def wrap_gradio_gpu_call(func, extra_outputs=None):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -58,6 +59,7 @@ def wrap_gradio_gpu_call(func):
|
|||||||
shared.state.current_image = None
|
shared.state.current_image = None
|
||||||
shared.state.current_image_sampling_step = 0
|
shared.state.current_image_sampling_step = 0
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
shared.state.textinfo = None
|
||||||
|
|
||||||
with queue_lock:
|
with queue_lock:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
@ -69,7 +71,7 @@ def wrap_gradio_gpu_call(func):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return modules.ui.wrap_gradio_call(f)
|
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
|
||||||
|
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
@ -86,13 +88,9 @@ def webui():
|
|||||||
|
|
||||||
signal.signal(signal.SIGINT, sigint_handler)
|
signal.signal(signal.SIGINT, sigint_handler)
|
||||||
|
|
||||||
demo = modules.ui.create_ui(
|
while 1:
|
||||||
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
|
|
||||||
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
|
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
|
||||||
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
|
|
||||||
run_pnginfo=modules.extras.run_pnginfo,
|
|
||||||
run_modelmerger=modules.extras.run_modelmerger
|
|
||||||
)
|
|
||||||
|
|
||||||
demo.launch(
|
demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
@ -101,8 +99,26 @@ def webui():
|
|||||||
debug=cmd_opts.gradio_debug,
|
debug=cmd_opts.gradio_debug,
|
||||||
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
|
||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
|
prevent_thread_lock=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
while 1:
|
||||||
|
time.sleep(0.5)
|
||||||
|
if getattr(demo, 'do_restart', False):
|
||||||
|
time.sleep(0.5)
|
||||||
|
demo.close()
|
||||||
|
time.sleep(0.5)
|
||||||
|
break
|
||||||
|
|
||||||
|
sd_samplers.set_samplers()
|
||||||
|
|
||||||
|
print('Reloading Custom Scripts')
|
||||||
|
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
print('Reloading modules: modules.ui')
|
||||||
|
importlib.reload(modules.ui)
|
||||||
|
print('Restarting Gradio')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
webui()
|
webui()
|
||||||
|
Loading…
Reference in New Issue
Block a user