2023-12-25 20:43:51 +08:00
import logging
2023-05-29 15:38:51 +08:00
import sys
2022-09-26 22:29:50 +08:00
2024-01-04 04:38:13 +08:00
import torch
2022-09-26 22:29:50 +08:00
from PIL import Image
2023-12-31 22:11:18 +08:00
from modules import devices , modelloader , script_callbacks , shared , upscaler_utils
2022-09-30 06:46:23 +08:00
from modules . upscaler import Upscaler , UpscalerData
2022-09-26 22:29:22 +08:00
2023-06-13 18:00:05 +08:00
SWINIR_MODEL_URL = " https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth "
2022-09-26 22:29:22 +08:00
2023-12-25 20:43:51 +08:00
logger = logging . getLogger ( __name__ )
2022-12-03 23:06:33 +08:00
2022-09-30 06:46:23 +08:00
class UpscalerSwinIR ( Upscaler ) :
def __init__ ( self , dirname ) :
2023-07-09 03:05:38 +08:00
self . _cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
self . _cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
2022-09-30 06:46:23 +08:00
self . name = " SwinIR "
2023-06-13 18:00:05 +08:00
self . model_url = SWINIR_MODEL_URL
2022-09-30 06:46:23 +08:00
self . model_name = " SwinIR 4x "
self . user_path = dirname
super ( ) . __init__ ( )
scalers = [ ]
model_files = self . find_models ( ext_filter = [ " .pt " , " .pth " ] )
for model in model_files :
2023-05-29 14:41:36 +08:00
if model . startswith ( " http " ) :
2022-09-30 06:46:23 +08:00
name = self . model_name
else :
name = modelloader . friendly_name ( model )
model_data = UpscalerData ( name , model , self )
scalers . append ( model_data )
self . scalers = scalers
2023-12-25 20:43:51 +08:00
def do_upscale ( self , img : Image . Image , model_file : str ) - > Image . Image :
2023-12-31 22:11:18 +08:00
current_config = ( model_file , shared . opts . SWIN_tile )
2023-12-25 20:43:51 +08:00
if self . _cached_model_config == current_config :
2023-07-09 03:05:38 +08:00
model = self . _cached_model
else :
try :
model = self . load_model ( model_file )
except Exception as e :
print ( f " Failed loading SwinIR model { model_file } : { e } " , file = sys . stderr )
return img
2023-12-25 20:43:51 +08:00
self . _cached_model = model
self . _cached_model_config = current_config
2023-12-31 22:11:18 +08:00
img = upscaler_utils . upscale_2 (
2023-12-25 20:43:51 +08:00
img ,
model ,
2023-12-31 22:11:18 +08:00
tile_size = shared . opts . SWIN_tile ,
tile_overlap = shared . opts . SWIN_tile_overlap ,
scale = 4 , # TODO: This was hard-coded before too...
desc = " SwinIR " ,
2023-12-25 20:43:51 +08:00
)
2023-07-08 22:13:18 +08:00
devices . torch_gc ( )
2022-09-30 06:46:23 +08:00
return img
2022-09-26 22:29:22 +08:00
2022-09-30 06:46:23 +08:00
def load_model ( self , path , scale = 4 ) :
2023-05-29 14:41:36 +08:00
if path . startswith ( " http " ) :
2023-05-29 14:34:26 +08:00
filename = modelloader . load_file_from_url (
url = path ,
model_dir = self . model_download_path ,
file_name = f " { self . model_name . replace ( ' ' , ' _ ' ) } .pth " ,
)
2022-09-30 06:46:23 +08:00
else :
filename = path
2023-12-31 06:09:51 +08:00
model_descriptor = modelloader . load_spandrel_model (
2023-12-25 20:43:51 +08:00
filename ,
device = self . _get_device ( ) ,
2024-01-04 04:38:13 +08:00
prefer_half = ( devices . dtype == torch . float16 ) ,
2023-12-30 22:37:03 +08:00
expected_architecture = " SwinIR " ,
2023-12-25 20:43:51 +08:00
)
2023-12-31 22:11:18 +08:00
if getattr ( shared . opts , ' SWIN_torch_compile ' , False ) :
2023-12-25 20:43:51 +08:00
try :
2023-12-31 06:09:51 +08:00
model_descriptor . model . compile ( )
2023-12-25 20:43:51 +08:00
except Exception :
logger . warning ( " Failed to compile SwinIR model, fallback to JIT " , exc_info = True )
2023-12-31 06:09:51 +08:00
return model_descriptor
2022-09-26 22:29:22 +08:00
2023-12-25 20:43:51 +08:00
def _get_device ( self ) :
return devices . get_device_for ( ' swinir ' )
2022-09-26 22:29:22 +08:00
2022-12-03 23:06:33 +08:00
def on_ui_settings ( ) :
import gradio as gr
shared . opts . add_option ( " SWIN_tile " , shared . OptionInfo ( 192 , " Tile size for all SwinIR. " , gr . Slider , { " minimum " : 16 , " maximum " : 512 , " step " : 16 } , section = ( ' upscaling ' , " Upscaling " ) ) )
shared . opts . add_option ( " SWIN_tile_overlap " , shared . OptionInfo ( 8 , " Tile overlap, in pixels for SwinIR. Low values = visible seam. " , gr . Slider , { " minimum " : 0 , " maximum " : 48 , " step " : 1 } , section = ( ' upscaling ' , " Upscaling " ) ) )
2023-12-25 20:43:51 +08:00
shared . opts . add_option ( " SWIN_torch_compile " , shared . OptionInfo ( False , " Use torch.compile to accelerate SwinIR. " , gr . Checkbox , { " interactive " : True } , section = ( ' upscaling ' , " Upscaling " ) ) . info ( " Takes longer on first run " ) )
2022-12-03 23:06:33 +08:00
script_callbacks . on_ui_settings ( on_ui_settings )