mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 12:25:06 +08:00
textual inversion embeddings support
settings tab
This commit is contained in:
parent
ec8a252260
commit
91dc8710ec
30
README.md
30
README.md
@ -1,7 +1,7 @@
|
|||||||
# Stable Diffusion web UI
|
# Stable Diffusion web UI
|
||||||
A browser interface based on Gradio library for Stable Diffusion.
|
A browser interface based on Gradio library for Stable Diffusion.
|
||||||
|
|
||||||
Original script with Gradio UI was written by a kind anonymopus user. This is a modification.
|
Original script with Gradio UI was written by a kind anonymous user. This is a modification.
|
||||||
|
|
||||||
![](screenshot.png)
|
![](screenshot.png)
|
||||||
## Installing and running
|
## Installing and running
|
||||||
@ -128,7 +128,7 @@ Example:
|
|||||||
Gradio's loading graphic has a very negative effect on the processing speed of the neural network.
|
Gradio's loading graphic has a very negative effect on the processing speed of the neural network.
|
||||||
My RTX 3090 makes images about 10% faster when the tab with gradio is not active. By default, the UI
|
My RTX 3090 makes images about 10% faster when the tab with gradio is not active. By default, the UI
|
||||||
now hides loading progress animation and replaces it with static "Loading..." text, which achieves
|
now hides loading progress animation and replaces it with static "Loading..." text, which achieves
|
||||||
the same effect. Use the --no-progressbar-hiding commandline option to revert this and show loading animations.
|
the same effect. Use the `--no-progressbar-hiding` commandline option to revert this and show loading animations.
|
||||||
|
|
||||||
### Prompt validation
|
### Prompt validation
|
||||||
Stable Diffusion has a limit for input text length. If your prompt is too long, you will get a
|
Stable Diffusion has a limit for input text length. If your prompt is too long, you will get a
|
||||||
@ -152,6 +152,28 @@ Adds information about generation parameters to PNG as a text chunk. You
|
|||||||
can view this information later using any software that supports viewing
|
can view this information later using any software that supports viewing
|
||||||
PNG chunk info, for example: https://www.nayuki.io/page/png-file-chunk-inspector
|
PNG chunk info, for example: https://www.nayuki.io/page/png-file-chunk-inspector
|
||||||
|
|
||||||
This can be disabled using the `--disable-pnginfo` command line option.
|
|
||||||
|
|
||||||
![](images/pnginfo.png)
|
![](images/pnginfo.png)
|
||||||
|
|
||||||
|
### Textual Inversion
|
||||||
|
Allows you to use pretrained textual inversion embeddings.
|
||||||
|
See originial site for details: https://textual-inversion.github.io/.
|
||||||
|
I used lstein's repo for training embdedding: https://github.com/lstein/stable-diffusion; if
|
||||||
|
you want to train your own, I recommend following the guide on his site.
|
||||||
|
|
||||||
|
No additional libraries/repositories are required to use pretrained embeddings.
|
||||||
|
|
||||||
|
To make use of pretrained embeddings, create `embeddings` directory in the root dir of Stable
|
||||||
|
Diffusion and put your embeddings into it. They must be .pt files about 5Kb in size, each with only
|
||||||
|
one trained embedding, and the filename (without .pt) will be the term you'd use in prompt
|
||||||
|
to get that embedding.
|
||||||
|
|
||||||
|
As an example, I trained one for about 5000 steps: https://files.catbox.moe/e2ui6r.pt; it does
|
||||||
|
not produce very good results, but it does work. Download and rename it to `Usada Pekora.pt`,
|
||||||
|
and put it into `embeddings` dir and use Usada Pekora in prompt.
|
||||||
|
|
||||||
|
![](images/inversion.png)
|
||||||
|
|
||||||
|
### Settings
|
||||||
|
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
|
||||||
|
were commandline. Settings are saved to config.js file. Settings that remain as commandline
|
||||||
|
options are ones that are required at startup.
|
||||||
|
BIN
images/inversion.png
Normal file
BIN
images/inversion.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 678 KiB |
397
webui.py
397
webui.py
@ -8,17 +8,19 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import html
|
import html
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
import k_diffusion as K
|
import k_diffusion as K
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
import ldm.modules.encoders.modules
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
@ -38,30 +40,18 @@ opt_f = 8
|
|||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
invalid_filename_chars = '<>:"/\|?*\n'
|
invalid_filename_chars = '<>:"/\|?*\n'
|
||||||
|
config_filename = "config.json"
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
|
|
||||||
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
|
|
||||||
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
|
|
||||||
parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; use -1 for autodetect and 0 for n_rows to be same as batch_size (default: -1)",)
|
|
||||||
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
|
||||||
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
|
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
|
||||||
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
|
||||||
parser.add_argument("--save-format", type=str, default='png', help="file format for saved indiviual samples; can be png or jpg")
|
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
|
||||||
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
|
|
||||||
parser.add_argument("--grid-extended-filename", action='store_true', help="save grid images to filenames with extended info: seed, prompt")
|
|
||||||
parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images")
|
|
||||||
parser.add_argument("--disable-pnginfo", action='store_true', help="disable saving text information about generation parameters as chunks to png files")
|
|
||||||
|
|
||||||
parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo")
|
cmd_opts = parser.parse_args()
|
||||||
opt = parser.parse_args()
|
|
||||||
|
|
||||||
GFPGAN_dir = opt.gfpgan_dir
|
|
||||||
|
|
||||||
css_hide_progressbar = """
|
css_hide_progressbar = """
|
||||||
.wrap .m-12 svg { display:none!important; }
|
.wrap .m-12 svg { display:none!important; }
|
||||||
@ -70,6 +60,49 @@ css_hide_progressbar = """
|
|||||||
.meta-text { display:none!important; }
|
.meta-text { display:none!important; }
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Options:
|
||||||
|
data = None
|
||||||
|
data_labels = {
|
||||||
|
"outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"),
|
||||||
|
"samples_save": (True, "Save indiviual samples"),
|
||||||
|
"samples_format": ('png', 'File format for indiviual samples'),
|
||||||
|
"grid_save": (True, "Save image grids"),
|
||||||
|
"grid_format": ('png', 'File format for grids'),
|
||||||
|
"grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||||
|
"n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16),
|
||||||
|
"jpeg_quality": (80, "Quality for saved jpeg images", 1, 100),
|
||||||
|
"verify_input": (True, "Check input, and produce warning if it's too long"),
|
||||||
|
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.data = {k: v[0] for k, v in self.data_labels.items()}
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if self.data is not None:
|
||||||
|
if key in self.data:
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
return super(Options, self).__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
if self.data is not None:
|
||||||
|
if item in self.data:
|
||||||
|
return self.data[item]
|
||||||
|
|
||||||
|
return super(Options, self).__getattribute__(item)
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
with open(filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(self.data, file)
|
||||||
|
|
||||||
|
def load(self, filename):
|
||||||
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
def chunk(it, size):
|
def chunk(it, size):
|
||||||
it = iter(it)
|
it = iter(it)
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
@ -154,13 +187,13 @@ def save_image(image, path, basename, seed, prompt, extension, info=None, short_
|
|||||||
else:
|
else:
|
||||||
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
|
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
|
||||||
|
|
||||||
if extension == 'png' and not opt.disable_pnginfo:
|
if extension == 'png' and opts.enable_pnginfo:
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
pnginfo.add_text("parameters", info)
|
pnginfo.add_text("parameters", info)
|
||||||
else:
|
else:
|
||||||
pnginfo = None
|
pnginfo = None
|
||||||
|
|
||||||
image.save(os.path.join(path, filename), quality=opt.jpeg_quality, pnginfo=pnginfo)
|
image.save(os.path.join(path, filename), quality=opts.jpeg_quality, pnginfo=pnginfo)
|
||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
@ -170,39 +203,22 @@ def plaintext_to_html(text):
|
|||||||
|
|
||||||
def load_GFPGAN():
|
def load_GFPGAN():
|
||||||
model_name = 'GFPGANv1.3'
|
model_name = 'GFPGANv1.3'
|
||||||
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
|
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
|
||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise Exception("GFPGAN model not found at path "+model_path)
|
raise Exception("GFPGAN model not found at path "+model_path)
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(GFPGAN_dir))
|
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
|
||||||
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||||
|
|
||||||
|
|
||||||
GFPGAN = None
|
|
||||||
if os.path.exists(GFPGAN_dir):
|
|
||||||
try:
|
|
||||||
GFPGAN = load_GFPGAN()
|
|
||||||
print("Loaded GFPGAN")
|
|
||||||
except Exception:
|
|
||||||
import traceback
|
|
||||||
print("Error loading GFPGAN:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
config = OmegaConf.load(opt.config)
|
|
||||||
model = load_model_from_config(config, opt.ckpt)
|
|
||||||
|
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
||||||
model = (model if opt.no_half else model.half()).to(device)
|
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
|
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
|
||||||
if force_n_rows is not None:
|
if force_n_rows is not None:
|
||||||
rows = force_n_rows
|
rows = force_n_rows
|
||||||
elif opt.n_rows > 0:
|
elif opts.n_rows > 0:
|
||||||
rows = opt.n_rows
|
rows = opts.n_rows
|
||||||
elif opt.n_rows == 0:
|
elif opts.n_rows == 0:
|
||||||
rows = batch_size
|
rows = batch_size
|
||||||
else:
|
else:
|
||||||
rows = math.sqrt(len(imgs))
|
rows = math.sqrt(len(imgs))
|
||||||
@ -353,6 +369,163 @@ def wrap_gradio_call(func):
|
|||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
GFPGAN = None
|
||||||
|
if os.path.exists(cmd_opts.gfpgan_dir):
|
||||||
|
try:
|
||||||
|
GFPGAN = load_GFPGAN()
|
||||||
|
print("Loaded GFPGAN")
|
||||||
|
except Exception:
|
||||||
|
print("Error loading GFPGAN:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
class TextInversionEmbeddings:
|
||||||
|
ids_lookup = {}
|
||||||
|
word_embeddings = {}
|
||||||
|
word_embeddings_checksums = {}
|
||||||
|
fixes = []
|
||||||
|
used_custom_terms = []
|
||||||
|
dir_mtime = None
|
||||||
|
|
||||||
|
def load(self, dir, model):
|
||||||
|
mt = os.path.getmtime(dir)
|
||||||
|
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.dir_mtime = mt
|
||||||
|
self.ids_lookup.clear()
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
|
||||||
|
tokenizer = model.cond_stage_model.tokenizer
|
||||||
|
|
||||||
|
def const_hash(a):
|
||||||
|
r = 0
|
||||||
|
for v in a:
|
||||||
|
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
||||||
|
return r
|
||||||
|
|
||||||
|
def process_file(path, filename):
|
||||||
|
name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
|
data = torch.load(path)
|
||||||
|
param_dict = data['string_to_param']
|
||||||
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
|
emb = next(iter(param_dict.items()))[1].reshape(768)
|
||||||
|
self.word_embeddings[name] = emb
|
||||||
|
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
|
||||||
|
|
||||||
|
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
|
||||||
|
first_id = ids[0]
|
||||||
|
if first_id not in self.ids_lookup:
|
||||||
|
self.ids_lookup[first_id] = []
|
||||||
|
self.ids_lookup[first_id].append((ids, name))
|
||||||
|
|
||||||
|
for fn in os.listdir(dir):
|
||||||
|
try:
|
||||||
|
process_file(os.path.join(dir, fn), fn)
|
||||||
|
except:
|
||||||
|
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
|
||||||
|
|
||||||
|
def hijack(self, m):
|
||||||
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
|
||||||
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
|
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||||
|
def __init__(self, wrapped, embeddings):
|
||||||
|
super().__init__()
|
||||||
|
self.wrapped = wrapped
|
||||||
|
self.embeddings = embeddings
|
||||||
|
self.tokenizer = wrapped.tokenizer
|
||||||
|
self.max_length = wrapped.max_length
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
self.embeddings.fixes = []
|
||||||
|
self.embeddings.used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
|
id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
maxlen = self.wrapped.max_length - 2
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
for tokens in batch_tokens:
|
||||||
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
|
if tuple_tokens in cache:
|
||||||
|
remade_tokens, fixes = cache[tuple_tokens]
|
||||||
|
else:
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
possible_matches = self.embeddings.ids_lookup.get(token, None)
|
||||||
|
|
||||||
|
if possible_matches is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
else:
|
||||||
|
found = False
|
||||||
|
for ids, word in possible_matches:
|
||||||
|
if tokens[i:i+len(ids)] == ids:
|
||||||
|
fixes.append((len(remade_tokens), word))
|
||||||
|
remade_tokens.append(777)
|
||||||
|
i += len(ids) - 1
|
||||||
|
found = True
|
||||||
|
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
|
||||||
|
cache[tuple_tokens] = (remade_tokens, fixes)
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
self.embeddings.fixes.append(fixes)
|
||||||
|
|
||||||
|
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
|
||||||
|
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsWithFixes(nn.Module):
|
||||||
|
def __init__(self, wrapped, embeddings):
|
||||||
|
super().__init__()
|
||||||
|
self.wrapped = wrapped
|
||||||
|
self.embeddings = embeddings
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
batch_fixes = self.embeddings.fixes
|
||||||
|
self.embeddings.fixes = []
|
||||||
|
|
||||||
|
inputs_embeds = self.wrapped(input_ids)
|
||||||
|
|
||||||
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
|
for offset, word in fixes:
|
||||||
|
tensor[offset] = self.embeddings.word_embeddings[word]
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning_with_embeddings(model, prompts):
|
||||||
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
|
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
||||||
|
|
||||||
|
return model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
@ -392,7 +565,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if not opt.no_verify_input:
|
if opts.verify_input:
|
||||||
try:
|
try:
|
||||||
check_prompt_length(prompt, comments)
|
check_prompt_length(prompt, comments)
|
||||||
except:
|
except:
|
||||||
@ -403,27 +576,29 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
all_prompts = batch_size * n_iter * [prompt]
|
all_prompts = batch_size * n_iter * [prompt]
|
||||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||||
|
|
||||||
info = f"""
|
def infotext():
|
||||||
|
return f"""
|
||||||
{prompt}
|
{prompt}
|
||||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
|
||||||
""".strip() + "".join(["\n\n" + x for x in comments])
|
""".strip() + "".join(["\n\n" + x for x in comments])
|
||||||
|
|
||||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
|
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
|
with torch.no_grad(), autocast("cuda"), model.ema_scope():
|
||||||
init_data = func_init()
|
init_data = func_init()
|
||||||
|
|
||||||
for n in range(n_iter):
|
for n in range(n_iter):
|
||||||
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
|
||||||
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
|
||||||
|
|
||||||
uc = None
|
|
||||||
if cfg_scale != 1.0:
|
|
||||||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||||
if isinstance(prompts, tuple):
|
|
||||||
prompts = list(prompts)
|
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
if len(text_inversion_embeddings.used_custom_terms) > 0:
|
||||||
|
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||||
|
|
||||||
@ -432,7 +607,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
if prompt_matrix or not opt.skip_save or not opt.skip_grid:
|
if prompt_matrix or opts.samples_save or opts.grid_save:
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
@ -442,12 +617,12 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
x_sample = restored_img
|
x_sample = restored_img
|
||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opt.save_format, info=info)
|
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
|
||||||
|
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
base_count += 1
|
base_count += 1
|
||||||
|
|
||||||
if (prompt_matrix or not opt.skip_grid) and not do_not_save_grid:
|
if (prompt_matrix or opts.grid_save) and not do_not_save_grid:
|
||||||
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
|
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
|
||||||
|
|
||||||
if prompt_matrix:
|
if prompt_matrix:
|
||||||
@ -461,23 +636,17 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
|
|
||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
|
|
||||||
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
|
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return output_images, seed, info
|
return output_images, seed, infotext()
|
||||||
|
|
||||||
|
|
||||||
def load_embeddings(fp):
|
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
|
||||||
if fp is not None and hasattr(model, "embedding_manager"):
|
outpath = opts.outdir or "outputs/txt2img-samples"
|
||||||
# load the file
|
|
||||||
model.embedding_manager.load(fp.name)
|
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, embeddings_fp):
|
|
||||||
outpath = opt.outdir or "outputs/txt2img-samples"
|
|
||||||
|
|
||||||
load_embeddings(embeddings_fp)
|
|
||||||
|
|
||||||
if sampler_name == 'PLMS':
|
if sampler_name == 'PLMS':
|
||||||
sampler = PLMSSampler(model)
|
sampler = PLMSSampler(model)
|
||||||
@ -567,29 +736,25 @@ txt2img_interface = gr.Interface(
|
|||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
|
||||||
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
||||||
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Gallery(label="Images"),
|
gr.Gallery(label="Images"),
|
||||||
gr.Number(label='Seed'),
|
gr.Number(label='Seed'),
|
||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
title="Stable Diffusion Text-to-Image K",
|
title="Stable Diffusion Text-to-Image",
|
||||||
description="Generate images from text with Stable Diffusion (using K-LMS)",
|
|
||||||
flagging_callback=Flagging()
|
flagging_callback=Flagging()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, embeddings_fp):
|
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||||
outpath = opt.outdir or "outputs/img2img-samples"
|
outpath = opts.outdir or "outputs/img2img-samples"
|
||||||
|
|
||||||
load_embeddings(embeddings_fp)
|
|
||||||
|
|
||||||
sampler = KDiffusionSampler(model)
|
sampler = KDiffusionSampler(model)
|
||||||
|
|
||||||
@ -658,7 +823,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
|
|||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
grid = image_grid(history, batch_size, force_n_rows=1)
|
grid = image_grid(history, batch_size, force_n_rows=1)
|
||||||
|
|
||||||
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
|
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
output_images = history
|
output_images = history
|
||||||
seed = initial_seed
|
seed = initial_seed
|
||||||
@ -698,15 +863,14 @@ img2img_interface = gr.Interface(
|
|||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
|
||||||
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
|
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
|
||||||
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
|
||||||
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
|
||||||
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
|
||||||
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
|
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
|
||||||
gr.Number(label='Seed', value=-1),
|
gr.Number(label='Seed', value=-1),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
|
||||||
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
|
||||||
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"),
|
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
|
||||||
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
gr.Gallery(),
|
gr.Gallery(),
|
||||||
@ -714,15 +878,9 @@ img2img_interface = gr.Interface(
|
|||||||
gr.HTML(),
|
gr.HTML(),
|
||||||
],
|
],
|
||||||
title="Stable Diffusion Image-to-Image",
|
title="Stable Diffusion Image-to-Image",
|
||||||
description="Generate images from images with Stable Diffusion",
|
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
)
|
)
|
||||||
|
|
||||||
interfaces = [
|
|
||||||
(txt2img_interface, "txt2img"),
|
|
||||||
(img2img_interface, "img2img")
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_GFPGAN(image, strength):
|
def run_GFPGAN(image, strength):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
@ -735,8 +893,7 @@ def run_GFPGAN(image, strength):
|
|||||||
return res, 0, ''
|
return res, 0, ''
|
||||||
|
|
||||||
|
|
||||||
if GFPGAN is not None:
|
gfpgan_interface = gr.Interface(
|
||||||
interfaces.append((gr.Interface(
|
|
||||||
run_GFPGAN,
|
run_GFPGAN,
|
||||||
inputs=[
|
inputs=[
|
||||||
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
|
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
|
||||||
@ -750,13 +907,83 @@ if GFPGAN is not None:
|
|||||||
title="GFPGAN",
|
title="GFPGAN",
|
||||||
description="Fix faces on images",
|
description="Fix faces on images",
|
||||||
allow_flagging="never",
|
allow_flagging="never",
|
||||||
), "GFPGAN"))
|
)
|
||||||
|
|
||||||
|
opts = Options()
|
||||||
|
if os.path.exists(config_filename):
|
||||||
|
opts.load(config_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def run_settings(*args):
|
||||||
|
up = []
|
||||||
|
|
||||||
|
for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
|
||||||
|
opts.data[key] = value
|
||||||
|
up.append(comp.update(value=value))
|
||||||
|
|
||||||
|
opts.save(config_filename)
|
||||||
|
|
||||||
|
return 'Settings saved.', ''
|
||||||
|
|
||||||
|
|
||||||
|
def create_setting_component(key):
|
||||||
|
def fun():
|
||||||
|
return opts.data[key] if key in opts.data else opts.data_labels[key][0]
|
||||||
|
|
||||||
|
labelinfo = opts.data_labels[key]
|
||||||
|
t = type(labelinfo[0])
|
||||||
|
label = labelinfo[1]
|
||||||
|
if t == str:
|
||||||
|
item = gr.Textbox(label=label, value=fun, lines=1)
|
||||||
|
elif t == int:
|
||||||
|
if len(labelinfo) == 4:
|
||||||
|
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
|
||||||
|
else:
|
||||||
|
item = gr.Number(label=label, value=fun)
|
||||||
|
elif t == bool:
|
||||||
|
item = gr.Checkbox(label=label, value=fun)
|
||||||
|
else:
|
||||||
|
raise Exception(f'bad options item type: {str(t)} for key {key}')
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
settings_interface = gr.Interface(
|
||||||
|
run_settings,
|
||||||
|
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
|
||||||
|
outputs=[
|
||||||
|
gr.Textbox(label='Result'),
|
||||||
|
gr.HTML(),
|
||||||
|
],
|
||||||
|
title=None,
|
||||||
|
description=None,
|
||||||
|
allow_flagging="never",
|
||||||
|
)
|
||||||
|
|
||||||
|
interfaces = [
|
||||||
|
(txt2img_interface, "txt2img"),
|
||||||
|
(img2img_interface, "img2img"),
|
||||||
|
(gfpgan_interface, "GFPGAN"),
|
||||||
|
(settings_interface, "Settings"),
|
||||||
|
]
|
||||||
|
|
||||||
|
config = OmegaConf.load(cmd_opts.config)
|
||||||
|
model = load_model_from_config(config, cmd_opts.ckpt)
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = (model if cmd_opts.no_half else model.half()).to(device)
|
||||||
|
text_inversion_embeddings = TextInversionEmbeddings()
|
||||||
|
|
||||||
|
if os.path.exists(cmd_opts.embeddings_dir):
|
||||||
|
text_inversion_embeddings.hijack(model)
|
||||||
|
|
||||||
|
if GFPGAN is None:
|
||||||
|
interfaces = [x for x in interfaces if x[0] != gfpgan_interface]
|
||||||
|
|
||||||
demo = gr.TabbedInterface(
|
demo = gr.TabbedInterface(
|
||||||
interface_list=[x[0] for x in interfaces],
|
interface_list=[x[0] for x in interfaces],
|
||||||
tab_names=[x[1] for x in interfaces],
|
tab_names=[x[1] for x in interfaces],
|
||||||
css=("" if opt.no_progressbar_hiding else css_hide_progressbar) + """
|
css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """
|
||||||
.output-html p {margin: 0 0.5em;}
|
.output-html p {margin: 0 0.5em;}
|
||||||
.performance { font-size: 0.85em; color: #444; }
|
.performance { font-size: 0.85em; color: #444; }
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user