diff --git a/modules/interrogate.py b/modules/interrogate.py index 3a09b3666..0068b81c8 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,4 +1,3 @@ -import contextlib import os import sys import traceback @@ -11,12 +10,9 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram +from modules import devices, paths, lowvram, modelloader blip_image_eval_size = 384 -blip_local_dir = os.path.join('models', 'Interrogator') -blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth') -blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' clip_model_name = 'ViT-L/14' Category = namedtuple("Category", ["name", "topn", "items"]) @@ -49,16 +45,14 @@ class InterrogateModels: def load_blip_model(self): import models.blip - if not os.path.isfile(blip_local_file): - if not os.path.isdir(blip_local_dir): - os.mkdir(blip_local_dir) + files = modelloader.load_models( + model_path=os.path.join(paths.models_path, "BLIP"), + model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', + ext_filter=[".pth"], + download_name='model_base_caption_capfilt_large.pth', + ) - print("Downloading BLIP...") - from requests import get as reqget - open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content) - print("BLIP downloaded to", blip_local_file + '.') - - blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) + blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) blip_model.eval() return blip_model