Drop dependency on basicsr

This commit is contained in:
Aarni Koskela 2023-12-30 17:45:26 +02:00
parent f476649c02
commit c9174253fb
3 changed files with 26 additions and 11 deletions

View File

@ -17,6 +17,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:
"""Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor."""
assert img.shape[2] == 3, "image must be RGB"
if img.dtype == "float64":
img = img.astype("float32")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return torch.from_numpy(img.transpose(2, 0, 1)).float()
def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:
"""
Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.
"""
tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
assert tensor.dim() == 3, "tensor must be RGB"
img_np = tensor.numpy().transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image, no RGB/BGR required
return np.squeeze(img_np, axis=2)
return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
def create_face_helper(device) -> FaceRestoreHelper: def create_face_helper(device) -> FaceRestoreHelper:
from facexlib.detection import retinaface from facexlib.detection import retinaface
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -43,7 +65,6 @@ def restore_with_face_helper(
`restore_face` should take a cropped face image and return a restored face image. `restore_face` should take a cropped face image and return a restored face image.
""" """
from basicsr.utils import img2tensor, tensor2img
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
np_image = np_image[:, :, ::-1] np_image = np_image[:, :, ::-1]
original_resolution = np_image.shape[0:2] original_resolution = np_image.shape[0:2]
@ -56,23 +77,19 @@ def restore_with_face_helper(
face_helper.align_warp_face() face_helper.align_warp_face()
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
for cropped_face in face_helper.cropped_faces: for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
try: try:
with torch.no_grad(): with torch.no_grad():
restored_face = tensor2img( cropped_face_t = restore_face(cropped_face_t)
restore_face(cropped_face_t),
rgb2bgr=True,
min_max=(-1, 1),
)
devices.torch_gc() devices.torch_gc()
except Exception: except Exception:
errors.report('Failed face-restoration inference', exc_info=True) errors.report('Failed face-restoration inference', exc_info=True)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8') restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))
restored_face = (restored_face * 255.0).astype('uint8')
face_helper.add_restored_face(restored_face) face_helper.add_restored_face(restored_face)
logger.debug("Merging restored faces into image") logger.debug("Merging restored faces into image")

View File

@ -2,7 +2,6 @@ GitPython
Pillow Pillow
accelerate accelerate
basicsr
blendmodes blendmodes
clean-fid clean-fid
einops einops

View File

@ -1,7 +1,6 @@
GitPython==3.1.32 GitPython==3.1.32
Pillow==9.5.0 Pillow==9.5.0
accelerate==0.21.0 accelerate==0.21.0
basicsr==1.4.2
blendmodes==2022 blendmodes==2022
clean-fid==0.1.35 clean-fid==0.1.35
einops==0.4.1 einops==0.4.1