mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-19 21:00:14 +08:00
Merge pull request #1109 from d8ahazard/ModelLoader
Model Loader, Fixes
This commit is contained in:
commit
25414bcd05
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,10 +1,13 @@
|
|||||||
__pycache__
|
__pycache__
|
||||||
/ESRGAN
|
*.ckpt
|
||||||
|
*.pth
|
||||||
|
/ESRGAN/*
|
||||||
|
/SwinIR/*
|
||||||
/repositories
|
/repositories
|
||||||
/venv
|
/venv
|
||||||
/tmp
|
/tmp
|
||||||
/model.ckpt
|
/model.ckpt
|
||||||
/models/**/*.ckpt
|
/models/**/*
|
||||||
/GFPGANv1.3.pth
|
/GFPGANv1.3.pth
|
||||||
/gfpgan/weights/*.pth
|
/gfpgan/weights/*.pth
|
||||||
/ui-config.json
|
/ui-config.json
|
||||||
|
@ -1 +0,0 @@
|
|||||||
|
|
11
launch.py
11
launch.py
@ -1,5 +1,5 @@
|
|||||||
# this scripts installs necessary requirements and launches main program in webui.py
|
# this scripts installs necessary requirements and launches main program in webui.py
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -22,7 +22,6 @@ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HAS
|
|||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
|
|
||||||
|
|
||||||
args = shlex.split(commandline_args)
|
args = shlex.split(commandline_args)
|
||||||
|
|
||||||
@ -120,9 +119,11 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
|
|||||||
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
|
if os.path.isdir(repo_dir('latent-diffusion')):
|
||||||
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
|
try:
|
||||||
|
shutil.rmtree(repo_dir('latent-diffusion'))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if not is_installed("lpips"):
|
if not is_installed("lpips"):
|
||||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||||
|
|
||||||
|
79
modules/bsrgan_model.py
Normal file
79
modules/bsrgan_model.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
import modules.upscaler
|
||||||
|
from modules import shared, modelloader
|
||||||
|
from modules.bsrgan_model_arch import RRDBNet
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "BSRGAN"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.model_name = "BSRGAN 4x"
|
||||||
|
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
|
scalers = []
|
||||||
|
if len(model_paths) == 0:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
for file in model_paths:
|
||||||
|
if "http" in file:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
try:
|
||||||
|
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_file):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model = self.load_model(selected_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
model.to(shared.device)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(shared.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(img)
|
||||||
|
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
output = 255. * np.moveaxis(output, 0, 2)
|
||||||
|
output = output.astype(np.uint8)
|
||||||
|
output = output[:, :, ::-1]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return PIL.Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
if "http" in path:
|
||||||
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||||
|
progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if not os.path.exists(filename) or filename is None:
|
||||||
|
print("Unable to load %s from %s" % (self.model_dir, filename))
|
||||||
|
return None
|
||||||
|
print("Loading %s from %s" % (self.model_dir, filename))
|
||||||
|
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
|
||||||
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
|
model.eval()
|
||||||
|
for k, v in model.named_parameters():
|
||||||
|
v.requires_grad = False
|
||||||
|
return model
|
||||||
|
|
103
modules/bsrgan_model_arch.py
Normal file
103
modules/bsrgan_model_arch.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import functools
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.init as init
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(net_l, scale=1):
|
||||||
|
if not isinstance(net_l, list):
|
||||||
|
net_l = [net_l]
|
||||||
|
for net in net_l:
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale # for residual block
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
m.weight.data *= scale
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
init.constant_(m.weight, 1)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def make_layer(block, n_layers):
|
||||||
|
layers = []
|
||||||
|
for _ in range(n_layers):
|
||||||
|
layers.append(block())
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualDenseBlock_5C(nn.Module):
|
||||||
|
def __init__(self, nf=64, gc=32, bias=True):
|
||||||
|
super(ResidualDenseBlock_5C, self).__init__()
|
||||||
|
# gc: growth channel, i.e. intermediate channels
|
||||||
|
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||||
|
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.lrelu(self.conv1(x))
|
||||||
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||||
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||||
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
return x5 * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDB(nn.Module):
|
||||||
|
'''Residual in Residual Dense Block'''
|
||||||
|
|
||||||
|
def __init__(self, nf, gc=32):
|
||||||
|
super(RRDB, self).__init__()
|
||||||
|
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.RDB1(x)
|
||||||
|
out = self.RDB2(out)
|
||||||
|
out = self.RDB3(out)
|
||||||
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNet(nn.Module):
|
||||||
|
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||||
|
super(RRDBNet, self).__init__()
|
||||||
|
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||||
|
self.sf = sf
|
||||||
|
print([in_nc, out_nc, nf, nb, gc, sf])
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||||
|
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||||
|
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
#### upsampling
|
||||||
|
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
if self.sf==4:
|
||||||
|
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fea = self.conv_first(x)
|
||||||
|
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||||
|
fea = fea + trunk
|
||||||
|
|
||||||
|
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
if self.sf==4:
|
||||||
|
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||||
|
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||||
|
|
||||||
|
return out
|
@ -5,31 +5,31 @@ import traceback
|
|||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import shared, devices
|
|
||||||
from modules.paths import script_path
|
|
||||||
import modules.shared
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
from importlib import reload
|
import modules.shared
|
||||||
|
from modules import shared, devices, modelloader
|
||||||
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
|
# codeformer people made a choice to include modified basicsr library to their project which makes
|
||||||
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
|
||||||
# I am making a choice to include some files from codeformer to work around this issue.
|
# I am making a choice to include some files from codeformer to work around this issue.
|
||||||
|
model_dir = "Codeformer"
|
||||||
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_path = os.path.join(models_path, model_dir)
|
||||||
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
|
||||||
have_codeformer = False
|
have_codeformer = False
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
def setup_codeformer():
|
|
||||||
|
def setup_model(dirname):
|
||||||
|
global model_path
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(model_path)
|
||||||
|
|
||||||
path = modules.paths.paths.get("CodeFormer", None)
|
path = modules.paths.paths.get("CodeFormer", None)
|
||||||
if path is None:
|
if path is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
|
|
||||||
#stored_sys_path = sys.path
|
|
||||||
#sys.path = [path] + sys.path
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
@ -44,18 +44,23 @@ def setup_codeformer():
|
|||||||
def name(self):
|
def name(self):
|
||||||
return "CodeFormer"
|
return "CodeFormer"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, dirname):
|
||||||
self.net = None
|
self.net = None
|
||||||
self.face_helper = None
|
self.face_helper = None
|
||||||
|
self.cmd_dir = dirname
|
||||||
|
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
|
|
||||||
if self.net is not None and self.face_helper is not None:
|
if self.net is not None and self.face_helper is not None:
|
||||||
self.net.to(devices.device_codeformer)
|
self.net.to(devices.device_codeformer)
|
||||||
return self.net, self.face_helper
|
return self.net, self.face_helper
|
||||||
|
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
|
||||||
|
if len(model_paths) != 0:
|
||||||
|
ckpt_path = model_paths[0]
|
||||||
|
else:
|
||||||
|
print("Unable to load codeformer model.")
|
||||||
|
return None, None
|
||||||
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
|
||||||
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
|
|
||||||
checkpoint = torch.load(ckpt_path)['params_ema']
|
checkpoint = torch.load(ckpt_path)['params_ema']
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
@ -74,6 +79,9 @@ def setup_codeformer():
|
|||||||
original_resolution = np_image.shape[0:2]
|
original_resolution = np_image.shape[0:2]
|
||||||
|
|
||||||
self.create_models()
|
self.create_models()
|
||||||
|
if self.net is None or self.face_helper is None:
|
||||||
|
return np_image
|
||||||
|
|
||||||
self.face_helper.clean_all()
|
self.face_helper.clean_all()
|
||||||
self.face_helper.read_image(np_image)
|
self.face_helper.read_image(np_image)
|
||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
@ -114,7 +122,7 @@ def setup_codeformer():
|
|||||||
have_codeformer = True
|
have_codeformer = True
|
||||||
|
|
||||||
global codeformer
|
global codeformer
|
||||||
codeformer = FaceRestorerCodeFormer()
|
codeformer = FaceRestorerCodeFormer(dirname)
|
||||||
shared.face_restorers.append(codeformer)
|
shared.face_restorers.append(codeformer)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -1,80 +1,124 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgam_model_arch as arch
|
import modules.esrgam_model_arch as arch
|
||||||
from modules import shared
|
from modules import shared, modelloader, images
|
||||||
from modules.shared import opts
|
|
||||||
from modules.devices import has_mps
|
from modules.devices import has_mps
|
||||||
import modules.images
|
from modules.paths import models_path
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
def load_model(filename):
|
class UpscalerESRGAN(Upscaler):
|
||||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
def __init__(self, dirname):
|
||||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
self.name = "ESRGAN"
|
||||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||||
|
self.model_name = "ESRGAN 4x"
|
||||||
|
self.scalers = []
|
||||||
|
self.user_path = dirname
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
super().__init__()
|
||||||
|
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
|
scalers = []
|
||||||
|
if len(model_paths) == 0:
|
||||||
|
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
|
scalers.append(scaler_data)
|
||||||
|
for file in model_paths:
|
||||||
|
print(f"File: {file}")
|
||||||
|
if "http" in file:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(file)
|
||||||
|
|
||||||
if 'conv_first.weight' in pretrained_net:
|
scaler_data = UpscalerData(name, file, self, 4)
|
||||||
crt_model.load_state_dict(pretrained_net)
|
print(f"ESRGAN: Adding scaler {name}")
|
||||||
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model):
|
||||||
|
model = self.load_model(selected_model)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
model.to(shared.device)
|
||||||
|
img = esrgan_upscale(model, img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_model(self, path: str):
|
||||||
|
if "http" in path:
|
||||||
|
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||||
|
file_name="%s.pth" % self.model_name,
|
||||||
|
progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if not os.path.exists(filename) or filename is None:
|
||||||
|
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||||
|
return None
|
||||||
|
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||||
|
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||||
|
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||||
|
|
||||||
|
if 'conv_first.weight' in pretrained_net:
|
||||||
|
crt_model.load_state_dict(pretrained_net)
|
||||||
|
return crt_model
|
||||||
|
|
||||||
|
if 'model.0.weight' not in pretrained_net:
|
||||||
|
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
|
||||||
|
"params_ema"]
|
||||||
|
if is_realesrgan:
|
||||||
|
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||||
|
else:
|
||||||
|
raise Exception("The file is not a ESRGAN model.")
|
||||||
|
|
||||||
|
crt_net = crt_model.state_dict()
|
||||||
|
load_net_clean = {}
|
||||||
|
for k, v in pretrained_net.items():
|
||||||
|
if k.startswith('module.'):
|
||||||
|
load_net_clean[k[7:]] = v
|
||||||
|
else:
|
||||||
|
load_net_clean[k] = v
|
||||||
|
pretrained_net = load_net_clean
|
||||||
|
|
||||||
|
tbd = []
|
||||||
|
for k, v in crt_net.items():
|
||||||
|
tbd.append(k)
|
||||||
|
|
||||||
|
# directly copy
|
||||||
|
for k, v in crt_net.items():
|
||||||
|
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
||||||
|
crt_net[k] = pretrained_net[k]
|
||||||
|
tbd.remove(k)
|
||||||
|
|
||||||
|
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
||||||
|
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
||||||
|
|
||||||
|
for k in tbd.copy():
|
||||||
|
if 'RDB' in k:
|
||||||
|
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||||
|
if '.weight' in k:
|
||||||
|
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||||
|
elif '.bias' in k:
|
||||||
|
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||||
|
crt_net[k] = pretrained_net[ori_k]
|
||||||
|
tbd.remove(k)
|
||||||
|
|
||||||
|
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
||||||
|
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
||||||
|
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
||||||
|
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
||||||
|
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
||||||
|
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
||||||
|
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
||||||
|
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
||||||
|
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||||
|
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||||
|
|
||||||
|
crt_model.load_state_dict(crt_net)
|
||||||
|
crt_model.eval()
|
||||||
return crt_model
|
return crt_model
|
||||||
|
|
||||||
if 'model.0.weight' not in pretrained_net:
|
|
||||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
|
||||||
if is_realesrgan:
|
|
||||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
|
||||||
else:
|
|
||||||
raise Exception("The file is not a ESRGAN model.")
|
|
||||||
|
|
||||||
crt_net = crt_model.state_dict()
|
|
||||||
load_net_clean = {}
|
|
||||||
for k, v in pretrained_net.items():
|
|
||||||
if k.startswith('module.'):
|
|
||||||
load_net_clean[k[7:]] = v
|
|
||||||
else:
|
|
||||||
load_net_clean[k] = v
|
|
||||||
pretrained_net = load_net_clean
|
|
||||||
|
|
||||||
tbd = []
|
|
||||||
for k, v in crt_net.items():
|
|
||||||
tbd.append(k)
|
|
||||||
|
|
||||||
# directly copy
|
|
||||||
for k, v in crt_net.items():
|
|
||||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
|
||||||
crt_net[k] = pretrained_net[k]
|
|
||||||
tbd.remove(k)
|
|
||||||
|
|
||||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
|
||||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
|
||||||
|
|
||||||
for k in tbd.copy():
|
|
||||||
if 'RDB' in k:
|
|
||||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
|
||||||
if '.weight' in k:
|
|
||||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
|
||||||
elif '.bias' in k:
|
|
||||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
|
||||||
crt_net[k] = pretrained_net[ori_k]
|
|
||||||
tbd.remove(k)
|
|
||||||
|
|
||||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
|
||||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
|
||||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
|
||||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
|
||||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
|
||||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
|
||||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
|
||||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
|
||||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
|
||||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
|
||||||
|
|
||||||
crt_model.load_state_dict(crt_net)
|
|
||||||
crt_model.eval()
|
|
||||||
return crt_model
|
|
||||||
|
|
||||||
def upscale_without_tiling(model, img):
|
def upscale_without_tiling(model, img):
|
||||||
img = np.array(img)
|
img = np.array(img)
|
||||||
@ -95,7 +139,7 @@ def esrgan_upscale(model, img):
|
|||||||
if opts.ESRGAN_tile == 0:
|
if opts.ESRGAN_tile == 0:
|
||||||
return upscale_without_tiling(model, img)
|
return upscale_without_tiling(model, img)
|
||||||
|
|
||||||
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||||
newtiles = []
|
newtiles = []
|
||||||
scale_factor = 1
|
scale_factor = 1
|
||||||
|
|
||||||
@ -110,32 +154,7 @@ def esrgan_upscale(model, img):
|
|||||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||||
|
|
||||||
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
|
||||||
output = modules.images.combine_grid(newgrid)
|
grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||||
|
output = images.combine_grid(newgrid)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class UpscalerESRGAN(modules.images.Upscaler):
|
|
||||||
def __init__(self, filename, title):
|
|
||||||
self.name = title
|
|
||||||
self.model = load_model(filename)
|
|
||||||
|
|
||||||
def do_upscale(self, img):
|
|
||||||
model = self.model.to(shared.device)
|
|
||||||
img = esrgan_upscale(model, img)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(dirname):
|
|
||||||
for file in os.listdir(dirname):
|
|
||||||
path = os.path.join(dirname, file)
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
|
|
||||||
if extension != '.pt' and extension != '.pth':
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for image, image_name in zip(imageArr, imageNameArr):
|
for image, image_name in zip(imageArr, imageNameArr):
|
||||||
|
if image is None:
|
||||||
|
return outputs, "Please select an input image.", ''
|
||||||
existing_pnginfo = image.info or {}
|
existing_pnginfo = image.info or {}
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
@ -65,29 +67,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
|||||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||||
image = res
|
image = res
|
||||||
|
|
||||||
if upscaling_resize != 1.0:
|
def upscale(image, scaler_index, resize):
|
||||||
def upscale(image, scaler_index, resize):
|
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
pixels = tuple(np.array(small).flatten().tolist())
|
||||||
pixels = tuple(np.array(small).flatten().tolist())
|
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
|
||||||
|
|
||||||
c = cached_images.get(key)
|
c = cached_images.get(key)
|
||||||
if c is None:
|
if c is None:
|
||||||
upscaler = shared.sd_upscalers[scaler_index]
|
upscaler = shared.sd_upscalers[scaler_index]
|
||||||
c = upscaler.upscale(image, image.width * resize, image.height * resize)
|
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||||
cached_images[key] = c
|
cached_images[key] = c
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
||||||
|
|
||||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
||||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||||
|
|
||||||
image = res
|
image = res
|
||||||
|
|
||||||
while len(cached_images) > 2:
|
while len(cached_images) > 2:
|
||||||
del cached_images[next(iter(cached_images.keys()))]
|
del cached_images[next(iter(cached_images.keys()))]
|
||||||
|
@ -1,39 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
from modules import shared, devices
|
import facexlib
|
||||||
from modules.shared import cmd_opts
|
import gfpgan
|
||||||
from modules.paths import script_path
|
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
|
from modules import shared, devices, modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
model_dir = "GFPGAN"
|
||||||
def gfpgan_model_path():
|
user_path = None
|
||||||
from modules.shared import cmd_opts
|
model_path = os.path.join(models_path, model_dir)
|
||||||
|
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||||
filemask = 'GFPGAN*.pth'
|
have_gfpgan = False
|
||||||
|
|
||||||
if cmd_opts.gfpgan_model is not None:
|
|
||||||
return cmd_opts.gfpgan_model
|
|
||||||
|
|
||||||
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
|
|
||||||
|
|
||||||
filename = None
|
|
||||||
for place in places:
|
|
||||||
filename = next(iter(glob(os.path.join(place, filemask))), None)
|
|
||||||
if filename is not None:
|
|
||||||
break
|
|
||||||
|
|
||||||
return filename
|
|
||||||
|
|
||||||
|
|
||||||
loaded_gfpgan_model = None
|
loaded_gfpgan_model = None
|
||||||
|
|
||||||
|
|
||||||
def gfpgan():
|
def gfpgann():
|
||||||
global loaded_gfpgan_model
|
global loaded_gfpgan_model
|
||||||
|
global model_path
|
||||||
if loaded_gfpgan_model is not None:
|
if loaded_gfpgan_model is not None:
|
||||||
loaded_gfpgan_model.gfpgan.to(shared.device)
|
loaded_gfpgan_model.gfpgan.to(shared.device)
|
||||||
return loaded_gfpgan_model
|
return loaded_gfpgan_model
|
||||||
@ -41,7 +27,17 @@ def gfpgan():
|
|||||||
if gfpgan_constructor is None:
|
if gfpgan_constructor is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||||
|
if len(models) == 1 and "http" in models[0]:
|
||||||
|
model_file = models[0]
|
||||||
|
elif len(models) != 0:
|
||||||
|
latest_file = max(models, key=os.path.getctime)
|
||||||
|
model_file = latest_file
|
||||||
|
else:
|
||||||
|
print("Unable to load gfpgan model!")
|
||||||
|
return None
|
||||||
|
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
|
||||||
|
bg_upsampler=None)
|
||||||
model.gfpgan.to(shared.device)
|
model.gfpgan.to(shared.device)
|
||||||
loaded_gfpgan_model = model
|
loaded_gfpgan_model = model
|
||||||
|
|
||||||
@ -49,10 +45,12 @@ def gfpgan():
|
|||||||
|
|
||||||
|
|
||||||
def gfpgan_fix_faces(np_image):
|
def gfpgan_fix_faces(np_image):
|
||||||
model = gfpgan()
|
model = gfpgann()
|
||||||
|
if model is None:
|
||||||
|
return np_image
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
np_image_bgr = np_image[:, :, ::-1]
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
|
||||||
|
only_center_face=False, paste_back=True)
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||||
|
|
||||||
if shared.opts.face_restoration_unload:
|
if shared.opts.face_restoration_unload:
|
||||||
@ -61,21 +59,41 @@ def gfpgan_fix_faces(np_image):
|
|||||||
return np_image
|
return np_image
|
||||||
|
|
||||||
|
|
||||||
have_gfpgan = False
|
|
||||||
gfpgan_constructor = None
|
gfpgan_constructor = None
|
||||||
|
|
||||||
def setup_gfpgan():
|
|
||||||
|
def setup_model(dirname):
|
||||||
|
global model_path
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(model_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gfpgan_model_path()
|
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.gfpgan_dir):
|
|
||||||
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
|
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
|
from facexlib import detection, parsing
|
||||||
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
have_gfpgan = True
|
|
||||||
|
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
|
||||||
|
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||||
|
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||||
|
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||||
|
|
||||||
|
def my_load_file_from_url(**kwargs):
|
||||||
|
print("Setting model_dir to " + model_path)
|
||||||
|
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
||||||
|
|
||||||
|
def facex_load_file_from_url(**kwargs):
|
||||||
|
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||||
|
|
||||||
|
def facex_load_file_from_url2(**kwargs):
|
||||||
|
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||||
|
|
||||||
|
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||||
|
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||||
|
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||||
|
user_path = dirname
|
||||||
|
print("Have gfpgan should be true?")
|
||||||
|
have_gfpgan = True
|
||||||
gfpgan_constructor = GFPGANer
|
gfpgan_constructor = GFPGANer
|
||||||
|
|
||||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||||
@ -84,7 +102,9 @@ def setup_gfpgan():
|
|||||||
|
|
||||||
def restore(self, np_image):
|
def restore(self, np_image):
|
||||||
np_image_bgr = np_image[:, :, ::-1]
|
np_image_bgr = np_image[:, :, ::-1]
|
||||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
|
||||||
|
only_center_face=False,
|
||||||
|
paste_back=True)
|
||||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||||
|
|
||||||
return np_image
|
return np_image
|
||||||
|
@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
|||||||
from fonts.ttf import Roboto
|
from fonts.ttf import Roboto
|
||||||
import string
|
import string
|
||||||
|
|
||||||
import modules.shared
|
|
||||||
from modules import sd_samplers, shared
|
from modules import sd_samplers, shared
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
|
||||||
@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||||||
cols = math.ceil((w - overlap) / non_overlap_width)
|
cols = math.ceil((w - overlap) / non_overlap_width)
|
||||||
rows = math.ceil((h - overlap) / non_overlap_height)
|
rows = math.ceil((h - overlap) / non_overlap_height)
|
||||||
|
|
||||||
dx = (w - tile_w) / (cols-1) if cols > 1 else 0
|
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
||||||
dy = (h - tile_h) / (rows-1) if rows > 1 else 0
|
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
||||||
|
|
||||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||||
for row in range(rows):
|
for row in range(rows):
|
||||||
@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|||||||
for col in range(cols):
|
for col in range(cols):
|
||||||
x = int(col * dx)
|
x = int(col * dx)
|
||||||
|
|
||||||
if x+tile_w >= w:
|
if x + tile_w >= w:
|
||||||
x = w - tile_w
|
x = w - tile_w
|
||||||
|
|
||||||
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||||
@ -85,8 +84,10 @@ def combine_grid(grid):
|
|||||||
r = r.astype(np.uint8)
|
r = r.astype(np.uint8)
|
||||||
return Image.fromarray(r, 'L')
|
return Image.fromarray(r, 'L')
|
||||||
|
|
||||||
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
mask_w = make_mask_image(
|
||||||
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||||
|
mask_h = make_mask_image(
|
||||||
|
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||||
|
|
||||||
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||||
for y, h, row in grid.tiles:
|
for y, h, row in grid.tiles:
|
||||||
@ -129,10 +130,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||||||
|
|
||||||
def draw_texts(drawing, draw_x, draw_y, lines):
|
def draw_texts(drawing, draw_x, draw_y, lines):
|
||||||
for i, line in enumerate(lines):
|
for i, line in enumerate(lines):
|
||||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
|
||||||
|
fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||||
|
|
||||||
if not line.is_active:
|
if not line.is_active:
|
||||||
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
|
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
|
||||||
|
draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||||
|
|
||||||
draw_y += line.size[1] + line_spacing
|
draw_y += line.size[1] + line_spacing
|
||||||
|
|
||||||
@ -171,7 +174,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||||||
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||||
|
|
||||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||||
|
ver_texts]
|
||||||
|
|
||||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||||
|
|
||||||
@ -202,8 +206,10 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
|||||||
prompts_horiz = prompts[:boundary]
|
prompts_horiz = prompts[:boundary]
|
||||||
prompts_vert = prompts[boundary:]
|
prompts_vert = prompts[boundary:]
|
||||||
|
|
||||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
|
||||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
range(1 << len(prompts_horiz))]
|
||||||
|
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
|
||||||
|
range(1 << len(prompts_vert))]
|
||||||
|
|
||||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||||
|
|
||||||
@ -214,7 +220,8 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
return im.resize((w, h), resample=LANCZOS)
|
return im.resize((w, h), resample=LANCZOS)
|
||||||
|
|
||||||
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
|
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
|
||||||
return upscaler.upscale(im, w, h)
|
scale = w / im.width
|
||||||
|
return upscaler.scaler.upscale(im, scale)
|
||||||
|
|
||||||
if resize_mode == 0:
|
if resize_mode == 0:
|
||||||
res = resize(im, width, height)
|
res = resize(im, width, height)
|
||||||
@ -244,11 +251,13 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
if ratio < src_ratio:
|
if ratio < src_ratio:
|
||||||
fill_height = height // 2 - src_h // 2
|
fill_height = height // 2 - src_h // 2
|
||||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
||||||
|
box=(0, fill_height + src_h))
|
||||||
elif ratio > src_ratio:
|
elif ratio > src_ratio:
|
||||||
fill_width = width // 2 - src_w // 2
|
fill_width = width // 2 - src_w // 2
|
||||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
||||||
|
box=(fill_width + src_w, 0))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -256,7 +265,7 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||||
invalid_filename_prefix = ' '
|
invalid_filename_prefix = ' '
|
||||||
invalid_filename_postfix = ' .'
|
invalid_filename_postfix = ' .'
|
||||||
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||||
max_filename_part_length = 128
|
max_filename_part_length = 128
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +292,8 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||||||
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
||||||
if len(words) == 0:
|
if len(words) == 0:
|
||||||
words = ["empty"]
|
words = ["empty"]
|
||||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
x = x.replace("[prompt_words]",
|
||||||
|
sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||||
|
|
||||||
if p is not None:
|
if p is not None:
|
||||||
x = x.replace("[steps]", str(p.steps))
|
x = x.replace("[steps]", str(p.steps))
|
||||||
@ -291,7 +301,8 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||||||
x = x.replace("[width]", str(p.width))
|
x = x.replace("[width]", str(p.width))
|
||||||
x = x.replace("[height]", str(p.height))
|
x = x.replace("[height]", str(p.height))
|
||||||
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
|
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
|
||||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
x = x.replace("[sampler]",
|
||||||
|
sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||||
|
|
||||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||||
@ -303,6 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_next_sequence_number(path, basename):
|
def get_next_sequence_number(path, basename):
|
||||||
"""
|
"""
|
||||||
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
||||||
@ -316,7 +328,8 @@ def get_next_sequence_number(path, basename):
|
|||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
l = os.path.splitext(p[prefix_length:])[0].split(
|
||||||
|
'-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(l[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -324,7 +337,10 @@ def get_next_sequence_number(path, basename):
|
|||||||
|
|
||||||
return result + 1
|
return result + 1
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
|
|
||||||
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
|
||||||
|
no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
|
||||||
|
forced_filename=None, suffix=""):
|
||||||
if short_filename or prompt is None or seed is None:
|
if short_filename or prompt is None or seed is None:
|
||||||
file_decoration = ""
|
file_decoration = ""
|
||||||
elif opts.save_to_dirs:
|
elif opts.save_to_dirs:
|
||||||
@ -361,7 +377,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
fullfn = "a.png"
|
fullfn = "a.png"
|
||||||
fullfn_without_extension = "a"
|
fullfn_without_extension = "a"
|
||||||
for i in range(500):
|
for i in range(500):
|
||||||
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
|
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||||
if not os.path.exists(fullfn):
|
if not os.path.exists(fullfn):
|
||||||
@ -403,31 +419,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
file.write(info + "\n")
|
file.write(info + "\n")
|
||||||
|
|
||||||
|
|
||||||
class Upscaler:
|
|
||||||
name = "Lanczos"
|
|
||||||
|
|
||||||
def do_upscale(self, img):
|
|
||||||
return img
|
|
||||||
|
|
||||||
def upscale(self, img, w, h):
|
|
||||||
for i in range(3):
|
|
||||||
if img.width >= w and img.height >= h:
|
|
||||||
break
|
|
||||||
|
|
||||||
img = self.do_upscale(img)
|
|
||||||
|
|
||||||
if img.width != w or img.height != h:
|
|
||||||
img = img.resize((int(w), int(h)), resample=LANCZOS)
|
|
||||||
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerNone(Upscaler):
|
|
||||||
name = "None"
|
|
||||||
|
|
||||||
def upscale(self, img, w, h):
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
modules.shared.sd_upscalers.append(UpscalerNone())
|
|
||||||
modules.shared.sd_upscalers.append(Upscaler())
|
|
||||||
|
@ -1,67 +1,45 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.images
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.ldsr_model_arch import LDSR
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths import script_path
|
from modules.paths import models_path
|
||||||
|
|
||||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
|
||||||
|
|
||||||
ldsr_models = []
|
|
||||||
have_ldsr = False
|
|
||||||
LDSR_obj = None
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(modules.images.Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
def __init__(self, steps):
|
def __init__(self, user_path):
|
||||||
self.steps = steps
|
|
||||||
self.name = "LDSR"
|
self.name = "LDSR"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.user_path = user_path
|
||||||
|
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||||
|
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||||
|
super().__init__()
|
||||||
|
scaler_data = UpscalerData("LDSR", None, self)
|
||||||
|
self.scalers = [scaler_data]
|
||||||
|
|
||||||
def do_upscale(self, img):
|
def load_model(self, path: str):
|
||||||
return upscale_with_ldsr(img)
|
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||||
|
file_name="model.pth", progress=True)
|
||||||
|
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||||
|
file_name="project.yaml", progress=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return LDSR(model, yaml)
|
||||||
|
|
||||||
def add_lsdr():
|
except Exception:
|
||||||
modules.shared.sd_upscalers.append(UpscalerLDSR(100))
|
print("Error importing LDSR:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def do_upscale(self, img, path):
|
||||||
def setup_ldsr():
|
ldsr = self.load_model(path)
|
||||||
path = modules.paths.paths.get("LDSR", None)
|
if ldsr is None:
|
||||||
if path is None:
|
print("NO LDSR!")
|
||||||
return
|
return img
|
||||||
global have_ldsr
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
global LDSR_obj
|
pre_scale = shared.opts.ldsr_pre_down
|
||||||
try:
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
from LDSR import LDSR
|
|
||||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
|
||||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
|
||||||
repo_path = 'latent-diffusion/experiments/pretrained_models/'
|
|
||||||
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
|
|
||||||
progress=True, file_name="model.chkpt")
|
|
||||||
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
|
|
||||||
progress=True, file_name="project.yaml")
|
|
||||||
have_ldsr = True
|
|
||||||
LDSR_obj = LDSR(model_path, yaml_path)
|
|
||||||
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
print("Error importing LDSR:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
have_ldsr = False
|
|
||||||
|
|
||||||
|
|
||||||
def upscale_with_ldsr(image):
|
|
||||||
setup_ldsr()
|
|
||||||
if not have_ldsr or LDSR_obj is None:
|
|
||||||
return image
|
|
||||||
|
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
|
||||||
pre_scale = shared.opts.ldsr_pre_down
|
|
||||||
post_scale = shared.opts.ldsr_post_down
|
|
||||||
|
|
||||||
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
|
|
||||||
return image
|
|
||||||
|
225
modules/ldsr_model_arch.py
Normal file
225
modules/ldsr_model_arch.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import gc
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.util import instantiate_from_config, ismap
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
|
||||||
|
# Create LDSR Class
|
||||||
|
class LDSR:
|
||||||
|
def load_model_from_config(self, half_attention):
|
||||||
|
print(f"Loading model from {self.modelPath}")
|
||||||
|
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
config = OmegaConf.load(self.yamlPath)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
model.cuda()
|
||||||
|
if half_attention:
|
||||||
|
model = model.half()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
return {"model": model}
|
||||||
|
|
||||||
|
def __init__(self, model_path, yaml_path):
|
||||||
|
self.modelPath = model_path
|
||||||
|
self.yamlPath = yaml_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(model, selected_path, custom_steps, eta):
|
||||||
|
example = get_cond(selected_path)
|
||||||
|
|
||||||
|
n_runs = 1
|
||||||
|
guider = None
|
||||||
|
ckwargs = None
|
||||||
|
ddim_use_x0_pred = False
|
||||||
|
temperature = 1.
|
||||||
|
eta = eta
|
||||||
|
custom_shape = None
|
||||||
|
|
||||||
|
height, width = example["image"].shape[1:3]
|
||||||
|
split_input = height >= 128 and width >= 128
|
||||||
|
|
||||||
|
if split_input:
|
||||||
|
ks = 128
|
||||||
|
stride = 64
|
||||||
|
vqf = 4 #
|
||||||
|
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||||
|
"vqf": vqf,
|
||||||
|
"patch_distributed_vq": True,
|
||||||
|
"tie_braker": False,
|
||||||
|
"clip_max_weight": 0.5,
|
||||||
|
"clip_min_weight": 0.01,
|
||||||
|
"clip_max_tie_weight": 0.5,
|
||||||
|
"clip_min_tie_weight": 0.01}
|
||||||
|
else:
|
||||||
|
if hasattr(model, "split_input_params"):
|
||||||
|
delattr(model, "split_input_params")
|
||||||
|
|
||||||
|
x_t = None
|
||||||
|
logs = None
|
||||||
|
for n in range(n_runs):
|
||||||
|
if custom_shape is not None:
|
||||||
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
|
|
||||||
|
logs = make_convolutional_sample(example, model,
|
||||||
|
custom_steps=custom_steps,
|
||||||
|
eta=eta, quantize_x0=False,
|
||||||
|
custom_shape=custom_shape,
|
||||||
|
temperature=temperature, noise_dropout=0.,
|
||||||
|
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||||
|
ddim_use_x0_pred=ddim_use_x0_pred
|
||||||
|
)
|
||||||
|
return logs
|
||||||
|
|
||||||
|
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||||
|
model = self.load_model_from_config(half_attention)
|
||||||
|
|
||||||
|
# Run settings
|
||||||
|
diffusion_steps = int(steps)
|
||||||
|
eta = 1.0
|
||||||
|
|
||||||
|
down_sample_method = 'Lanczos'
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
im_og = image
|
||||||
|
width_og, height_og = im_og.size
|
||||||
|
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||||
|
print("Foo")
|
||||||
|
down_sample_rate = target_scale / 4
|
||||||
|
print(f"Downsample rate is {down_sample_rate}")
|
||||||
|
wd = width_og * down_sample_rate
|
||||||
|
hd = height_og * down_sample_rate
|
||||||
|
width_downsampled_pre = int(wd)
|
||||||
|
height_downsampled_pre = int(hd)
|
||||||
|
|
||||||
|
if down_sample_rate != 1:
|
||||||
|
print(
|
||||||
|
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||||
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
|
else:
|
||||||
|
print(f"Down sample rate is 1 from {target_scale} / 4")
|
||||||
|
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||||
|
|
||||||
|
sample = logs["sample"]
|
||||||
|
sample = sample.detach().cpu()
|
||||||
|
sample = torch.clamp(sample, -1., 1.)
|
||||||
|
sample = (sample + 1.) / 2. * 255
|
||||||
|
sample = sample.numpy().astype(np.uint8)
|
||||||
|
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||||
|
a = Image.fromarray(sample[0])
|
||||||
|
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print(f'Processing finished!')
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def get_cond(selected_path):
|
||||||
|
example = dict()
|
||||||
|
up_f = 4
|
||||||
|
c = selected_path.convert('RGB')
|
||||||
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
|
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||||
|
antialias=True)
|
||||||
|
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||||
|
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||||
|
c = 2. * c - 1.
|
||||||
|
|
||||||
|
c = c.to(torch.device("cuda"))
|
||||||
|
example["LR_image"] = c
|
||||||
|
example["image"] = c_up
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||||
|
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||||
|
corrector_kwargs=None, x_t=None
|
||||||
|
):
|
||||||
|
ddim = DDIMSampler(model)
|
||||||
|
bs = shape[0]
|
||||||
|
shape = shape[1:]
|
||||||
|
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||||
|
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||||
|
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||||
|
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||||
|
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
|
log = dict()
|
||||||
|
|
||||||
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
|
return_first_stage_outputs=True,
|
||||||
|
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||||
|
and model.cond_stage_key == 'coordinates_bbox'),
|
||||||
|
return_original_cond=True)
|
||||||
|
|
||||||
|
if custom_shape is not None:
|
||||||
|
z = torch.randn(custom_shape)
|
||||||
|
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||||
|
|
||||||
|
z0 = None
|
||||||
|
|
||||||
|
log["input"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
|
||||||
|
if ismap(xc):
|
||||||
|
log["original_conditioning"] = model.to_rgb(xc)
|
||||||
|
if hasattr(model, 'cond_stage_key'):
|
||||||
|
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||||
|
if model.cond_stage_model:
|
||||||
|
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||||
|
if model.cond_stage_key == 'class_label':
|
||||||
|
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||||
|
|
||||||
|
with model.ema_scope("Plotting"):
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||||
|
eta=eta,
|
||||||
|
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||||
|
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||||
|
x_t=x_T)
|
||||||
|
t1 = time.time()
|
||||||
|
|
||||||
|
if ddim_use_x0_pred:
|
||||||
|
sample = intermediates['pred_x0'][-1]
|
||||||
|
|
||||||
|
x_sample = model.decode_first_stage(sample)
|
||||||
|
|
||||||
|
try:
|
||||||
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
|
log["sample_noquant"] = x_sample_noquant
|
||||||
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
log["sample"] = x_sample
|
||||||
|
log["time"] = t1 - t0
|
||||||
|
|
||||||
|
return log
|
133
modules/modelloader.py
Normal file
133
modules/modelloader.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import importlib
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.upscaler import Upscaler
|
||||||
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
||||||
|
"""
|
||||||
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
|
|
||||||
|
@param download_name: Specify to download from model_url immediately.
|
||||||
|
@param model_url: If no other models are found, this will be downloaded on upscale.
|
||||||
|
@param model_path: The location to store/find models in.
|
||||||
|
@param command_path: A command-line argument to search for models in first.
|
||||||
|
@param ext_filter: An optional list of filename extensions to filter by
|
||||||
|
@return: A list of paths containing the desired model(s)
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
|
||||||
|
if ext_filter is None:
|
||||||
|
ext_filter = []
|
||||||
|
try:
|
||||||
|
places = []
|
||||||
|
if command_path is not None and command_path != model_path:
|
||||||
|
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||||
|
if os.path.exists(pretrained_path):
|
||||||
|
print(f"Appending path: {pretrained_path}")
|
||||||
|
places.append(pretrained_path)
|
||||||
|
elif os.path.exists(command_path):
|
||||||
|
places.append(command_path)
|
||||||
|
places.append(model_path)
|
||||||
|
for place in places:
|
||||||
|
if os.path.exists(place):
|
||||||
|
for file in os.listdir(place):
|
||||||
|
full_path = os.path.join(place, file)
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
continue
|
||||||
|
if len(ext_filter) != 0:
|
||||||
|
model_name, extension = os.path.splitext(file)
|
||||||
|
if extension not in ext_filter:
|
||||||
|
continue
|
||||||
|
if file not in output:
|
||||||
|
output.append(full_path)
|
||||||
|
if model_url is not None and len(output) == 0:
|
||||||
|
if download_name is not None:
|
||||||
|
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||||
|
output.append(dl)
|
||||||
|
else:
|
||||||
|
output.append(model_url)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def friendly_name(file: str):
|
||||||
|
if "http" in file:
|
||||||
|
file = urlparse(file).path
|
||||||
|
|
||||||
|
file = os.path.basename(file)
|
||||||
|
model_name, extension = os.path.splitext(file)
|
||||||
|
model_name = model_name.replace("_", " ").title()
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_models():
|
||||||
|
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
|
||||||
|
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
|
||||||
|
# somehow auto-register and just do these things...
|
||||||
|
root_path = script_path
|
||||||
|
src_path = models_path
|
||||||
|
dest_path = os.path.join(models_path, "Stable-diffusion")
|
||||||
|
move_files(src_path, dest_path, ".ckpt")
|
||||||
|
src_path = os.path.join(root_path, "ESRGAN")
|
||||||
|
dest_path = os.path.join(models_path, "ESRGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "gfpgan")
|
||||||
|
dest_path = os.path.join(models_path, "GFPGAN")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "SwinIR")
|
||||||
|
dest_path = os.path.join(models_path, "SwinIR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
|
||||||
|
dest_path = os.path.join(models_path, "LDSR")
|
||||||
|
move_files(src_path, dest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||||
|
try:
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
os.makedirs(dest_path)
|
||||||
|
if os.path.exists(src_path):
|
||||||
|
for file in os.listdir(src_path):
|
||||||
|
fullpath = os.path.join(src_path, file)
|
||||||
|
if os.path.isfile(fullpath):
|
||||||
|
if ext_filter is not None:
|
||||||
|
if ext_filter not in file:
|
||||||
|
continue
|
||||||
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
|
try:
|
||||||
|
shutil.move(fullpath, dest_path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if len(os.listdir(src_path)) == 0:
|
||||||
|
print(f"Removing empty folder: {src_path}")
|
||||||
|
shutil.rmtree(src_path, True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def load_upscalers():
|
||||||
|
datas = []
|
||||||
|
for cls in Upscaler.__subclasses__():
|
||||||
|
name = cls.__name__
|
||||||
|
module_name = cls.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
class_ = getattr(module, name)
|
||||||
|
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
||||||
|
opt_string = None
|
||||||
|
try:
|
||||||
|
opt_string = shared.opts.__getattr__(cmd_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
scaler = class_(opt_string)
|
||||||
|
for child in scaler.scalers:
|
||||||
|
datas.append(child)
|
||||||
|
|
||||||
|
shared.sd_upscalers = datas
|
@ -3,9 +3,10 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
models_path = os.path.join(script_path, "models")
|
||||||
sys.path.insert(0, script_path)
|
sys.path.insert(0, script_path)
|
||||||
|
|
||||||
# search for directory of stable diffsuion in following palces
|
# search for directory of stable diffusion in following places
|
||||||
sd_path = None
|
sd_path = None
|
||||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||||
for possible_sd_path in possible_sd_paths:
|
for possible_sd_path in possible_sd_paths:
|
||||||
|
@ -1,119 +1,139 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
import modules.images
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
from modules.paths import models_path
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
|
|
||||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
|
||||||
realesrgan_models = []
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
have_realesrgan = False
|
def __init__(self, path):
|
||||||
|
self.name = "RealESRGAN"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.user_path = path
|
||||||
|
super().__init__()
|
||||||
|
try:
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
self.enable = True
|
||||||
|
self.scalers = []
|
||||||
|
scalers = self.load_models(path)
|
||||||
|
for scaler in scalers:
|
||||||
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
self.enable = False
|
||||||
|
self.scalers = []
|
||||||
|
|
||||||
|
def do_upscale(self, img, path):
|
||||||
|
if not self.enable:
|
||||||
|
return img
|
||||||
|
|
||||||
|
info = self.load_model(path)
|
||||||
|
if not os.path.exists(info.data_path):
|
||||||
|
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||||
|
return img
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=info.scale,
|
||||||
|
model_path=info.data_path,
|
||||||
|
model=info.model(),
|
||||||
|
half=not cmd_opts.no_half,
|
||||||
|
tile=opts.ESRGAN_tile,
|
||||||
|
tile_pad=opts.ESRGAN_tile_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||||
|
|
||||||
|
image = Image.fromarray(upsampled)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def load_model(self, path):
|
||||||
|
try:
|
||||||
|
info = None
|
||||||
|
for scaler in self.scalers:
|
||||||
|
if scaler.data_path == path:
|
||||||
|
info = scaler
|
||||||
|
|
||||||
|
if info is None:
|
||||||
|
print(f"Unable to find model info: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||||
|
info.data_path = model_file
|
||||||
|
return info
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_models(self, _):
|
||||||
|
return get_realesrgan_models(self)
|
||||||
|
|
||||||
|
|
||||||
def get_realesrgan_models():
|
def get_realesrgan_models(scaler):
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
models = [
|
models = [
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN General x4x3",
|
name="R-ESRGAN General 4xV3",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
|
||||||
netscale=4,
|
".pth",
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
scale=4,
|
||||||
|
upscaler=scaler,
|
||||||
|
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||||
|
act_type='prelu')
|
||||||
),
|
),
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN General WDN x4x3",
|
name="R-ESRGAN General WDN 4xV3",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||||
netscale=4,
|
scale=4,
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
upscaler=scaler,
|
||||||
|
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||||
|
act_type='prelu')
|
||||||
),
|
),
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN AnimeVideo",
|
name="R-ESRGAN AnimeVideo",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||||
netscale=4,
|
scale=4,
|
||||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
upscaler=scaler,
|
||||||
|
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
|
||||||
|
act_type='prelu')
|
||||||
),
|
),
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN 4x plus",
|
name="R-ESRGAN 4x+",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
netscale=4,
|
scale=4,
|
||||||
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||||
),
|
),
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN 4x plus anime 6B",
|
name="R-ESRGAN 4x+ Anime6B",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
netscale=4,
|
scale=4,
|
||||||
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
),
|
),
|
||||||
RealesrganModelInfo(
|
UpscalerData(
|
||||||
name="Real-ESRGAN 2x plus",
|
name="R-ESRGAN 2x+",
|
||||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
netscale=2,
|
scale=2,
|
||||||
|
upscaler=scaler,
|
||||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
|
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
class UpscalerRealESRGAN(modules.images.Upscaler):
|
|
||||||
def __init__(self, upscaling, model_index):
|
|
||||||
self.upscaling = upscaling
|
|
||||||
self.model_index = model_index
|
|
||||||
self.name = realesrgan_models[model_index].name
|
|
||||||
|
|
||||||
def do_upscale(self, img):
|
|
||||||
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_realesrgan():
|
|
||||||
global realesrgan_models
|
|
||||||
global have_realesrgan
|
|
||||||
|
|
||||||
try:
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
||||||
|
|
||||||
realesrgan_models = get_realesrgan_models()
|
|
||||||
have_realesrgan = True
|
|
||||||
|
|
||||||
for i, model in enumerate(realesrgan_models):
|
|
||||||
if model.name in opts.realesrgan_enabled_models:
|
|
||||||
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
|
|
||||||
have_realesrgan = False
|
|
||||||
|
|
||||||
|
|
||||||
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
|
||||||
if not have_realesrgan:
|
|
||||||
return image
|
|
||||||
|
|
||||||
info = realesrgan_models[RealESRGAN_model_index]
|
|
||||||
|
|
||||||
model = info.model()
|
|
||||||
upsampler = RealESRGANer(
|
|
||||||
scale=info.netscale,
|
|
||||||
model_path=info.location,
|
|
||||||
model=model,
|
|
||||||
half=not cmd_opts.no_half,
|
|
||||||
tile=opts.ESRGAN_tile,
|
|
||||||
tile_pad=opts.ESRGAN_tile_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
|
||||||
|
|
||||||
image = Image.fromarray(upsampled)
|
|
||||||
return image
|
|
||||||
|
@ -8,7 +8,14 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared, modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
model_dir = "Stable-diffusion"
|
||||||
|
model_path = os.path.join(models_path, model_dir)
|
||||||
|
model_name = "sd-v1-4.ckpt"
|
||||||
|
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
|
||||||
|
user_dir = None
|
||||||
|
|
||||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
|
||||||
checkpoints_list = {}
|
checkpoints_list = {}
|
||||||
@ -23,14 +30,47 @@ except Exception:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def modeltitle(path, h):
|
||||||
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
if abspath.startswith(model_dir):
|
||||||
|
name = abspath.replace(model_dir, '')
|
||||||
|
else:
|
||||||
|
name = os.path.basename(path)
|
||||||
|
|
||||||
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
|
name = name[1:]
|
||||||
|
|
||||||
|
return f'{name} [{h}]'
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(dirname):
|
||||||
|
global model_path
|
||||||
|
global model_name
|
||||||
|
global model_url
|
||||||
|
global user_dir
|
||||||
|
global model_list
|
||||||
|
user_dir = dirname
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(model_path)
|
||||||
|
checkpoints_list.clear()
|
||||||
|
list_models()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles():
|
||||||
return sorted([x.title for x in checkpoints_list.values()])
|
return sorted([x.title for x in checkpoints_list.values()])
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
|
global model_path
|
||||||
|
global model_url
|
||||||
|
global model_name
|
||||||
|
global user_dir
|
||||||
checkpoints_list.clear()
|
checkpoints_list.clear()
|
||||||
|
model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
|
||||||
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
|
ext_filter=[".ckpt"], download_name=model_name)
|
||||||
|
print(f"Model list: {model_list}")
|
||||||
|
model_dir = os.path.abspath(model_path)
|
||||||
|
|
||||||
def modeltitle(path, h):
|
def modeltitle(path, h):
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
@ -53,13 +93,11 @@ def list_models():
|
|||||||
title, model_name = modeltitle(cmd_ckpt, h)
|
title, model_name = modeltitle(cmd_ckpt, h)
|
||||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
|
||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
for filename in model_list:
|
||||||
if os.path.exists(model_dir):
|
h = model_hash(filename)
|
||||||
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
title = modeltitle(filename, h)
|
||||||
h = model_hash(filename)
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
||||||
title, model_name = modeltitle(filename, h)
|
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closet_checkpoint_match(searchString):
|
||||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||||
@ -69,6 +107,7 @@ def get_closet_checkpoint_match(searchString):
|
|||||||
|
|
||||||
def model_hash(filename):
|
def model_hash(filename):
|
||||||
try:
|
try:
|
||||||
|
print(f"Opening: {filename}")
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
import hashlib
|
import hashlib
|
||||||
m = hashlib.sha256()
|
m = hashlib.sha256()
|
||||||
@ -89,7 +128,7 @@ def select_checkpoint():
|
|||||||
if len(checkpoints_list) == 0:
|
if len(checkpoints_list) == 0:
|
||||||
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
|
||||||
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
|
||||||
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
|
print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
|
||||||
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
|
@ -1,26 +1,28 @@
|
|||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tqdm
|
import tqdm
|
||||||
import datetime
|
|
||||||
|
|
||||||
import modules.artists
|
import modules.artists
|
||||||
from modules.paths import script_path, sd_path
|
|
||||||
from modules.devices import get_optimal_device
|
|
||||||
import modules.styles
|
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
|
import modules.styles
|
||||||
|
from modules.devices import get_optimal_device
|
||||||
|
from modules.paths import script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
|
model_path = os.path.join(script_path, 'models')
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
# This should be deprecated, but we'll leave it for a few iterations
|
||||||
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", )
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
@ -34,8 +36,14 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
|
|||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||||
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
|
||||||
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
|
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
|
||||||
|
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
||||||
|
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
||||||
|
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
||||||
|
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
|
||||||
|
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
||||||
|
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
|
||||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
|
||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||||
@ -53,7 +61,10 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
|||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
|
||||||
cmd_opts = parser.parse_args()
|
cmd_opts = parser.parse_args()
|
||||||
|
if cmd_opts.ckpt_dir is not None:
|
||||||
|
print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be "
|
||||||
|
"removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.")
|
||||||
|
cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir)
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
|
||||||
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
|
||||||
@ -61,6 +72,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
|
|||||||
|
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
|
||||||
|
|
||||||
class State:
|
class State:
|
||||||
interrupted = False
|
interrupted = False
|
||||||
job = ""
|
job = ""
|
||||||
@ -95,13 +107,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
|
|||||||
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
# This was moved to webui.py with the other model "setup" calls.
|
||||||
modules.sd_models.list_models()
|
# modules.sd_models.list_models()
|
||||||
|
|
||||||
|
|
||||||
def realesrgan_models_names():
|
def realesrgan_models_names():
|
||||||
import modules.realesrgan_model
|
import modules.realesrgan_model
|
||||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
|
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||||
|
|
||||||
|
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
@ -167,13 +179,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||||
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
|
||||||
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
|
||||||
|
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
@ -1,123 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import cv2
|
|
||||||
import os
|
|
||||||
import contextlib
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import modules.images
|
|
||||||
from modules.shared import cmd_opts, opts, device
|
|
||||||
from modules.swinir_arch import SwinIR as net
|
|
||||||
|
|
||||||
precision_scope = (
|
|
||||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(filename, scale=4):
|
|
||||||
model = net(
|
|
||||||
upscale=scale,
|
|
||||||
in_chans=3,
|
|
||||||
img_size=64,
|
|
||||||
window_size=8,
|
|
||||||
img_range=1.0,
|
|
||||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
|
||||||
embed_dim=240,
|
|
||||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
|
||||||
mlp_ratio=2,
|
|
||||||
upsampler="nearest+conv",
|
|
||||||
resi_connection="3conv",
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained_model = torch.load(filename)
|
|
||||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
|
||||||
if not cmd_opts.no_half:
|
|
||||||
model = model.half()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(dirname):
|
|
||||||
for file in os.listdir(dirname):
|
|
||||||
path = os.path.join(dirname, file)
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
|
|
||||||
if extension != ".pt" and extension != ".pth":
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
|
|
||||||
except Exception:
|
|
||||||
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def upscale(
|
|
||||||
img,
|
|
||||||
model,
|
|
||||||
tile=opts.SWIN_tile,
|
|
||||||
tile_overlap=opts.SWIN_tile_overlap,
|
|
||||||
window_size=8,
|
|
||||||
scale=4,
|
|
||||||
):
|
|
||||||
img = np.array(img)
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
img = img.unsqueeze(0).to(device)
|
|
||||||
with torch.no_grad(), precision_scope("cuda"):
|
|
||||||
_, _, h_old, w_old = img.size()
|
|
||||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
|
||||||
w_pad = (w_old // window_size + 1) * window_size - w_old
|
|
||||||
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
|
||||||
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
|
||||||
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
|
||||||
output = output[..., : h_old * scale, : w_old * scale]
|
|
||||||
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
||||||
if output.ndim == 3:
|
|
||||||
output = np.transpose(
|
|
||||||
output[[2, 1, 0], :, :], (1, 2, 0)
|
|
||||||
) # CHW-RGB to HCW-BGR
|
|
||||||
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
|
||||||
return Image.fromarray(output, "RGB")
|
|
||||||
|
|
||||||
|
|
||||||
def inference(img, model, tile, tile_overlap, window_size, scale):
|
|
||||||
# test the image tile by tile
|
|
||||||
b, c, h, w = img.size()
|
|
||||||
tile = min(tile, h, w)
|
|
||||||
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
|
||||||
sf = scale
|
|
||||||
|
|
||||||
stride = tile - tile_overlap
|
|
||||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
|
||||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
|
||||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
|
||||||
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
|
||||||
|
|
||||||
for h_idx in h_idx_list:
|
|
||||||
for w_idx in w_idx_list:
|
|
||||||
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
|
|
||||||
out_patch = model(in_patch)
|
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
|
||||||
|
|
||||||
E[
|
|
||||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
|
||||||
].add_(out_patch)
|
|
||||||
W[
|
|
||||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
|
||||||
].add_(out_patch_mask)
|
|
||||||
output = E.div_(W)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class UpscalerSwin(modules.images.Upscaler):
|
|
||||||
def __init__(self, filename, title):
|
|
||||||
self.name = title
|
|
||||||
self.model = load_model(filename)
|
|
||||||
|
|
||||||
def do_upscale(self, img):
|
|
||||||
model = self.model.to(device)
|
|
||||||
img = upscale(img, model)
|
|
||||||
return img
|
|
139
modules/swinir_model.py
Normal file
139
modules/swinir_model.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
|
from modules import modelloader
|
||||||
|
from modules.paths import models_path
|
||||||
|
from modules.shared import cmd_opts, opts, device
|
||||||
|
from modules.swinir_model_arch import SwinIR as net
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
precision_scope = (
|
||||||
|
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerSwinIR(Upscaler):
|
||||||
|
def __init__(self, dirname):
|
||||||
|
self.name = "SwinIR"
|
||||||
|
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
||||||
|
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
||||||
|
"-L_x4_GAN.pth "
|
||||||
|
self.model_name = "SwinIR 4x"
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
self.user_path = dirname
|
||||||
|
super().__init__()
|
||||||
|
scalers = []
|
||||||
|
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
|
for model in model_files:
|
||||||
|
if "http" in model:
|
||||||
|
name = self.model_name
|
||||||
|
else:
|
||||||
|
name = modelloader.friendly_name(model)
|
||||||
|
model_data = UpscalerData(name, model, self)
|
||||||
|
scalers.append(model_data)
|
||||||
|
self.scalers = scalers
|
||||||
|
|
||||||
|
def do_upscale(self, img, model_file):
|
||||||
|
model = self.load_model(model_file)
|
||||||
|
if model is None:
|
||||||
|
return img
|
||||||
|
model = model.to(device)
|
||||||
|
img = upscale(img, model)
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_model(self, path, scale=4):
|
||||||
|
if "http" in path:
|
||||||
|
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
||||||
|
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
||||||
|
else:
|
||||||
|
filename = path
|
||||||
|
if filename is None or not os.path.exists(filename):
|
||||||
|
return None
|
||||||
|
model = net(
|
||||||
|
upscale=scale,
|
||||||
|
in_chans=3,
|
||||||
|
img_size=64,
|
||||||
|
window_size=8,
|
||||||
|
img_range=1.0,
|
||||||
|
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||||
|
embed_dim=240,
|
||||||
|
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||||
|
mlp_ratio=2,
|
||||||
|
upsampler="nearest+conv",
|
||||||
|
resi_connection="3conv",
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained_model = torch.load(filename)
|
||||||
|
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||||
|
if not cmd_opts.no_half:
|
||||||
|
model = model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def upscale(
|
||||||
|
img,
|
||||||
|
model,
|
||||||
|
tile=opts.SWIN_tile,
|
||||||
|
tile_overlap=opts.SWIN_tile_overlap,
|
||||||
|
window_size=8,
|
||||||
|
scale=4,
|
||||||
|
):
|
||||||
|
img = np.array(img)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = np.moveaxis(img, 2, 0) / 255
|
||||||
|
img = torch.from_numpy(img).float()
|
||||||
|
img = img.unsqueeze(0).to(device)
|
||||||
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
|
_, _, h_old, w_old = img.size()
|
||||||
|
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||||
|
w_pad = (w_old // window_size + 1) * window_size - w_old
|
||||||
|
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
|
||||||
|
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
|
||||||
|
output = inference(img, model, tile, tile_overlap, window_size, scale)
|
||||||
|
output = output[..., : h_old * scale, : w_old * scale]
|
||||||
|
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||||
|
if output.ndim == 3:
|
||||||
|
output = np.transpose(
|
||||||
|
output[[2, 1, 0], :, :], (1, 2, 0)
|
||||||
|
) # CHW-RGB to HCW-BGR
|
||||||
|
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||||
|
return Image.fromarray(output, "RGB")
|
||||||
|
|
||||||
|
|
||||||
|
def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||||
|
# test the image tile by tile
|
||||||
|
b, c, h, w = img.size()
|
||||||
|
tile = min(tile, h, w)
|
||||||
|
assert tile % window_size == 0, "tile size should be a multiple of window_size"
|
||||||
|
sf = scale
|
||||||
|
|
||||||
|
stride = tile - tile_overlap
|
||||||
|
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||||
|
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||||
|
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
|
||||||
|
W = torch.zeros_like(E, dtype=torch.half, device=device)
|
||||||
|
|
||||||
|
for h_idx in h_idx_list:
|
||||||
|
for w_idx in w_idx_list:
|
||||||
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
|
out_patch = model(in_patch)
|
||||||
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
||||||
|
E[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch)
|
||||||
|
W[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch_mask)
|
||||||
|
output = E.div_(W)
|
||||||
|
|
||||||
|
return output
|
121
modules/upscaler.py
Normal file
121
modules/upscaler.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.shared
|
||||||
|
from modules import modelloader, shared
|
||||||
|
|
||||||
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
from modules.paths import models_path
|
||||||
|
|
||||||
|
|
||||||
|
class Upscaler:
|
||||||
|
name = None
|
||||||
|
model_path = None
|
||||||
|
model_name = None
|
||||||
|
model_url = None
|
||||||
|
enable = True
|
||||||
|
filter = None
|
||||||
|
model = None
|
||||||
|
user_path = None
|
||||||
|
scalers: []
|
||||||
|
tile = True
|
||||||
|
|
||||||
|
def __init__(self, create_dirs=False):
|
||||||
|
self.mod_pad_h = None
|
||||||
|
self.tile_size = modules.shared.opts.ESRGAN_tile
|
||||||
|
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
||||||
|
self.device = modules.shared.device
|
||||||
|
self.img = None
|
||||||
|
self.output = None
|
||||||
|
self.scale = 1
|
||||||
|
self.half = not modules.shared.cmd_opts.no_half
|
||||||
|
self.pre_pad = 0
|
||||||
|
self.mod_scale = None
|
||||||
|
if self.name is not None and create_dirs:
|
||||||
|
self.model_path = os.path.join(models_path, self.name)
|
||||||
|
if not os.path.exists(self.model_path):
|
||||||
|
os.makedirs(self.model_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
self.can_tile = True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||||
|
return img
|
||||||
|
|
||||||
|
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||||
|
self.scale = scale
|
||||||
|
dest_w = img.width * scale
|
||||||
|
dest_h = img.height * scale
|
||||||
|
for i in range(3):
|
||||||
|
if img.width >= dest_w and img.height >= dest_h:
|
||||||
|
break
|
||||||
|
img = self.do_upscale(img, selected_model)
|
||||||
|
if img.width != dest_w or img.height != dest_h:
|
||||||
|
img = img.resize(dest_w, dest_h, resample=LANCZOS)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, path: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def find_models(self, ext_filter=None) -> list:
|
||||||
|
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||||
|
|
||||||
|
def update_status(self, prompt):
|
||||||
|
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerData:
|
||||||
|
name = None
|
||||||
|
data_path = None
|
||||||
|
scale: int = 4
|
||||||
|
scaler: Upscaler = None
|
||||||
|
model: None
|
||||||
|
|
||||||
|
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||||
|
self.name = name
|
||||||
|
self.data_path = path
|
||||||
|
self.scaler = upscaler
|
||||||
|
self.scale = scale
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerNone(Upscaler):
|
||||||
|
name = "None"
|
||||||
|
scalers = []
|
||||||
|
|
||||||
|
def load_model(self, path):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model=None):
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __init__(self, dirname=None):
|
||||||
|
super().__init__(False)
|
||||||
|
self.scalers = [UpscalerData("None", None, self)]
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerLanczos(Upscaler):
|
||||||
|
scalers = []
|
||||||
|
|
||||||
|
def do_upscale(self, img, selected_model=None):
|
||||||
|
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
||||||
|
|
||||||
|
def load_model(self, _):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, dirname=None):
|
||||||
|
super().__init__(False)
|
||||||
|
self.name = "Lanczos"
|
||||||
|
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||||
|
|
48
webui.py
48
webui.py
@ -3,36 +3,36 @@ import threading
|
|||||||
|
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
import threading
|
||||||
from modules.shared import opts, cmd_opts, state
|
import modules.paths
|
||||||
import modules.shared as shared
|
import modules.codeformer_model as codeformer
|
||||||
import modules.ui
|
import modules.esrgan_model as esrgan
|
||||||
|
import modules.bsrgan_model as bsrgan
|
||||||
|
import modules.extras
|
||||||
|
import modules.face_restoration
|
||||||
|
import modules.gfpgan_model as gfpgan
|
||||||
|
import modules.img2img
|
||||||
|
import modules.ldsr_model as ldsr
|
||||||
|
import modules.lowvram
|
||||||
|
import modules.realesrgan_model as realesrgan
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
import modules.codeformer_model
|
|
||||||
import modules.gfpgan_model
|
|
||||||
import modules.face_restoration
|
|
||||||
import modules.realesrgan_model as realesrgan
|
|
||||||
import modules.esrgan_model as esrgan
|
|
||||||
import modules.ldsr_model as ldsr
|
|
||||||
import modules.extras
|
|
||||||
import modules.lowvram
|
|
||||||
import modules.txt2img
|
|
||||||
import modules.img2img
|
|
||||||
import modules.swinir as swinir
|
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
|
import modules.shared as shared
|
||||||
|
import modules.swinir_model as swinir
|
||||||
|
import modules.txt2img
|
||||||
|
import modules.ui
|
||||||
|
from modules import modelloader
|
||||||
|
from modules.paths import script_path
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
|
||||||
|
modelloader.cleanup_models()
|
||||||
modules.codeformer_model.setup_codeformer()
|
modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path)
|
||||||
modules.gfpgan_model.setup_gfpgan()
|
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||||
|
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
|
modelloader.load_upscalers()
|
||||||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
|
||||||
swinir.load_models(cmd_opts.swinir_models_path)
|
|
||||||
realesrgan.setup_realesrgan()
|
|
||||||
ldsr.add_lsdr()
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user