save image generation params into text chunks for png images

This commit is contained in:
AUTOMATIC 2022-08-24 17:57:49 +03:00
parent da96bbf485
commit 29f7e7ab89

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import autocast from torch import autocast
@ -49,11 +49,13 @@ parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=(
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long") parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--save-format", type=str, default='png', help="file format for saved indiviual samples; can be png or jpg") parser.add_argument("--save-format", type=str, default='png', help="file format for saved indiviual samples; can be png or jpg")
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg") parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
parser.add_argument("--grid-extended-filename", action='store_true', help="save grid images to filenames with extended info: seed, prompt") parser.add_argument("--grid-extended-filename", action='store_true', help="save grid images to filenames with extended info: seed, prompt")
parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images") parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images")
parser.add_argument("--disable-pnginfo", action='store_true', help="disable saving text information about generation parameters as chunks to png files")
opt = parser.parse_args() opt = parser.parse_args()
GFPGAN_dir = opt.gfpgan_dir GFPGAN_dir = opt.gfpgan_dir
@ -141,7 +143,7 @@ def sanitize_filename_part(text):
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128] return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
def save_image(image, path, basename, seed, prompt, extension, short_filename=False): def save_image(image, path, basename, seed, prompt, extension, info=None, short_filename=False):
prompt = sanitize_filename_part(prompt) prompt = sanitize_filename_part(prompt)
if short_filename: if short_filename:
@ -149,7 +151,13 @@ def save_image(image, path, basename, seed, prompt, extension, short_filename=Fa
else: else:
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}" filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
image.save(os.path.join(path, filename), quality=opt.jpeg_quality) if extension == 'png' and not opt.disable_pnginfo:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info)
else:
pnginfo = None
image.save(os.path.join(path, filename), quality=opt.jpeg_quality, pnginfo=pnginfo)
def load_GFPGAN(): def load_GFPGAN():
@ -373,6 +381,11 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
all_prompts = batch_size * n_iter * [prompt] all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))] all_seeds = [seed + x for x in range(len(all_prompts))]
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() + "".join(["\n\n" + x for x in comments])
precision_scope = autocast if opt.precision == "autocast" else nullcontext precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = [] output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
@ -407,7 +420,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
x_sample = restored_img x_sample = restored_img
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opt.save_format) save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opt.save_format, info=info)
output_images.append(image) output_images.append(image)
base_count += 1 base_count += 1
@ -426,16 +439,9 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
output_images.insert(0, grid) output_images.insert(0, grid)
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opt.grid_format, short_filename=not opt.grid_extended_filename) save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
grid_count += 1 grid_count += 1
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
for comment in comments:
info += "\n\n" + comment
torch_gc() torch_gc()
return output_images, seed, info return output_images, seed, info
@ -619,7 +625,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1) grid = image_grid(history, batch_size, force_n_rows=1)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opt.grid_format, short_filename=not opt.grid_extended_filename) save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename)
output_images = history output_images = history
seed = initial_seed seed = initial_seed