load_spandrel_model: always return a model descriptor

This commit is contained in:
Aarni Koskela 2023-12-31 00:04:47 +02:00
parent 3be9074031
commit c0ca6348e8

View File

@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
import importlib
import logging import logging
import os import os
import importlib from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
import torch import torch
@ -10,6 +11,8 @@ import torch
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
if TYPE_CHECKING:
import spandrel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -142,17 +145,17 @@ def load_spandrel_model(
half: bool = False, half: bool = False,
dtype: str | None = None, dtype: str | None = None,
expected_architecture: str | None = None, expected_architecture: str | None = None,
): ) -> spandrel.ModelDescriptor:
import spandrel import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path) model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture: if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning( logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
) )
if half: if half:
model = model.model.half() model_descriptor.model.half()
if dtype: if dtype:
model = model.model.to(dtype=dtype) model_descriptor.model.to(dtype=dtype)
model.eval() model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
return model return model_descriptor