Prompt matrix now draws text like in demo.

This commit is contained in:
AUTOMATIC 2022-08-23 18:04:13 +03:00
parent 61bfa6c16b
commit cb118c4036
4 changed files with 90 additions and 51 deletions

View File

@ -75,17 +75,17 @@ Pick out of three sampling methods for txt2img: DDIM, PLMS, k-diffusion:
### Prompt matrix ### Prompt matrix
Separate multiple prompts using the `|` character, and the system will produce an image for every combination of them. Separate multiple prompts using the `|` character, and the system will produce an image for every combination of them.
For example, if you use `a house in a field of grass|at dawn|illustration` prompt, there are four combinations possible (first part of prompt is always kept): For example, if you use `a busy city street in a modern city|illustration|cinematic lighting` prompt, there are four combinations possible (first part of prompt is always kept):
- `a house in a field of grass` - `a busy city street in a modern city`
- `a house in a field of grass, at dawn` - `a busy city street in a modern city, illustration`
- `a house in a field of grass, illustration` - `a busy city street in a modern city, cinematic lighting`
- `a house in a field of grass, at dawn, illustration` - `a busy city street in a modern city, illustration, cinematic lighting`
Four images will be produced, in this order, all with same seed and each with corresponding prompt: Four images will be produced, in this order, all with same seed and each with corresponding prompt:
![](images/prompt-matrix.png) ![](images/prompt-matrix.png)
Another example, this time with 5 prompts and 16 variations, (text added manually): Another example, this time with 5 prompts and 16 variations:
![](images/prompt_matrix.jpg) ![](images/prompt_matrix.jpg)
### Flagging ### Flagging

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 770 KiB

After

Width:  |  Height:  |  Size: 1.2 MiB

127
webui.py
View File

@ -1,11 +1,10 @@
import PIL
import argparse, os, sys, glob import argparse, os, sys, glob
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image, ImageFont, ImageDraw
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import autocast from torch import autocast
@ -76,23 +75,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model return model
def load_img_pil(img_pil):
image = img_pil.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
print(f"cropped image to size ({w}, {h})")
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2. * image - 1.
def load_img(path):
return load_img_pil(Image.open(path))
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
@ -179,6 +161,71 @@ def image_grid(imgs, batch_size, round_down=False):
return grid return grid
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
if d.textlength(line, font=font) <= line_length:
lines[-1] = line
else:
lines.append(word)
return '\n'.join(lines)
def draw_texts(pos, x, y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0
if not active:
text = '\u0336'.join(text) + '\u0336'
d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
y += size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
fnt = ImageFont.truetype("arial.ttf", fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
pad_top = height // 4
pad_left = width * 3 // 4
cols = im.width // width
rows = im.height // height
prompts = all_prompts[1:]
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top))
d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
for col in range(cols):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_height / 2
draw_texts(col, x, y, prompts_horiz, sizes_hor)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_height / 2
draw_texts(row, x, y, prompts_vert, sizes_ver)
return result
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -212,30 +259,23 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
prompt_matrix_prompts = [] prompt_matrix_prompts = []
comment = "" prompt_matrix_parts = []
if prompt_matrix: if prompt_matrix:
keep_same_seed = True keep_same_seed = True
comment = "Image prompts:\n\n"
items = prompt.split("|") prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(items)-1) combination_count = 2 ** (len(prompt_matrix_parts)-1)
for combination_num in range(combination_count): for combination_num in range(combination_count):
current = items[0] current = prompt_matrix_parts[0]
label = 'A'
for n, text in enumerate(items[1:]): for n, text in enumerate(prompt_matrix_parts[1:]):
if combination_num & (2**n) > 0: if combination_num & (2**n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text current += ("" if text.strip().startswith(",") else ", ") + text
label += chr(ord('B') + n)
comment += " - " + label + "\n"
prompt_matrix_prompts.append(current) prompt_matrix_prompts.append(current)
n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size) n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
comment += "\nwhere:\n" print(f"Prompt matrix will create {len(prompt_matrix_prompts)} images using a total of {n_iter} batches.")
for n, text in enumerate(items):
comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
precision_scope = autocast if opt.precision == "autocast" else nullcontext precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = [] output_images = []
@ -262,7 +302,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save or not opt.skip_grid: if prompt_matrix or not opt.skip_save or not opt.skip_grid:
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -279,14 +319,16 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
output_images.append(image) output_images.append(image)
base_count += 1 base_count += 1
if not opt.skip_grid: if prompt_matrix or not opt.skip_grid:
# additionally, save as grid
grid = image_grid(output_images, batch_size, round_down=prompt_matrix) grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
if prompt_matrix:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
output_images.insert(0, grid)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
if sampler is not None:
del sampler del sampler
info = f""" info = f"""
@ -294,9 +336,6 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() """.strip()
if len(comment) > 0:
info += "\n\n" + comment
return output_images, seed, info return output_images, seed, info
class Flagging(gr.FlaggingCallback): class Flagging(gr.FlaggingCallback):
@ -350,7 +389,7 @@ dream_interface = gr.Interface(
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1), gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
@ -389,7 +428,7 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
image = init_img.convert("RGB") image = init_img.convert("RGB")
image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS) image = image.resize((width, height), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
@ -466,7 +505,7 @@ img2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1), gr.Number(label='Seed', value=-1),
@ -494,7 +533,7 @@ def run_GFPGAN(image, strength):
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if strength < 1.0: if strength < 1.0:
res = PIL.Image.blend(image, res, strength) res = Image.blend(image, res, strength)
return res return res