mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-02-12 00:22:55 +08:00
Refactor esrgan_upscale to more generic upscale_with_model
This commit is contained in:
parent
12c6f37f8e
commit
e472383acb
@ -1,13 +1,12 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import modelloader, images, devices
|
from modules import modelloader, devices
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.upscaler_utils import upscale_with_model
|
||||||
|
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
def mod2normal(state_dict):
|
||||||
@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def upscale_without_tiling(model, img):
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(devices.device_esrgan)
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(img)
|
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
|
||||||
return Image.fromarray(output, 'RGB')
|
|
||||||
|
|
||||||
|
|
||||||
def esrgan_upscale(model, img):
|
def esrgan_upscale(model, img):
|
||||||
if opts.ESRGAN_tile == 0:
|
return upscale_with_model(
|
||||||
return upscale_without_tiling(model, img)
|
model,
|
||||||
|
img,
|
||||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
tile_size=opts.ESRGAN_tile,
|
||||||
newtiles = []
|
tile_overlap=opts.ESRGAN_tile_overlap,
|
||||||
scale_factor = 1
|
)
|
||||||
|
|
||||||
for y, h, row in grid.tiles:
|
|
||||||
newrow = []
|
|
||||||
for tiledata in row:
|
|
||||||
x, w, tile = tiledata
|
|
||||||
|
|
||||||
output = upscale_without_tiling(model, tile)
|
|
||||||
scale_factor = output.width // tile.width
|
|
||||||
|
|
||||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
|
||||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
|
||||||
|
|
||||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
|
||||||
output = images.combine_grid(newgrid)
|
|
||||||
return output
|
|
||||||
|
66
modules/upscaler_utils.py
Normal file
66
modules/upscaler_utils.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modules import devices, images
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_without_tiling(model, img: Image.Image):
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(devices.device_esrgan)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(img)
|
||||||
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output = 255. * np.moveaxis(output, 0, 2)
|
||||||
|
output = output.astype(np.uint8)
|
||||||
|
output = output[:, :, ::-1]
|
||||||
|
return Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_with_model(
|
||||||
|
model: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
img: Image.Image,
|
||||||
|
*,
|
||||||
|
tile_size: int,
|
||||||
|
tile_overlap: int = 0,
|
||||||
|
desc="tiled upscale",
|
||||||
|
) -> Image.Image:
|
||||||
|
if tile_size <= 0:
|
||||||
|
logger.debug("Upscaling %s without tiling", img)
|
||||||
|
output = upscale_without_tiling(model, img)
|
||||||
|
logger.debug("=> %s", output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
|
||||||
|
newtiles = []
|
||||||
|
|
||||||
|
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
|
||||||
|
for y, h, row in grid.tiles:
|
||||||
|
newrow = []
|
||||||
|
for x, w, tile in row:
|
||||||
|
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
||||||
|
output = upscale_without_tiling(model, tile)
|
||||||
|
scale_factor = output.width // tile.width
|
||||||
|
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
||||||
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||||
|
p.update(1)
|
||||||
|
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||||
|
|
||||||
|
newgrid = images.Grid(
|
||||||
|
newtiles,
|
||||||
|
tile_w=grid.tile_w * scale_factor,
|
||||||
|
tile_h=grid.tile_h * scale_factor,
|
||||||
|
image_w=grid.image_w * scale_factor,
|
||||||
|
image_h=grid.image_h * scale_factor,
|
||||||
|
overlap=grid.overlap * scale_factor,
|
||||||
|
)
|
||||||
|
return images.combine_grid(newgrid)
|
Loading…
Reference in New Issue
Block a user