Holy $hit.

Yep.

Fix gfpgan_model_arch requirement(s).
Add Upscaler base class, move from images.
Add a lot of methods to Upscaler.
Re-work all the child upscalers to be proper classes.
Add BSRGAN scaler.
Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff.
Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated.
Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size.
Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size.
Add typehints for IDE sanity.
PEP-8 improvements.
Moar.
This commit is contained in:
d8ahazard 2022-09-29 17:46:23 -05:00
parent 31ad536c33
commit 0dce0df1ee
18 changed files with 1009 additions and 641 deletions

View File

@ -1,5 +1,5 @@
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import shutil
import subprocess import subprocess
import os import os
import sys import sys
@ -22,7 +22,6 @@ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "6
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
@ -122,9 +121,11 @@ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-di
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier if os.path.isdir(repo_dir('latent-diffusion')):
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash) try:
shutil.rmtree(repo_dir('latent-diffusion'))
except:
pass
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")

79
modules/bsrgan_model.py Normal file
View File

@ -0,0 +1,79 @@
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import shared, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
model.to(shared.device)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_dir, filename))
return None
print("Loading %s from %s" % (self.model_dir, filename))
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
return model

View File

@ -0,0 +1,103 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
print([in_nc, out_nc, nf, nb, gc, sf])
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

View File

@ -1,6 +1,4 @@
import os import os
import sys
import traceback
import numpy as np import numpy as np
import torch import torch
@ -8,32 +6,56 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
import modules.images from modules import shared, modelloader, images
from modules import shared
from modules import shared, modelloader
from modules.devices import has_mps from modules.devices import has_mps
from modules.paths import models_path from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
model_dir = "ESRGAN"
model_path = os.path.join(models_path, model_dir)
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
model_name = "ESRGAN_x4"
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
self.model_name = "ESRGAN 4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
print(f"File: {file}")
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
def load_model(path: str, name: str): scaler_data = UpscalerData(name, file, self, 4)
global model_path print(f"ESRGAN: Adding scaler {name}")
global model_url self.scalers.append(scaler_data)
global model_dir
global model_name def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
if "http" in path: if "http" in path:
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True) filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="%s.pth" % self.model_name,
progress=True)
else: else:
filename = path filename = path
if not os.path.exists(filename) or filename is None: if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (model_dir, filename)) print("Unable to load %s from %s" % (self.model_path, filename))
return None return None
print("Loading %s from %s" % (model_dir, filename))
# this code is adapted from https://github.com/xinntao/ESRGAN # this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@ -43,7 +65,8 @@ def load_model(path: str, name: str):
return crt_model return crt_model
if 'model.0.weight' not in pretrained_net: if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
"params_ema"]
if is_realesrgan: if is_realesrgan:
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
else: else:
@ -96,6 +119,7 @@ def load_model(path: str, name: str):
crt_model.eval() crt_model.eval()
return crt_model return crt_model
def upscale_without_tiling(model, img): def upscale_without_tiling(model, img):
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
@ -115,7 +139,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img) return upscale_without_tiling(model, img)
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = [] newtiles = []
scale_factor = 1 scale_factor = 1
@ -130,38 +154,7 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output]) newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow]) newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
output = modules.images.combine_grid(newgrid) grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output return output
class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.filename = filename
def do_upscale(self, img):
model = load_model(self.filename, self.name)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def setup_model(dirname):
global model_path
global model_name
if not os.path.exists(model_path):
os.makedirs(model_path)
model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
if len(model_paths) == 0:
modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
for file in model_paths:
name = modelloader.friendly_name(file)
try:
modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
except Exception:
print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)

View File

@ -66,7 +66,6 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res image = res
if upscaling_resize != 1.0:
def upscale(image, scaler_index, resize): def upscale(image, scaler_index, resize):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist()) pixels = tuple(np.array(small).flatten().tolist())
@ -75,7 +74,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.upscale(image, image.width * resize, image.height * resize) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c cached_images[key] = c
return c return c

View File

@ -1,24 +1,23 @@
import os import os
import sys import sys
import traceback import traceback
from glob import glob
from modules import shared, devices import facexlib
from modules.shared import cmd_opts import gfpgan
from modules.paths import script_path
import modules.face_restoration import modules.face_restoration
from modules import shared, devices, modelloader from modules import shared, devices, modelloader
from modules.paths import models_path from modules.paths import models_path
model_dir = "GFPGAN" model_dir = "GFPGAN"
cmd_dir = None user_path = None
model_path = os.path.join(models_path, model_dir) model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
def gfpgan(): def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
@ -28,14 +27,16 @@ def gfpgan():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
models = modelloader.load_models(model_path, model_url, cmd_dir) models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) != 0: if len(models) == 1 and "http" in models[0]:
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) latest_file = max(models, key=os.path.getctime)
model_file = latest_file model_file = latest_file
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2, model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
bg_upsampler=None) bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
@ -44,11 +45,12 @@ def gfpgan():
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
model = gfpgan() model = gfpgann()
if model is None: if model is None:
return np_image return np_image
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload: if shared.opts.face_restoration_unload:
@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image):
return np_image return np_image
have_gfpgan = False
gfpgan_constructor = None gfpgan_constructor = None
@ -67,14 +68,33 @@ def setup_model(dirname):
os.makedirs(model_path) os.makedirs(model_path)
try: try:
from modules.gfpgan_model_arch import GFPGANerr from gfpgan import GFPGANer
global cmd_dir from facexlib import detection, parsing
global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
cmd_dir = dirname load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
print("Setting model_dir to " + model_path)
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
print("Have gfpgan should be true?")
have_gfpgan = True have_gfpgan = True
gfpgan_constructor = GFPGANerr gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self): def name(self):
@ -82,7 +102,9 @@ def setup_model(dirname):
def restore(self, np_image): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
only_center_face=False,
paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image return np_image

View File

@ -1,150 +0,0 @@
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANerr():
"""Helper for restoration with GFPGAN.
It will detect and crop faces, and then resize the faces to 512x512.
GFPGAN is used to restored the resized faces.
The background is upsampled with the bg_upsampler.
Finally, the faces will be pasted back to the upsample background image.
Args:
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
upscale (float): The upscale of the final output. Default: 2.
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
"""
def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'bilinear':
self.gfpgan = GFPGANBilinear(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'original':
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
elif arch == 'RestoreFormer':
from gfpgan.archs.restoreformer_arch import RestoreFormer
self.gfpgan = RestoreFormer()
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True,
device=self.device,
model_rootpath=model_dir)
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=model_dir, progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
self.face_helper.clean_all()
if has_aligned: # the inputs are already aligned
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
# upsample the background
if self.bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None

View File

@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import modules.shared
from modules import sd_samplers, shared from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width) cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height) rows = math.ceil((h - overlap) / non_overlap_height)
dx = (w - tile_w) / (cols-1) if cols > 1 else 0 dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows-1) if rows > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap) grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows): for row in range(rows):
@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
for col in range(cols): for col in range(cols):
x = int(col * dx) x = int(col * dx)
if x+tile_w >= w: if x + tile_w >= w:
x = w - tile_w x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h)) tile = image.crop((x, y, x + tile_w, y + tile_h))
@ -85,8 +84,10 @@ def combine_grid(grid):
r = r.astype(np.uint8) r = r.astype(np.uint8)
return Image.fromarray(r, 'L') return Image.fromarray(r, 'L')
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) mask_w = make_mask_image(
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
mask_h = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles: for y, h, row in grid.tiles:
@ -129,10 +130,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
def draw_texts(drawing, draw_x, draw_y, lines): def draw_texts(drawing, draw_x, draw_y, lines):
for i, line in enumerate(lines): for i, line in enumerate(lines):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active: if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4) drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing draw_y += line.size[1] + line_spacing
@ -171,7 +174,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2 pad_top = max(hor_text_heights) + line_spacing * 2
@ -202,8 +206,10 @@ def draw_prompt_matrix(im, width, height, all_prompts):
prompts_horiz = prompts[:boundary] prompts_horiz = prompts[:boundary]
prompts_vert = prompts[boundary:] prompts_vert = prompts[boundary:]
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] range(1 << len(prompts_horiz))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
range(1 << len(prompts_vert))]
return draw_grid_annotations(im, width, height, hor_texts, ver_texts) return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
@ -214,7 +220,8 @@ def resize_image(resize_mode, im, width, height):
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
return upscaler.upscale(im, w, h) scale = w / im.width
return upscaler.scaler.upscale(im, scale)
if resize_mode == 0: if resize_mode == 0:
res = resize(im, width, height) res = resize(im, width, height)
@ -244,11 +251,13 @@ def resize_image(resize_mode, im, width, height):
if ratio < src_ratio: if ratio < src_ratio:
fill_height = height // 2 - src_h // 2 fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
box=(0, fill_height + src_h))
elif ratio > src_ratio: elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2 fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
box=(fill_width + src_w, 0))
return res return res
@ -256,7 +265,7 @@ def resize_image(resize_mode, im, width, height):
invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s'+string.punctuation+']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128 max_filename_part_length = 128
@ -283,7 +292,8 @@ def apply_filename_pattern(x, p, seed, prompt):
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
if len(words) == 0: if len(words) == 0:
words = ["empty"] words = ["empty"]
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) x = x.replace("[prompt_words]",
sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None: if p is not None:
x = x.replace("[steps]", str(p.steps)) x = x.replace("[steps]", str(p.steps))
@ -291,7 +301,8 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[width]", str(p.width)) x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height)) x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[sampler]",
sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat()) x = x.replace("[date]", datetime.date.today().isoformat())
@ -303,6 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
return x return x
def get_next_sequence_number(path, basename): def get_next_sequence_number(path, basename):
""" """
Determines and returns the next sequence number to use when saving an image in the specified directory. Determines and returns the next sequence number to use when saving an image in the specified directory.
@ -316,7 +328,8 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) l = os.path.splitext(p[prefix_length:])[0].split(
'-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(l[0]), result)
except ValueError: except ValueError:
@ -324,7 +337,10 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
elif opts.save_to_dirs: elif opts.save_to_dirs:
@ -361,7 +377,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = "a.png" fullfn = "a.png"
fullfn_without_extension = "a" fullfn_without_extension = "a"
for i in range(500): for i in range(500):
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}" fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn): if not os.path.exists(fullfn):
@ -403,31 +419,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file.write(info + "\n") file.write(info + "\n")
class Upscaler:
name = "Lanczos"
def do_upscale(self, img):
return img
def upscale(self, img, w, h):
for i in range(3):
if img.width >= w and img.height >= h:
break
img = self.do_upscale(img)
if img.width != w or img.height != h:
img = img.resize((int(w), int(h)), resample=LANCZOS)
return img
class UpscalerNone(Upscaler):
name = "None"
def upscale(self, img, w, h):
return img
modules.shared.sd_upscalers.append(UpscalerNone())
modules.shared.sd_upscalers.append(Upscaler())

View File

@ -1,74 +1,45 @@
import os import os
import sys import sys
import traceback import traceback
from collections import namedtuple
from modules import shared, images, modelloader, paths from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR
from modules import shared
from modules.paths import models_path from modules.paths import models_path
model_dir = "LDSR"
model_path = os.path.join(models_path, model_dir)
cmd_path = None
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) class UpscalerLDSR(Upscaler):
def __init__(self, user_path):
ldsr_models = []
have_ldsr = False
LDSR_obj = None
class UpscalerLDSR(images.Upscaler):
def __init__(self, steps):
self.steps = steps
self.name = "LDSR" self.name = "LDSR"
self.model_path = os.path.join(models_path, self.name)
self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
scaler_data = UpscalerData("LDSR", None, self)
self.scalers = [scaler_data]
def do_upscale(self, img): def load_model(self, path: str):
return upscale_with_ldsr(img) model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.pth", progress=True)
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="project.yaml", progress=True)
def setup_model(dirname):
global cmd_path
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
cmd_path = dirname
shared.sd_upscalers.append(UpscalerLDSR(100))
def prepare_ldsr():
path = paths.paths.get("LDSR", None)
if path is None:
return
global have_ldsr
global LDSR_obj
try: try:
from LDSR import LDSR return LDSR(model, yaml)
model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
if len(model_files) != 0 and len(yaml_files) != 0:
model_file = model_files[0]
yaml_file = yaml_files[0]
have_ldsr = True
LDSR_obj = LDSR(model_file, yaml_file)
else:
return
except Exception: except Exception:
print("Error importing LDSR:", file=sys.stderr) print("Error importing LDSR:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
have_ldsr = False return None
def upscale_with_ldsr(image):
prepare_ldsr()
if not have_ldsr or LDSR_obj is None:
return image
def do_upscale(self, img, path):
ldsr = self.load_model(path)
if ldsr is None:
print("NO LDSR!")
return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
pre_scale = shared.opts.ldsr_pre_down pre_scale = shared.opts.ldsr_pre_down
post_scale = shared.opts.ldsr_post_down return ldsr.super_resolution(img, ddim_steps, self.scale)
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
return image

223
modules/ldsr_model_arch.py Normal file
View File

@ -0,0 +1,223 @@
import gc
import time
import warnings
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
warnings.filterwarnings("ignore", category=UserWarning)
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model.cuda()
if half_attention:
model = model.half()
model.eval()
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for n in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
down_sample_method = 'Lanczos'
gc.collect()
torch.cuda.empty_cache()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
print("Foo")
down_sample_rate = target_scale / 4
print(f"Downsample rate is {down_sample_rate}")
width_downsampled_pre = width_og * down_sample_rate
height_downsampled_pre = height_og * down_sample_method
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
del model
gc.collect()
torch.cuda.empty_cache()
print(f'Processing finished!')
return a
def get_cond(selected_path):
example = dict()
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log

View File

@ -1,34 +1,36 @@
import os import os
import shutil import shutil
import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None, def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
ext_filter=None) -> list:
""" """
A one-and done loader to try finding the desired models in specified directories. A one-and done loader to try finding the desired models in specified directories.
@param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL. @param download_name: Specify to download from model_url immediately.
@param model_url: If specified, attempt to download model from the given URL. @param model_url: If no other models are found, this will be downloaded on upscale.
@param model_path: The location to store/find models in. @param model_path: The location to store/find models in.
@param command_path: A command-line argument to search for models in first. @param command_path: A command-line argument to search for models in first.
@param existing: An array of existing model paths.
@param ext_filter: An optional list of filename extensions to filter by @param ext_filter: An optional list of filename extensions to filter by
@return: A list of paths containing the desired model(s) @return: A list of paths containing the desired model(s)
""" """
output = []
if ext_filter is None: if ext_filter is None:
ext_filter = [] ext_filter = []
if existing is None:
existing = []
try: try:
places = [] places = []
if command_path is not None and command_path != model_path: if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path): if os.path.exists(pretrained_path):
print(f"Appending path: {pretrained_path}")
places.append(pretrained_path) places.append(pretrained_path)
elif os.path.exists(command_path): elif os.path.exists(command_path):
places.append(command_path) places.append(command_path)
@ -36,26 +38,24 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places: for place in places:
if os.path.exists(place): if os.path.exists(place):
for file in os.listdir(place): for file in os.listdir(place):
if os.path.isdir(file): full_path = os.path.join(place, file)
if os.path.isdir(full_path):
continue continue
if len(ext_filter) != 0: if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file) model_name, extension = os.path.splitext(file)
if extension not in ext_filter: if extension not in ext_filter:
continue continue
if file not in existing: if file not in output:
path = os.path.join(place, file) output.append(full_path)
existing.append(path) if model_url is not None and len(output) == 0:
if model_url is not None and len(existing) == 0: if download_name is not None:
if dl_name is not None: dl = load_file_from_url(model_url, model_path, True, download_name)
model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True) output.append(dl)
else: else:
model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True) output.append(model_url)
if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing:
existing.append(model_file)
except: except:
pass pass
return existing return output
def friendly_name(file: str): def friendly_name(file: str):
@ -111,3 +111,37 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
shutil.rmtree(src_path, True) shutil.rmtree(src_path, True)
except: except:
pass pass
def load_upscalers():
datas = []
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
print(f"Class: {name} and {module_name}")
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
print(f"CMD Name: {cmd_name}")
opt_string = None
try:
opt_string = shared.opts.__getattr__(cmd_name)
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
print(f"Appending {child.name}")
datas.append(child)
shared.sd_upscalers = datas
# for scaler in subclasses:
# print(f"Found scaler: {type(scaler).__name__}")
# try:
# scaler = scaler()
# for child in scaler.scalers:
# print(f"Appending {child.name}")
# datas.append[child]
# except:
# pass
# shared.sd_upscalers = datas

View File

@ -1,64 +1,135 @@
import os import os
import sys import sys
import traceback import traceback
from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import modules.images from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
model_dir = "RealESRGAN"
model_path = os.path.join(models_path, model_dir) class UpscalerRealESRGAN(Upscaler):
cmd_dir = None def __init__(self, path):
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) self.name = "RealESRGAN"
realesrgan_models = [] self.model_path = os.path.join(models_path, self.name)
have_realesrgan = False self.user_path = path
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
self.enable = True
self.scalers = []
scalers = self.load_models(path)
for scaler in scalers:
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False
self.scalers = []
def do_upscale(self, img, path):
if not self.enable:
return img
info = self.load_model(path)
if not os.path.exists(info.data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
model_path=info.data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path):
try:
info = None
for scaler in self.scalers:
if scaler.data_path == path:
info = scaler
if info is None:
print(f"Unable to find model info: {path}")
return None
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
info.data_path = model_file
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(): def get_realesrgan_models(scaler):
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [ models = [
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General x4x3", name="R-ESRGAN General 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
netscale=4, ".pth",
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General WDN x4x3", name="R-ESRGAN General WDN 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
netscale=4, scale=4,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
netscale=4, scale=4,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus", name="R-ESRGAN 4x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus anime 6B", name="R-ESRGAN 4x+ Anime6B",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 2x plus", name="R-ESRGAN 2x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
), ),
] ]
@ -66,69 +137,3 @@ def get_realesrgan_models():
except Exception as e: except Exception as e:
print("Error making Real-ESRGAN models list:", file=sys.stderr) print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class UpscalerRealESRGAN(modules.images.Upscaler):
def __init__(self, upscaling, model_index):
self.upscaling = upscaling
self.model_index = model_index
self.name = realesrgan_models[model_index].name
def do_upscale(self, img):
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
global realesrgan_models
global have_realesrgan
if model_path != dirname:
model_path = dirname
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = get_realesrgan_models()
have_realesrgan = True
for i, model in enumerate(realesrgan_models):
if model.name in opts.realesrgan_enabled_models:
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
have_realesrgan = False
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
if not have_realesrgan:
return image
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
if not os.path.exists(model_file):
print("Unable to load RealESRGAN model: %s" % info.name)
return image
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image

View File

@ -50,7 +50,7 @@ def setup_model(dirname):
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt") model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt")
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
@ -68,6 +68,7 @@ def setup_model(dirname):
def model_hash(filename): def model_hash(filename):
try: try:
print(f"Opening: {filename}")
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib import hashlib
m = hashlib.sha256() m = hashlib.sha256()

View File

@ -154,9 +154,9 @@ class VanillaStableDiffusionSampler:
# existing code fails with cetin step counts, like 9 # existing code fails with cetin step counts, like 9
try: try:
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
except Exception: except Exception:
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
return samples_ddim return samples_ddim

View File

@ -1,18 +1,19 @@
import sys
import argparse import argparse
import datetime
import json import json
import os import os
import sys
import gradio as gr import gradio as gr
import tqdm import tqdm
import datetime
import modules.artists import modules.artists
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles
from modules.devices import get_optimal_device
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
@ -38,6 +39,7 @@ parser.add_argument("--share", action='store_true', help="use share=True for gra
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
@ -111,7 +113,7 @@ face_restorers = []
def realesrgan_models_names(): def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()] return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo: class OptionInfo:
@ -176,13 +178,11 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))

View File

@ -1,35 +1,59 @@
import contextlib import contextlib
import os import os
import sys
import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.images
from modules import modelloader from modules import modelloader
from modules.paths import models_path from modules.paths import models_path
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch import SwinIR as net
from modules.upscaler import Upscaler, UpscalerData
model_dir = "SwinIR"
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
model_name = "SwinIR x4"
model_path = os.path.join(models_path, model_dir)
cmd_path = ""
precision_scope = ( precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
) )
def load_model(path, scale=4): class UpscalerSwinIR(Upscaler):
global model_path def __init__(self, dirname):
global model_name self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x"
self.model_path = os.path.join(models_path, self.name)
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
if "http" in model:
name = self.model_name
else:
name = modelloader.friendly_name(model)
model_data = UpscalerData(name, model, self)
scalers.append(model_data)
self.scalers = scalers
def do_upscale(self, img, model_file):
model = self.load_model(model_file)
if model is None:
return img
model = model.to(device)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
except:
pass
return img
def load_model(self, path, scale=4):
if "http" in path: if "http" in path:
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth") dl_name = "%s%s" % (self.name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True) filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
else: else:
filename = path filename = path
if filename is None or not os.path.exists(filename): if filename is None or not os.path.exists(filename):
@ -55,31 +79,6 @@ def load_model(path, scale=4):
return model return model
def setup_model(dirname):
global model_path
global model_name
global cmd_path
if not os.path.exists(model_path):
os.makedirs(model_path)
cmd_path = dirname
model_file = ""
try:
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
if len(models) != 0:
model_file = models[0]
name = modelloader.friendly_name(model_file)
else:
# Add the "default" model if none are found.
model_file = model_url
name = model_name
modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
except Exception:
print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def upscale( def upscale(
img, img,
model, model,
@ -125,34 +124,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for h_idx in h_idx_list: for h_idx in h_idx_list:
for w_idx in w_idx_list: for w_idx in w_idx_list:
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
E[ E[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch) ].add_(out_patch)
W[ W[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask) ].add_(out_patch_mask)
output = E.div_(W) output = E.div_(W)
return output return output
class UpscalerSwin(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.filename = filename
def do_upscale(self, img):
model = load_model(self.filename)
if model is None:
return img
model = model.to(device)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
except:
pass
return img

121
modules/upscaler.py Normal file
View File

@ -0,0 +1,121 @@
import os
from abc import abstractmethod
import PIL
import numpy as np
import torch
from PIL import Image
import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
from modules.paths import models_path
class Upscaler:
name = None
model_path = None
model_name = None
model_url = None
enable = True
filter = None
model = None
user_path = None
scalers: []
tile = True
def __init__(self, create_dirs=False):
self.mod_pad_h = None
self.tile_size = modules.shared.opts.ESRGAN_tile
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
self.device = modules.shared.device
self.img = None
self.output = None
self.scale = 1
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
if self.name is not None and create_dirs:
self.model_path = os.path.join(models_path, self.name)
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
try:
import cv2
self.can_tile = True
except:
pass
@abstractmethod
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width >= dest_w and img.height >= dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
img = img.resize(dest_w, dest_h, resample=LANCZOS)
return img
@abstractmethod
def load_model(self, path: str):
pass
def find_models(self, ext_filter=None) -> list:
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)
class UpscalerData:
name = None
data_path = None
scale: int = 4
scaler: Upscaler = None
model: None
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
class UpscalerNone(Upscaler):
name = "None"
scalers = []
def load_model(self, path):
pass
def do_upscale(self, img, selected_model=None):
return img
def __init__(self, dirname=None):
super().__init__(False)
self.scalers = [UpscalerData("None", None, self)]
class UpscalerLanczos(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]

View File

@ -1,9 +1,10 @@
import os import os
import signal import signal
import threading import threading
import modules.paths
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan import modules.esrgan_model as esrgan
import modules.bsrgan_model as bsrgan
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
import modules.gfpgan_model as gfpgan import modules.gfpgan_model as gfpgan
@ -27,11 +28,7 @@ modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path)
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
esrgan.setup_model(cmd_opts.esrgan_models_path)
swinir.setup_model(cmd_opts.swinir_models_path)
realesrgan.setup_model(cmd_opts.realesrgan_models_path)
ldsr.setup_model(cmd_opts.ldsr_models_path)
queue_lock = threading.Lock() queue_lock = threading.Lock()