mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-17 03:40:14 +08:00
interrogate: Fix CLIP-interrogation on CPU
Currently, trying to perform CLIP interrogation on a CPU fails, saying: ``` RuntimeError: "slow_conv2d_cpu" not implemented for 'Half' ``` This merge request fixes this issue by detecting whether the target device is CPU and, if so, force-enabling `--no-half` and passing `device="cpu"` to `clip.load()` (which then does some extra tricks to ensure it works correctly on CPU).
This commit is contained in:
parent
5f4fec307c
commit
7157e5d064
@ -28,9 +28,11 @@ class InterrogateModels:
|
||||
clip_preprocess = None
|
||||
categories = None
|
||||
dtype = None
|
||||
running_on_cpu = None
|
||||
|
||||
def __init__(self, content_dir):
|
||||
self.categories = []
|
||||
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
|
||||
|
||||
if os.path.exists(content_dir):
|
||||
for filename in os.listdir(content_dir):
|
||||
@ -53,7 +55,11 @@ class InterrogateModels:
|
||||
def load_clip_model(self):
|
||||
import clip
|
||||
|
||||
if self.running_on_cpu:
|
||||
model, preprocess = clip.load(clip_model_name, device="cpu")
|
||||
else:
|
||||
model, preprocess = clip.load(clip_model_name)
|
||||
|
||||
model.eval()
|
||||
model = model.to(devices.device_interrogate)
|
||||
|
||||
@ -62,14 +68,14 @@ class InterrogateModels:
|
||||
def load(self):
|
||||
if self.blip_model is None:
|
||||
self.blip_model = self.load_blip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.blip_model = self.blip_model.half()
|
||||
|
||||
self.blip_model = self.blip_model.to(devices.device_interrogate)
|
||||
|
||||
if self.clip_model is None:
|
||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||
if not shared.cmd_opts.no_half:
|
||||
if not shared.cmd_opts.no_half and not self.running_on_cpu:
|
||||
self.clip_model = self.clip_model.half()
|
||||
|
||||
self.clip_model = self.clip_model.to(devices.device_interrogate)
|
||||
|
Loading…
Reference in New Issue
Block a user