stable-diffusion-webui/modules/codeformer_model.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

64 lines
1.8 KiB
Python
Raw Normal View History

from __future__ import annotations
import logging
2022-09-07 17:32:28 +08:00
import torch
from modules import (
devices,
errors,
face_restoration,
face_restoration_utils,
modelloader,
shared,
)
logger = logging.getLogger(__name__)
2022-09-07 17:32:28 +08:00
2022-09-26 22:29:50 +08:00
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_download_name = 'codeformer-v0.1.0.pth'
2022-09-07 17:32:28 +08:00
# used by e.g. postprocessing_codeformer.py
codeformer: face_restoration.FaceRestoration | None = None
2022-09-07 17:32:28 +08:00
2022-09-26 22:29:50 +08:00
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
def name(self):
return "CodeFormer"
def load_net(self) -> torch.Module:
for model_path in modelloader.load_models(
model_path=self.model_path,
model_url=model_url,
command_path=self.model_path,
download_name=model_download_name,
ext_filter=['.pth'],
):
return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
).model
raise ValueError("No codeformer model found")
2022-09-07 17:32:28 +08:00
def get_device(self):
return devices.device_codeformer
2022-09-07 17:32:28 +08:00
def restore(self, np_image, w: float | None = None):
if w is None:
w = getattr(shared.opts, "code_former_weight", 0.5)
2022-09-07 17:32:28 +08:00
def restore_face(cropped_face_t):
assert self.net is not None
return self.net(cropped_face_t, w=w, adain=True)[0]
2022-09-07 17:32:28 +08:00
return self.restore_with_helper(np_image, restore_face)
2022-09-07 17:32:28 +08:00
def setup_model(dirname: str) -> None:
global codeformer
try:
2022-09-26 22:29:50 +08:00
codeformer = FaceRestorerCodeFormer(dirname)
2022-09-07 18:35:02 +08:00
shared.face_restorers.append(codeformer)
2022-09-07 17:32:28 +08:00
except Exception:
errors.report("Error setting up CodeFormer", exc_info=True)