diff --git a/modules/face_restoration_utils.py b/modules/face_restoration_utils.py index 85cb30570..1cbac2364 100644 --- a/modules/face_restoration_utils.py +++ b/modules/face_restoration_utils.py @@ -17,6 +17,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: + """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" + assert img.shape[2] == 3, "image must be RGB" + if img.dtype == "float64": + img = img.astype("float32") + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return torch.from_numpy(img.transpose(2, 0, 1)).float() + + +def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: + """ + Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. + """ + tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) + assert tensor.dim() == 3, "tensor must be RGB" + img_np = tensor.numpy().transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image, no RGB/BGR required + return np.squeeze(img_np, axis=2) + return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) + + def create_face_helper(device) -> FaceRestoreHelper: from facexlib.detection import retinaface from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -43,7 +65,6 @@ def restore_with_face_helper( `restore_face` should take a cropped face image and return a restored face image. """ - from basicsr.utils import img2tensor, tensor2img from torchvision.transforms.functional import normalize np_image = np_image[:, :, ::-1] original_resolution = np_image.shape[0:2] @@ -56,23 +77,19 @@ def restore_with_face_helper( face_helper.align_warp_face() logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) for cropped_face in face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) try: with torch.no_grad(): - restored_face = tensor2img( - restore_face(cropped_face_t), - rgb2bgr=True, - min_max=(-1, 1), - ) + cropped_face_t = restore_face(cropped_face_t) devices.torch_gc() except Exception: errors.report('Failed face-restoration inference', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') + restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) + restored_face = (restored_face * 255.0).astype('uint8') face_helper.add_restored_face(restored_face) logger.debug("Merging restored faces into image") diff --git a/requirements.txt b/requirements.txt index b1329c9e3..731a1be7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ GitPython Pillow accelerate -basicsr blendmodes clean-fid einops diff --git a/requirements_versions.txt b/requirements_versions.txt index edbb6db9e..1e0ccafa7 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,7 +1,6 @@ GitPython==3.1.32 Pillow==9.5.0 accelerate==0.21.0 -basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1