import logging from typing import Callable import numpy as np import torch import tqdm from PIL import Image from modules import devices, images, shared, torch_utils logger = logging.getLogger(__name__) def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: img = np.array(img.convert("RGB")) img = img[:, :, ::-1] # flip RGB to BGR img = np.transpose(img, (2, 0, 1)) # HWC to CHW img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] return torch.from_numpy(img) def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: if tensor.ndim == 4: # If we're given a tensor with a batch dimension, squeeze it out # (but only if it's a batch of size 1). if tensor.shape[0] != 1: raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") tensor = tensor.squeeze(0) assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale arr = arr.round().astype(np.uint8) arr = arr[:, :, ::-1] # flip BGR to RGB return Image.fromarray(arr, "RGB") def upscale_pil_patch(model, img: Image.Image) -> Image.Image: """ Upscale a given PIL image using the given model. """ param = torch_utils.get_param(model) with torch.inference_mode(): tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension tensor = tensor.to(device=param.device, dtype=param.dtype) with devices.without_autocast(): return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( model: Callable[[torch.Tensor], torch.Tensor], img: Image.Image, *, tile_size: int, tile_overlap: int = 0, desc="tiled upscale", ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output grid = images.split_grid(img, tile_size, tile_size, tile_overlap) newtiles = [] with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p: for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: if shared.state.interrupted: return img output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width newrow.append([x * scale_factor, w * scale_factor, output]) p.update(1) newtiles.append([y * scale_factor, h * scale_factor, newrow]) newgrid = images.Grid( newtiles, tile_w=grid.tile_w * scale_factor, tile_h=grid.tile_h * scale_factor, image_w=grid.image_w * scale_factor, image_h=grid.image_h * scale_factor, overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) def tiled_upscale_2( img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. b, c, h, w = img.size() tile_size = min(tile_size, h, w) if tile_size <= 0: logger.debug("Upscaling %s without tiling", img.shape) return model(img) stride = tile_size - tile_overlap h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] result = torch.zeros( b, c, h * scale, w * scale, device=device, dtype=img.dtype, ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: for h_idx in h_idx_list: if shared.state.interrupted or shared.state.skipped: break for w_idx in w_idx_list: if shared.state.interrupted or shared.state.skipped: break # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, ].to(device=device) out_patch = model(in_patch) result[ ..., h_idx * scale : (h_idx + tile_size) * scale, w_idx * scale : (w_idx + tile_size) * scale, ].add_(out_patch) out_patch_mask = torch.ones_like(out_patch) weights[ ..., h_idx * scale : (h_idx + tile_size) * scale, w_idx * scale : (w_idx + tile_size) * scale, ].add_(out_patch_mask) pbar.update(1) output = result.div_(weights) return output def upscale_2( img: Image.Image, model, *, tile_size: int, tile_overlap: int, scale: int, desc: str, ): """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ param = torch_utils.get_param(model) tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( tensor, model, tile_size=tile_size, tile_overlap=tile_overlap, scale=scale, desc=desc, device=param.device, ) return torch_bgr_to_pil_image(output)