2022-09-21 21:06:37 +08:00
import os
import sys
import traceback
2022-09-30 06:46:23 +08:00
from basicsr . utils . download_util import load_file_from_url
2022-09-21 21:06:37 +08:00
2022-09-30 06:46:23 +08:00
from modules . upscaler import Upscaler , UpscalerData
2022-12-03 23:06:33 +08:00
from ldsr_model_arch import LDSR
from modules import shared , script_callbacks
2022-12-04 21:42:19 +08:00
import sd_hijack_autoencoder , sd_hijack_ddpm_v1
2022-09-21 21:06:37 +08:00
2022-09-30 06:46:23 +08:00
class UpscalerLDSR ( Upscaler ) :
def __init__ ( self , user_path ) :
2022-09-21 21:06:37 +08:00
self . name = " LDSR "
2022-09-30 06:46:23 +08:00
self . user_path = user_path
self . model_url = " https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 "
self . yaml_url = " https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 "
super ( ) . __init__ ( )
scaler_data = UpscalerData ( " LDSR " , None , self )
self . scalers = [ scaler_data ]
def load_model ( self , path : str ) :
2022-09-30 21:41:25 +08:00
# Remove incorrect project.yaml file if too big
yaml_path = os . path . join ( self . model_path , " project.yaml " )
2022-09-30 21:55:04 +08:00
old_model_path = os . path . join ( self . model_path , " model.pth " )
new_model_path = os . path . join ( self . model_path , " model.ckpt " )
2022-12-11 02:57:18 +08:00
safetensors_model_path = os . path . join ( self . model_path , " model.safetensors " )
2022-09-30 21:41:25 +08:00
if os . path . exists ( yaml_path ) :
statinfo = os . stat ( yaml_path )
2022-09-30 21:55:04 +08:00
if statinfo . st_size > = 10485760 :
2022-09-30 21:41:25 +08:00
print ( " Removing invalid LDSR YAML file. " )
os . remove ( yaml_path )
2022-09-30 21:55:04 +08:00
if os . path . exists ( old_model_path ) :
print ( " Renaming model from model.pth to model.ckpt " )
os . rename ( old_model_path , new_model_path )
2022-12-11 02:57:18 +08:00
if os . path . exists ( safetensors_model_path ) :
model = safetensors_model_path
else :
model = load_file_from_url ( url = self . model_url , model_dir = self . model_path ,
file_name = " model.ckpt " , progress = True )
2022-09-30 21:33:06 +08:00
yaml = load_file_from_url ( url = self . yaml_url , model_dir = self . model_path ,
2022-09-30 06:46:23 +08:00
file_name = " project.yaml " , progress = True )
try :
return LDSR ( model , yaml )
except Exception :
print ( " Error importing LDSR: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
return None
def do_upscale ( self , img , path ) :
ldsr = self . load_model ( path )
if ldsr is None :
print ( " NO LDSR! " )
return img
ddim_steps = shared . opts . ldsr_steps
return ldsr . super_resolution ( img , ddim_steps , self . scale )
2022-12-03 23:06:33 +08:00
def on_ui_settings ( ) :
import gradio as gr
shared . opts . add_option ( " ldsr_steps " , shared . OptionInfo ( 100 , " LDSR processing steps. Lower = faster " , gr . Slider , { " minimum " : 1 , " maximum " : 200 , " step " : 1 } , section = ( ' upscaling ' , " Upscaling " ) ) )
2022-12-10 21:54:29 +08:00
shared . opts . add_option ( " ldsr_cached " , shared . OptionInfo ( False , " Cache LDSR model in memory " , gr . Checkbox , { " interactive " : True } , section = ( ' upscaling ' , " Upscaling " ) ) )
2022-12-03 23:06:33 +08:00
script_callbacks . on_ui_settings ( on_ui_settings )