2022-09-22 17:11:48 +08:00
import math
import modules . scripts as scripts
import gradio as gr
from PIL import Image
from modules import processing , shared , sd_samplers , images , devices
from modules . processing import Processed
from modules . shared import opts , cmd_opts , state
class Script ( scripts . Script ) :
def title ( self ) :
return " SD upscale "
def show ( self , is_img2img ) :
return is_img2img
def ui ( self , is_img2img ) :
2023-01-05 05:03:32 +08:00
elem_prefix = ( ' i2i ' if is_img2img else ' t2i ' ) + ' _script_sd_upscale_ '
2022-11-29 04:28:22 +08:00
info = gr . HTML ( " <p style= \" margin-bottom:0.75em \" >Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p> " )
2023-01-05 05:03:32 +08:00
overlap = gr . Slider ( minimum = 0 , maximum = 256 , step = 16 , label = ' Tile overlap ' , value = 64 , elem_id = elem_prefix + " overlap " )
scale_factor = gr . Slider ( minimum = 1.0 , maximum = 4.0 , step = 0.05 , label = ' Scale Factor ' , value = 2.0 , elem_id = elem_prefix + " scale_factor " )
upscaler_index = gr . Radio ( label = ' Upscaler ' , choices = [ x . name for x in shared . sd_upscalers ] , value = shared . sd_upscalers [ 0 ] . name , type = " index " , elem_id = elem_prefix + " upscaler_index " )
2022-09-22 17:11:48 +08:00
2022-11-29 04:24:53 +08:00
return [ info , overlap , upscaler_index , scale_factor ]
2022-09-22 17:11:48 +08:00
2022-11-29 04:24:53 +08:00
def run ( self , p , _ , overlap , upscaler_index , scale_factor ) :
2022-09-22 17:11:48 +08:00
processing . fix_seed ( p )
upscaler = shared . sd_upscalers [ upscaler_index ]
p . extra_generation_params [ " SD upscale overlap " ] = overlap
p . extra_generation_params [ " SD upscale upscaler " ] = upscaler . name
initial_info = None
seed = p . seed
init_img = p . init_images [ 0 ]
2022-12-25 14:47:24 +08:00
init_img = images . flatten ( init_img , opts . img2img_background_color )
2022-11-29 04:24:53 +08:00
2022-12-25 14:47:24 +08:00
if upscaler . name != " None " :
2022-11-29 04:24:53 +08:00
img = upscaler . scaler . upscale ( init_img , scale_factor , upscaler . data_path )
2022-10-02 22:10:41 +08:00
else :
img = init_img
2022-09-22 17:11:48 +08:00
devices . torch_gc ( )
grid = images . split_grid ( img , tile_w = p . width , tile_h = p . height , overlap = overlap )
batch_size = p . batch_size
upscale_count = p . n_iter
p . n_iter = 1
p . do_not_save_grid = True
p . do_not_save_samples = True
work = [ ]
for y , h , row in grid . tiles :
for tiledata in row :
work . append ( tiledata [ 2 ] )
batch_count = math . ceil ( len ( work ) / batch_size )
state . job_count = batch_count * upscale_count
2022-11-29 04:28:22 +08:00
print ( f " SD upscaling will process a total of { len ( work ) } images tiled as { len ( grid . tiles [ 0 ] [ 2 ] ) } x { len ( grid . tiles ) } per upscale in a total of { state . job_count } batches. " )
2022-09-22 17:11:48 +08:00
result_images = [ ]
for n in range ( upscale_count ) :
start_seed = seed + n
p . seed = start_seed
work_results = [ ]
for i in range ( batch_count ) :
p . batch_size = batch_size
2022-11-29 04:24:53 +08:00
p . init_images = work [ i * batch_size : ( i + 1 ) * batch_size ]
2022-09-22 17:11:48 +08:00
state . job = f " Batch { i + 1 + n * batch_count } out of { state . job_count } "
processed = processing . process_images ( p )
if initial_info is None :
initial_info = processed . info
p . seed = processed . seed + 1
work_results + = processed . images
image_index = 0
for y , h , row in grid . tiles :
for tiledata in row :
2022-11-29 04:28:22 +08:00
tiledata [ 2 ] = work_results [ image_index ] if image_index < len ( work_results ) else Image . new ( " RGB " , ( p . width , p . height ) )
2022-09-22 17:11:48 +08:00
image_index + = 1
combined_image = images . combine_grid ( grid )
result_images . append ( combined_image )
if opts . samples_save :
2022-11-29 04:28:22 +08:00
images . save_image ( combined_image , p . outpath_samples , " " , start_seed , p . prompt , opts . samples_format , info = initial_info , p = p )
2022-09-22 17:11:48 +08:00
processed = Processed ( p , result_images , seed , initial_info )
2022-11-29 04:28:22 +08:00
return processed