2022-10-01 04:28:37 +08:00
import sys
import PIL . Image
2023-04-19 15:35:50 +08:00
2022-10-01 04:28:37 +08:00
import modules . upscaler
2023-12-31 22:11:18 +08:00
from modules import devices , errors , modelloader , script_callbacks , shared , upscaler_utils
2022-10-01 04:28:37 +08:00
class UpscalerScuNET ( modules . upscaler . Upscaler ) :
def __init__ ( self , dirname ) :
self . name = " ScuNET "
self . model_name = " ScuNET GAN "
self . model_name2 = " ScuNET PSNR "
self . model_url = " https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth "
self . model_url2 = " https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth "
self . user_path = dirname
super ( ) . __init__ ( )
model_paths = self . find_models ( ext_filter = [ " .pth " ] )
scalers = [ ]
add_model2 = True
for file in model_paths :
2023-05-29 14:41:36 +08:00
if file . startswith ( " http " ) :
2022-10-01 04:28:37 +08:00
name = self . model_name
else :
name = modelloader . friendly_name ( file )
if name == self . model_name2 or file == self . model_url2 :
add_model2 = False
try :
scaler_data = modules . upscaler . UpscalerData ( name , file , self , 4 )
scalers . append ( scaler_data )
except Exception :
2023-06-01 00:56:37 +08:00
errors . report ( f " Error loading ScuNET model: { file } " , exc_info = True )
2022-10-01 04:28:37 +08:00
if add_model2 :
scaler_data2 = modules . upscaler . UpscalerData ( self . model_name2 , self . model_url2 , self )
scalers . append ( scaler_data2 )
self . scalers = scalers
2023-04-19 15:35:50 +08:00
def do_upscale ( self , img : PIL . Image . Image , selected_file ) :
2023-07-08 22:13:18 +08:00
devices . torch_gc ( )
2023-05-29 15:38:51 +08:00
try :
model = self . load_model ( selected_file )
except Exception as e :
print ( f " ScuNET: Unable to load model from { selected_file } : { e } " , file = sys . stderr )
2022-10-01 04:28:37 +08:00
return img
2023-12-31 22:11:18 +08:00
img = upscaler_utils . upscale_2 (
img ,
model ,
tile_size = shared . opts . SCUNET_tile ,
tile_overlap = shared . opts . SCUNET_tile_overlap ,
scale = 1 , # ScuNET is a denoising model, not an upscaler
desc = ' ScuNET ' ,
)
2023-07-08 22:13:18 +08:00
devices . torch_gc ( )
2023-12-31 22:11:18 +08:00
return img
2022-10-01 04:28:37 +08:00
def load_model ( self , path : str ) :
2022-12-03 23:06:33 +08:00
device = devices . get_device_for ( ' scunet ' )
2023-05-29 14:41:36 +08:00
if path . startswith ( " http " ) :
2023-05-29 14:45:07 +08:00
# TODO: this doesn't use `path` at all?
2023-12-25 20:43:51 +08:00
filename = modelloader . load_file_from_url ( self . model_url , model_dir = self . model_download_path , file_name = f " { self . name } .pth " )
2022-10-01 04:28:37 +08:00
else :
filename = path
2023-12-30 22:37:03 +08:00
return modelloader . load_spandrel_model ( filename , device = device , expected_architecture = ' SCUNet ' )
2023-05-14 16:04:21 +08:00
def on_ui_settings ( ) :
import gradio as gr
shared . opts . add_option ( " SCUNET_tile " , shared . OptionInfo ( 256 , " Tile size for SCUNET upscalers. " , gr . Slider , { " minimum " : 0 , " maximum " : 512 , " step " : 16 } , section = ( ' upscaling ' , " Upscaling " ) ) . info ( " 0 = no tiling " ) )
shared . opts . add_option ( " SCUNET_tile_overlap " , shared . OptionInfo ( 8 , " Tile overlap for SCUNET upscalers. " , gr . Slider , { " minimum " : 0 , " maximum " : 64 , " step " : 1 } , section = ( ' upscaling ' , " Upscaling " ) ) . info ( " Low values = visible seam " ) )
script_callbacks . on_ui_settings ( on_ui_settings )