add images.read to automatically fix all jpeg/png weirdness

This commit is contained in:
AUTOMATIC1111 2024-03-04 19:14:53 +03:00
parent 5625ce1b1a
commit 09b5ce68a9
6 changed files with 41 additions and 72 deletions

View File

@ -85,8 +85,7 @@ def decode_base64_to_image(encoding):
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
response = requests.get(encoding, timeout=30, headers=headers) response = requests.get(encoding, timeout=30, headers=headers)
try: try:
image = Image.open(BytesIO(response.content)) image = images.read(BytesIO(response.content))
image = images.apply_exif_orientation(image)
return image return image
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Invalid image url") from e raise HTTPException(status_code=500, detail="Invalid image url") from e
@ -94,8 +93,7 @@ def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"): if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1] encoding = encoding.split(";")[1].split(",")[1]
try: try:
image = Image.open(BytesIO(base64.b64decode(encoding))) image = images.read(BytesIO(base64.b64decode(encoding)))
image = images.apply_exif_orientation(image)
return image return image
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail="Invalid encoded image") from e raise HTTPException(status_code=500, detail="Invalid encoded image") from e

View File

@ -12,7 +12,7 @@ import re
import numpy as np import numpy as np
import piexif import piexif
import piexif.helper import piexif.helper
from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
import string import string
import json import json
import hashlib import hashlib
@ -551,12 +551,6 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
else: else:
pnginfo_data = None pnginfo_data = None
# Error handling for unsupported transparency in RGB mode
if (image.mode == "RGB" and
"transparency" in image.info and
isinstance(image.info["transparency"], bytes)):
del image.info["transparency"]
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
elif extension.lower() in (".jpg", ".jpeg", ".webp"): elif extension.lower() in (".jpg", ".jpeg", ".webp"):
@ -779,7 +773,7 @@ def image_data(data):
import gradio as gr import gradio as gr
try: try:
image = Image.open(io.BytesIO(data)) image = read(io.BytesIO(data))
textinfo, _ = read_info_from_image(image) textinfo, _ = read_info_from_image(image)
return textinfo, None return textinfo, None
except Exception: except Exception:
@ -807,51 +801,29 @@ def flatten(img, bgcolor):
return img.convert('RGB') return img.convert('RGB')
# https://www.exiv2.org/tags.html def read(fp, **kwargs):
_EXIF_ORIENT = 274 # exif 'Orientation' tag image = Image.open(fp, **kwargs)
image = fix_image(image)
def apply_exif_orientation(image):
"""
Applies the exif orientation correctly.
This code exists per the bug:
https://github.com/python-pillow/Pillow/issues/3973
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
various methods, especially `tobytes`
Function based on:
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
Args:
image (PIL.Image): a PIL image
Returns:
(PIL.Image): the PIL image with exif orientation applied, if applicable
"""
if not hasattr(image, "getexif"):
return image return image
def fix_image(image: Image.Image):
if image is None:
return None
try: try:
exif = image.getexif() image = ImageOps.exif_transpose(image)
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885 image = fix_png_transparency(image)
exif = None except Exception:
pass
if exif is None:
return image return image
orientation = exif.get(_EXIF_ORIENT)
method = { def fix_png_transparency(image: Image.Image):
2: Image.FLIP_LEFT_RIGHT, if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
3: Image.ROTATE_180, return image
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE, image = image.convert("RGBA")
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
return image.transpose(method)
return image return image

View File

@ -6,7 +6,7 @@ import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr import gradio as gr
from modules import images as imgutil from modules import images
from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
@ -21,7 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
output_dir = output_dir.strip() output_dir = output_dir.strip()
processing.fix_seed(p) processing.fix_seed(p)
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff"))) batch_images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
is_inpaint_batch = False is_inpaint_batch = False
if inpaint_mask_dir: if inpaint_mask_dir:
@ -31,9 +31,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
if is_inpaint_batch: if is_inpaint_batch:
print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
state.job_count = len(images) * p.n_iter state.job_count = len(batch_images) * p.n_iter
# extract "default" params to use in case getting png info fails # extract "default" params to use in case getting png info fails
prompt = p.prompt prompt = p.prompt
@ -46,8 +46,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None)) sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
batch_results = None batch_results = None
discard_further_results = False discard_further_results = False
for i, image in enumerate(images): for i, image in enumerate(batch_images):
state.job = f"{i+1} out of {len(images)}" state.job = f"{i+1} out of {len(batch_images)}"
if state.skipped: if state.skipped:
state.skipped = False state.skipped = False
@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
break break
try: try:
img = Image.open(image) img = images.read(image)
except UnidentifiedImageError as e: except UnidentifiedImageError as e:
print(e) print(e)
continue continue
@ -86,7 +86,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
# otherwise user has many masks with the same name but different extensions # otherwise user has many masks with the same name but different extensions
mask_image_path = masks_found[0] mask_image_path = masks_found[0]
mask_image = Image.open(mask_image_path) mask_image = images.read(mask_image_path)
p.image_mask = mask_image p.image_mask = mask_image
if use_png_info: if use_png_info:
@ -94,8 +94,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
info_img = img info_img = img
if png_info_dir: if png_info_dir:
info_img_path = os.path.join(png_info_dir, os.path.basename(image)) info_img_path = os.path.join(png_info_dir, os.path.basename(image))
info_img = Image.open(info_img_path) info_img = images.read(info_img_path)
geninfo, _ = imgutil.read_info_from_image(info_img) geninfo, _ = images.read_info_from_image(info_img)
parsed_parameters = parse_generation_parameters(geninfo) parsed_parameters = parse_generation_parameters(geninfo)
parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})} parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
except Exception: except Exception:
@ -175,9 +175,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
image = None image = None
mask = None mask = None
# Use the EXIF orientation of photos taken by smartphones. image = images.fix_image(image)
if image is not None: mask = images.fix_image(mask)
image = ImageOps.exif_transpose(image)
if selected_scale_tab == 1 and not is_batch: if selected_scale_tab == 1 and not is_batch:
assert image, "Can't scale by because no image is selected" assert image, "Can't scale by because no image is selected"

View File

@ -8,7 +8,7 @@ import sys
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images
from PIL import Image from PIL import Image
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
@ -83,7 +83,7 @@ def image_from_url_text(filedata):
assert is_in_right_dir, 'trying to open image file outside of allowed directories' assert is_in_right_dir, 'trying to open image file outside of allowed directories'
filename = filename.rsplit('?', 1)[0] filename = filename.rsplit('?', 1)[0]
return Image.open(filename) return images.read(filename)
if type(filedata) == list: if type(filedata) == list:
if len(filedata) == 0: if len(filedata) == 0:
@ -95,7 +95,7 @@ def image_from_url_text(filedata):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
filedata = base64.decodebytes(filedata.encode('utf-8')) filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata)) image = images.read(io.BytesIO(filedata))
return image return image

View File

@ -17,10 +17,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1: if extras_mode == 1:
for img in image_folder: for img in image_folder:
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
image = img image = images.fix_image(img)
fn = '' fn = ''
else: else:
image = Image.open(os.path.abspath(img.name)) image = images.read(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0] fn = os.path.splitext(img.orig_name)[0]
yield image, fn yield image, fn
elif extras_mode == 2: elif extras_mode == 2:
@ -56,7 +56,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if isinstance(image_placeholder, str): if isinstance(image_placeholder, str):
try: try:
image_data = Image.open(image_placeholder) image_data = images.read(image_placeholder)
except Exception: except Exception:
continue continue
else: else:

View File

@ -10,7 +10,7 @@ from random import shuffle, choices
import random import random
import tqdm import tqdm
from modules import devices, shared from modules import devices, shared, images
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
@ -61,7 +61,7 @@ class PersonalizedBase(Dataset):
if shared.state.interrupted: if shared.state.interrupted:
raise Exception("interrupted") raise Exception("interrupted")
try: try:
image = Image.open(path) image = images.read(path)
#Currently does not work for single color transparency #Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that #We would need to read image.info['transparency'] for that
if use_weight and 'A' in image.getbands(): if use_weight and 'A' in image.getbands():