mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-01 12:25:06 +08:00
Merge pull request #14500 from akx/spandrel-prefer-half
Spandrel: "prefer half" instead of "force half"
This commit is contained in:
commit
7c3ab416ad
@ -139,23 +139,31 @@ def load_upscalers():
|
||||
|
||||
|
||||
def load_spandrel_model(
|
||||
path: str,
|
||||
path: str | os.PathLike,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
half: bool = False,
|
||||
prefer_half: bool = False,
|
||||
dtype: str | torch.dtype | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
) -> spandrel.ModelDescriptor:
|
||||
import spandrel
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
||||
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
||||
logger.warning(
|
||||
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
||||
)
|
||||
if half:
|
||||
model_descriptor.model.half()
|
||||
half = False
|
||||
if prefer_half:
|
||||
if model_descriptor.supports_half:
|
||||
model_descriptor.model.half()
|
||||
half = True
|
||||
else:
|
||||
logger.info("Model %s does not support half precision, ignoring --half", path)
|
||||
if dtype:
|
||||
model_descriptor.model.to(dtype=dtype)
|
||||
model_descriptor.model.eval()
|
||||
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
|
||||
logger.debug(
|
||||
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
||||
model_descriptor, path, device, half, dtype,
|
||||
)
|
||||
return model_descriptor
|
||||
|
@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
model_descriptor = modelloader.load_spandrel_model(
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
|
||||
)
|
||||
return upscale_with_model(
|
||||
|
Loading…
Reference in New Issue
Block a user