2022-09-13 05:44:08 +08:00
import datetime
2022-09-03 17:08:45 +08:00
import math
import os
from collections import namedtuple
import re
import numpy as np
2022-09-12 19:40:02 +08:00
import piexif
import piexif . helper
2022-09-03 17:08:45 +08:00
from PIL import Image , ImageFont , ImageDraw , PngImagePlugin
2022-09-14 00:53:42 +08:00
from fonts . ttf import Roboto
2022-09-10 13:45:16 +08:00
import string
2022-09-03 17:08:45 +08:00
2022-09-04 23:54:12 +08:00
import modules . shared
2022-09-13 01:47:46 +08:00
from modules import sd_samplers , shared
2022-09-15 09:05:00 +08:00
from modules . shared import opts , cmd_opts
2022-09-03 17:08:45 +08:00
LANCZOS = ( Image . Resampling . LANCZOS if hasattr ( Image , ' Resampling ' ) else Image . LANCZOS )
def image_grid ( imgs , batch_size = 1 , rows = None ) :
if rows is None :
if opts . n_rows > 0 :
rows = opts . n_rows
elif opts . n_rows == 0 :
rows = batch_size
else :
rows = math . sqrt ( len ( imgs ) )
rows = round ( rows )
cols = math . ceil ( len ( imgs ) / rows )
w , h = imgs [ 0 ] . size
grid = Image . new ( ' RGB ' , size = ( cols * w , rows * h ) , color = ' black ' )
for i , img in enumerate ( imgs ) :
grid . paste ( img , box = ( i % cols * w , i / / cols * h ) )
return grid
Grid = namedtuple ( " Grid " , [ " tiles " , " tile_w " , " tile_h " , " image_w " , " image_h " , " overlap " ] )
def split_grid ( image , tile_w = 512 , tile_h = 512 , overlap = 64 ) :
w = image . width
h = image . height
2022-09-04 06:29:43 +08:00
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
2022-09-03 17:08:45 +08:00
2022-09-04 06:29:43 +08:00
cols = math . ceil ( ( w - overlap ) / non_overlap_width )
rows = math . ceil ( ( h - overlap ) / non_overlap_height )
2022-09-04 23:54:12 +08:00
dx = ( w - tile_w ) / ( cols - 1 ) if cols > 1 else 0
dy = ( h - tile_h ) / ( rows - 1 ) if rows > 1 else 0
2022-09-03 17:08:45 +08:00
grid = Grid ( [ ] , tile_w , tile_h , w , h , overlap )
for row in range ( rows ) :
row_images = [ ]
2022-09-04 23:54:12 +08:00
y = int ( row * dy )
2022-09-03 17:08:45 +08:00
if y + tile_h > = h :
y = h - tile_h
for col in range ( cols ) :
2022-09-04 23:54:12 +08:00
x = int ( col * dx )
2022-09-03 17:08:45 +08:00
if x + tile_w > = w :
x = w - tile_w
tile = image . crop ( ( x , y , x + tile_w , y + tile_h ) )
row_images . append ( [ x , tile_w , tile ] )
grid . tiles . append ( [ y , tile_h , row_images ] )
return grid
def combine_grid ( grid ) :
def make_mask_image ( r ) :
r = r * 255 / grid . overlap
r = r . astype ( np . uint8 )
return Image . fromarray ( r , ' L ' )
mask_w = make_mask_image ( np . arange ( grid . overlap , dtype = np . float32 ) . reshape ( ( 1 , grid . overlap ) ) . repeat ( grid . tile_h , axis = 0 ) )
mask_h = make_mask_image ( np . arange ( grid . overlap , dtype = np . float32 ) . reshape ( ( grid . overlap , 1 ) ) . repeat ( grid . image_w , axis = 1 ) )
combined_image = Image . new ( " RGB " , ( grid . image_w , grid . image_h ) )
for y , h , row in grid . tiles :
combined_row = Image . new ( " RGB " , ( grid . image_w , h ) )
for x , w , tile in row :
if x == 0 :
combined_row . paste ( tile , ( 0 , 0 ) )
continue
combined_row . paste ( tile . crop ( ( 0 , 0 , grid . overlap , h ) ) , ( x , 0 ) , mask = mask_w )
combined_row . paste ( tile . crop ( ( grid . overlap , 0 , w , h ) ) , ( x + grid . overlap , 0 ) )
if y == 0 :
combined_image . paste ( combined_row , ( 0 , 0 ) )
continue
combined_image . paste ( combined_row . crop ( ( 0 , 0 , combined_row . width , grid . overlap ) ) , ( 0 , y ) , mask = mask_h )
combined_image . paste ( combined_row . crop ( ( 0 , grid . overlap , combined_row . width , h ) ) , ( 0 , y + grid . overlap ) )
return combined_image
class GridAnnotation :
def __init__ ( self , text = ' ' , is_active = True ) :
self . text = text
self . is_active = is_active
self . size = None
def draw_grid_annotations ( im , width , height , hor_texts , ver_texts ) :
def wrap ( drawing , text , font , line_length ) :
lines = [ ' ' ]
for word in text . split ( ) :
line = f ' { lines [ - 1 ] } { word } ' . strip ( )
if drawing . textlength ( line , font = font ) < = line_length :
lines [ - 1 ] = line
else :
lines . append ( word )
return lines
def draw_texts ( drawing , draw_x , draw_y , lines ) :
for i , line in enumerate ( lines ) :
drawing . multiline_text ( ( draw_x , draw_y + line . size [ 1 ] / 2 ) , line . text , font = fnt , fill = color_active if line . is_active else color_inactive , anchor = " mm " , align = " center " )
if not line . is_active :
drawing . line ( ( draw_x - line . size [ 0 ] / / 2 , draw_y + line . size [ 1 ] / / 2 , draw_x + line . size [ 0 ] / / 2 , draw_y + line . size [ 1 ] / / 2 ) , fill = color_inactive , width = 4 )
draw_y + = line . size [ 1 ] + line_spacing
fontsize = ( width + height ) / / 25
line_spacing = fontsize / / 2
2022-09-13 00:17:02 +08:00
try :
fnt = ImageFont . truetype ( opts . font or Roboto , fontsize )
except Exception :
fnt = ImageFont . truetype ( Roboto , fontsize )
2022-09-03 17:08:45 +08:00
color_active = ( 0 , 0 , 0 )
color_inactive = ( 153 , 153 , 153 )
2022-09-09 22:54:04 +08:00
pad_left = 0 if sum ( [ sum ( [ len ( line . text ) for line in lines ] ) for lines in ver_texts ] ) == 0 else width * 3 / / 4
2022-09-03 17:08:45 +08:00
cols = im . width / / width
rows = im . height / / height
assert cols == len ( hor_texts ) , f ' bad number of horizontal texts: { len ( hor_texts ) } ; must be { cols } '
assert rows == len ( ver_texts ) , f ' bad number of vertical texts: { len ( ver_texts ) } ; must be { rows } '
calc_img = Image . new ( " RGB " , ( 1 , 1 ) , " white " )
calc_d = ImageDraw . Draw ( calc_img )
for texts , allowed_width in zip ( hor_texts + ver_texts , [ width ] * len ( hor_texts ) + [ pad_left ] * len ( ver_texts ) ) :
items = [ ] + texts
texts . clear ( )
for line in items :
wrapped = wrap ( calc_d , line . text , fnt , allowed_width )
texts + = [ GridAnnotation ( x , line . is_active ) for x in wrapped ]
for line in texts :
bbox = calc_d . multiline_textbbox ( ( 0 , 0 ) , line . text , font = fnt )
line . size = ( bbox [ 2 ] - bbox [ 0 ] , bbox [ 3 ] - bbox [ 1 ] )
hor_text_heights = [ sum ( [ line . size [ 1 ] + line_spacing for line in lines ] ) - line_spacing for lines in hor_texts ]
ver_text_heights = [ sum ( [ line . size [ 1 ] + line_spacing for line in lines ] ) - line_spacing * len ( lines ) for lines in ver_texts ]
pad_top = max ( hor_text_heights ) + line_spacing * 2
result = Image . new ( " RGB " , ( im . width + pad_left , im . height + pad_top ) , " white " )
result . paste ( im , ( pad_left , pad_top ) )
d = ImageDraw . Draw ( result )
for col in range ( cols ) :
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_heights [ col ] / 2
draw_texts ( d , x , y , hor_texts [ col ] )
for row in range ( rows ) :
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights [ row ] / 2
draw_texts ( d , x , y , ver_texts [ row ] )
return result
def draw_prompt_matrix ( im , width , height , all_prompts ) :
prompts = all_prompts [ 1 : ]
boundary = math . ceil ( len ( prompts ) / 2 )
prompts_horiz = prompts [ : boundary ]
prompts_vert = prompts [ boundary : ]
hor_texts = [ [ GridAnnotation ( x , is_active = pos & ( 1 << i ) != 0 ) for i , x in enumerate ( prompts_horiz ) ] for pos in range ( 1 << len ( prompts_horiz ) ) ]
ver_texts = [ [ GridAnnotation ( x , is_active = pos & ( 1 << i ) != 0 ) for i , x in enumerate ( prompts_vert ) ] for pos in range ( 1 << len ( prompts_vert ) ) ]
return draw_grid_annotations ( im , width , height , hor_texts , ver_texts )
def resize_image ( resize_mode , im , width , height ) :
if resize_mode == 0 :
res = im . resize ( ( width , height ) , resample = LANCZOS )
elif resize_mode == 1 :
ratio = width / height
src_ratio = im . width / im . height
src_w = width if ratio > src_ratio else im . width * height / / im . height
src_h = height if ratio < = src_ratio else im . height * width / / im . width
resized = im . resize ( ( src_w , src_h ) , resample = LANCZOS )
res = Image . new ( " RGB " , ( width , height ) )
res . paste ( resized , box = ( width / / 2 - src_w / / 2 , height / / 2 - src_h / / 2 ) )
else :
ratio = width / height
src_ratio = im . width / im . height
src_w = width if ratio < src_ratio else im . width * height / / im . height
src_h = height if ratio > = src_ratio else im . height * width / / im . width
resized = im . resize ( ( src_w , src_h ) , resample = LANCZOS )
res = Image . new ( " RGB " , ( width , height ) )
res . paste ( resized , box = ( width / / 2 - src_w / / 2 , height / / 2 - src_h / / 2 ) )
if ratio < src_ratio :
fill_height = height / / 2 - src_h / / 2
res . paste ( resized . resize ( ( width , fill_height ) , box = ( 0 , 0 , width , 0 ) ) , box = ( 0 , 0 ) )
res . paste ( resized . resize ( ( width , fill_height ) , box = ( 0 , resized . height , width , resized . height ) ) , box = ( 0 , fill_height + src_h ) )
elif ratio > src_ratio :
fill_width = width / / 2 - src_w / / 2
res . paste ( resized . resize ( ( fill_width , height ) , box = ( 0 , 0 , 0 , height ) ) , box = ( 0 , 0 ) )
res . paste ( resized . resize ( ( fill_width , height ) , box = ( resized . width , 0 , resized . width , height ) ) , box = ( fill_width + src_w , 0 ) )
return res
invalid_filename_chars = ' <>: " / \\ |?* \n '
2022-09-10 13:45:16 +08:00
re_nonletters = re . compile ( r ' [ \ s ' + string . punctuation + ' ]+ ' )
2022-09-03 17:08:45 +08:00
2022-09-12 20:41:30 +08:00
def sanitize_filename_part ( text , replace_spaces = True ) :
if replace_spaces :
text = text . replace ( ' ' , ' _ ' )
2022-09-03 17:08:45 +08:00
2022-09-16 07:47:37 +08:00
return text . translate ( { ord ( x ) : ' _ ' for x in invalid_filename_chars } ) [ : 128 ]
2022-09-03 17:08:45 +08:00
2022-09-12 20:41:30 +08:00
2022-09-13 05:44:08 +08:00
def apply_filename_pattern ( x , p , seed , prompt ) :
if seed is not None :
x = x . replace ( " [seed] " , str ( seed ) )
if prompt is not None :
x = x . replace ( " [prompt] " , sanitize_filename_part ( prompt ) [ : 128 ] )
x = x . replace ( " [prompt_spaces] " , sanitize_filename_part ( prompt , replace_spaces = False ) [ : 128 ] )
if " [prompt_words] " in x :
words = [ x for x in re_nonletters . split ( prompt or " " ) if len ( x ) > 0 ]
if len ( words ) == 0 :
words = [ " empty " ]
x = x . replace ( " [prompt_words] " , " " . join ( words [ 0 : 8 ] ) . strip ( ) )
if p is not None :
x = x . replace ( " [steps] " , str ( p . steps ) )
x = x . replace ( " [cfg] " , str ( p . cfg_scale ) )
x = x . replace ( " [width] " , str ( p . width ) )
x = x . replace ( " [height] " , str ( p . height ) )
x = x . replace ( " [sampler] " , sd_samplers . samplers [ p . sampler_index ] . name )
2022-09-17 17:05:04 +08:00
x = x . replace ( " [model_hash] " , shared . sd_model . sd_model_hash )
2022-09-13 05:44:08 +08:00
x = x . replace ( " [date] " , datetime . date . today ( ) . isoformat ( ) )
2022-09-15 09:05:00 +08:00
if cmd_opts . hide_ui_dir_config :
x = re . sub ( r ' ^[ \\ /]+| \ . { 2,}[ \\ /]+|[ \\ /]+ \ . { 2,} ' , ' ' , x )
2022-09-13 05:44:08 +08:00
return x
2022-09-14 20:40:16 +08:00
def get_next_sequence_number ( path , basename ) :
2022-09-13 22:43:08 +08:00
"""
Determines and returns the next sequence number to use when saving an image in the specified directory .
The sequence starts at 0.
"""
result = - 1
2022-09-14 20:40:16 +08:00
if basename != ' ' :
basename = basename + " - "
prefix_length = len ( basename )
2022-09-13 22:43:08 +08:00
for p in os . listdir ( path ) :
2022-09-14 20:40:16 +08:00
if p . startswith ( basename ) :
l = os . path . splitext ( p [ prefix_length : ] ) [ 0 ] . split ( ' - ' ) #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try :
2022-09-13 23:46:05 +08:00
result = max ( int ( l [ 0 ] ) , result )
2022-09-14 20:40:16 +08:00
except ValueError :
2022-09-13 23:46:05 +08:00
pass
2022-09-13 22:43:08 +08:00
return result + 1
2022-09-13 05:44:08 +08:00
2022-09-14 04:28:03 +08:00
def save_image ( image , path , basename , seed = None , prompt = None , extension = ' png ' , info = None , short_filename = False , no_prompt = False , grid = False , pnginfo_section_name = ' parameters ' , p = None , existing_info = None ) :
2022-09-03 17:08:45 +08:00
if short_filename or prompt is None or seed is None :
file_decoration = " "
elif opts . save_to_dirs :
2022-09-13 05:44:08 +08:00
file_decoration = opts . samples_filename_pattern or " [seed] "
2022-09-03 17:08:45 +08:00
else :
2022-09-13 05:44:08 +08:00
file_decoration = opts . samples_filename_pattern or " [seed]-[prompt_spaces] "
2022-09-11 23:17:13 +08:00
2022-09-12 23:59:53 +08:00
if file_decoration != " " :
file_decoration = " - " + file_decoration . lower ( )
2022-09-13 05:44:08 +08:00
file_decoration = apply_filename_pattern ( file_decoration , p , seed , prompt )
2022-09-13 01:47:46 +08:00
2022-09-03 17:08:45 +08:00
if extension == ' png ' and opts . enable_pnginfo and info is not None :
pnginfo = PngImagePlugin . PngInfo ( )
2022-09-12 23:59:53 +08:00
if existing_info is not None :
for k , v in existing_info . items ( ) :
2022-09-13 12:34:35 +08:00
pnginfo . add_text ( k , str ( v ) )
2022-09-12 23:59:53 +08:00
2022-09-11 16:31:16 +08:00
pnginfo . add_text ( pnginfo_section_name , info )
2022-09-03 17:08:45 +08:00
else :
pnginfo = None
2022-09-14 04:28:03 +08:00
save_to_dirs = ( grid and opts . grid_save_to_dirs ) or ( not grid and opts . save_to_dirs and not no_prompt )
2022-09-10 18:36:16 +08:00
2022-09-13 05:44:08 +08:00
if save_to_dirs :
dirname = apply_filename_pattern ( opts . directories_filename_pattern or " [prompt_words] " , p , seed , prompt )
2022-09-03 17:08:45 +08:00
path = os . path . join ( path , dirname )
os . makedirs ( path , exist_ok = True )
2022-09-14 20:40:16 +08:00
basecount = get_next_sequence_number ( path , basename )
2022-09-03 17:08:45 +08:00
fullfn = " a.png "
fullfn_without_extension = " a "
2022-09-06 16:51:34 +08:00
for i in range ( 500 ) :
2022-09-13 22:43:08 +08:00
fn = f " { basecount + i : 05 } " if basename == ' ' else f " { basename } - { basecount + i : 04 } "
2022-09-03 17:08:45 +08:00
fullfn = os . path . join ( path , f " { fn } { file_decoration } . { extension } " )
fullfn_without_extension = os . path . join ( path , f " { fn } { file_decoration } " )
if not os . path . exists ( fullfn ) :
break
2022-09-17 13:32:15 +08:00
def exif_bytes ( ) :
2022-09-15 19:54:29 +08:00
return piexif . dump ( {
2022-09-14 00:23:55 +08:00
" Exif " : {
2022-09-15 19:54:29 +08:00
piexif . ExifIFD . UserComment : piexif . helper . UserComment . dump ( info or " " , encoding = " unicode " )
2022-09-14 18:38:40 +08:00
} ,
2022-09-14 00:23:55 +08:00
} )
2022-09-15 19:54:29 +08:00
if extension . lower ( ) in ( " jpg " , " jpeg " , " webp " ) :
2022-09-17 13:32:15 +08:00
image . save ( fullfn , quality = opts . jpeg_quality )
if opts . enable_pnginfo and info is not None :
piexif . insert ( exif_bytes ( ) , fullfn )
2022-09-15 19:54:29 +08:00
else :
image . save ( fullfn , quality = opts . jpeg_quality , pnginfo = pnginfo )
2022-09-03 17:08:45 +08:00
target_side_length = 4000
oversize = image . width > target_side_length or image . height > target_side_length
if opts . export_for_4chan and ( oversize or os . stat ( fullfn ) . st_size > 4 * 1024 * 1024 ) :
ratio = image . width / image . height
if oversize and ratio > 1 :
image = image . resize ( ( target_side_length , image . height * target_side_length / / image . width ) , LANCZOS )
elif oversize :
image = image . resize ( ( image . width * target_side_length / / image . height , target_side_length ) , LANCZOS )
2022-09-17 13:32:15 +08:00
image . save ( fullfn_without_extension + " .jpg " , quality = opts . jpeg_quality )
if opts . enable_pnginfo and info is not None :
2022-09-17 20:39:20 +08:00
piexif . insert ( exif_bytes ( ) , fullfn_without_extension + " .jpg " )
2022-09-03 17:08:45 +08:00
if opts . save_txt and info is not None :
with open ( f " { fullfn_without_extension } .txt " , " w " , encoding = " utf8 " ) as file :
file . write ( info + " \n " )
2022-09-04 23:54:12 +08:00
class Upscaler :
name = " Lanczos "
def do_upscale ( self , img ) :
return img
def upscale ( self , img , w , h ) :
for i in range ( 3 ) :
if img . width > = w and img . height > = h :
break
img = self . do_upscale ( img )
if img . width != w or img . height != h :
2022-09-12 21:17:32 +08:00
img = img . resize ( ( int ( w ) , int ( h ) ) , resample = LANCZOS )
2022-09-04 23:54:12 +08:00
return img
class UpscalerNone ( Upscaler ) :
name = " None "
def upscale ( self , img , w , h ) :
return img
modules . shared . sd_upscalers . append ( UpscalerNone ( ) )
modules . shared . sd_upscalers . append ( Upscaler ( ) )