From 948eff4b3caa237334389a5a08adda130e2b43a5 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:36:20 +0300 Subject: [PATCH] make swinir actually useful --- swinir.py => modules/swinir.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) rename swinir.py => modules/swinir.py (77%) diff --git a/swinir.py b/modules/swinir.py similarity index 77% rename from swinir.py rename to modules/swinir.py index cb2bbe3dd..6c7f0a2db 100644 --- a/swinir.py +++ b/modules/swinir.py @@ -12,7 +12,13 @@ import modules.images from modules.shared import cmd_opts, opts, device from modules.swinir_arch import SwinIR as net precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext -def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(cmd_opts.esrgan_models_path))): +def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4): + + try: + modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin")) + except Exception: + print(f"Error loading ESRGAN model", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) if not large_model: # use 'nearest+conv' to avoid block artifacts model = net(upscale=scale, in_chans=3, img_size=64, window_size=8, @@ -26,12 +32,16 @@ def load_model(task = "realsr", large_model = True, model_path=next(os.listdir(c mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv') pretrained_model = torch.load(model_path) - model.load_state_dict(pretrained_model, strict=True) + model.load_state_dict(pretrained_model["params_ema"], strict=True) return model.half().to(device) def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4): - img = cv2.imread(img, cv2.IMREAD_COLOR).astype(np.float16) / 255. + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(device) model = load_model() with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() @@ -45,7 +55,7 @@ def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, w if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return output + return Image.fromarray(output, 'RGB') def inference(img, model, tile, tile_overlap, window_size, scale): @@ -71,4 +81,12 @@ def inference(img, model, tile, tile_overlap, window_size, scale): W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) output = E.div_(W) - return output \ No newline at end of file + return output + +class UpscalerSwin(modules.images.Upscaler): + def __init__(self, title): + self.name = title + + def do_upscale(self, img): + img = upscale(img) + return img