mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-30 18:22:56 +08:00
Refactor esrgan_upscale to more generic upscale_with_model
This commit is contained in:
parent
12c6f37f8e
commit
e472383acb
@ -1,13 +1,12 @@
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import modelloader, images, devices
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler):
|
||||
return model
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
|
||||
|
||||
def esrgan_upscale(model, img):
|
||||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for tiledata in row:
|
||||
x, w, tile = tiledata
|
||||
|
||||
output = upscale_without_tiling(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
return upscale_with_model(
|
||||
model,
|
||||
img,
|
||||
tile_size=opts.ESRGAN_tile,
|
||||
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
66
modules/upscaler_utils.py
Normal file
66
modules/upscaler_utils.py
Normal file
@ -0,0 +1,66 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from modules import devices, images
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img: Image.Image):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
|
||||
|
||||
def upscale_with_model(
|
||||
model: Callable[[torch.Tensor], torch.Tensor],
|
||||
img: Image.Image,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int = 0,
|
||||
desc="tiled upscale",
|
||||
) -> Image.Image:
|
||||
if tile_size <= 0:
|
||||
logger.debug("Upscaling %s without tiling", img)
|
||||
output = upscale_without_tiling(model, img)
|
||||
logger.debug("=> %s", output)
|
||||
return output
|
||||
|
||||
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
|
||||
newtiles = []
|
||||
|
||||
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for x, w, tile in row:
|
||||
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
||||
output = upscale_without_tiling(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
p.update(1)
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = images.Grid(
|
||||
newtiles,
|
||||
tile_w=grid.tile_w * scale_factor,
|
||||
tile_h=grid.tile_h * scale_factor,
|
||||
image_w=grid.image_w * scale_factor,
|
||||
image_h=grid.image_h * scale_factor,
|
||||
overlap=grid.overlap * scale_factor,
|
||||
)
|
||||
return images.combine_grid(newgrid)
|
Loading…
Reference in New Issue
Block a user