mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-06 15:15:05 +08:00
load_spandrel_model: always return a model descriptor
This commit is contained in:
parent
3be9074031
commit
c0ca6348e8
@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import importlib
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,6 +11,8 @@ import torch
|
|||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import spandrel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -142,17 +145,17 @@ def load_spandrel_model(
|
|||||||
half: bool = False,
|
half: bool = False,
|
||||||
dtype: str | None = None,
|
dtype: str | None = None,
|
||||||
expected_architecture: str | None = None,
|
expected_architecture: str | None = None,
|
||||||
):
|
) -> spandrel.ModelDescriptor:
|
||||||
import spandrel
|
import spandrel
|
||||||
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||||
if expected_architecture and model.architecture != expected_architecture:
|
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
|
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||||
)
|
)
|
||||||
if half:
|
if half:
|
||||||
model = model.model.half()
|
model_descriptor.model.half()
|
||||||
if dtype:
|
if dtype:
|
||||||
model = model.model.to(dtype=dtype)
|
model_descriptor.model.to(dtype=dtype)
|
||||||
model.eval()
|
model_descriptor.model.eval()
|
||||||
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
|
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
|
||||||
return model
|
return model_descriptor
|
||||||
|
Loading…
Reference in New Issue
Block a user