Browse Source

Fix for running on CPU

pull/34/head
pharmapsychotic 2 years ago
parent
commit
6f17fb09af
  1. 2
      README.md
  2. 2
      clip_interrogator/__init__.py
  3. 9
      clip_interrogator/clip_interrogator.py
  4. 10
      run_cli.py
  5. 2
      setup.py

2
README.md

@ -37,7 +37,7 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt
# install blip and clip-interrogator # install blip and clip-interrogator
pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
pip install clip-interrogator pip install clip-interrogator==0.3.2
``` ```
You can then use it in your script You can then use it in your script

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.3.1' __version__ = '0.3.2'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

9
clip_interrogator/clip_interrogator.py

@ -9,7 +9,7 @@ import time
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from models.blip import blip_decoder from models.blip import blip_decoder, BLIP_Decoder
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
@ -20,7 +20,7 @@ from typing import List
@dataclass @dataclass
class Config: class Config:
# models can optionally be passed in directly # models can optionally be passed in directly
blip_model = None blip_model: BLIP_Decoder = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
@ -256,8 +256,6 @@ class LabelTable():
if data.get('hash') == hash: if data.get('hash') == hash:
self.labels = data['labels'] self.labels = data['labels']
self.embeds = data['embeds'] self.embeds = data['embeds']
if self.device == 'cpu':
self.embeds = [e.astype(np.float32) for e in self.embeds]
except Exception as e: except Exception as e:
print(f"Error loading cached table {desc}: {e}") print(f"Error loading cached table {desc}: {e}")
@ -281,6 +279,9 @@ class LabelTable():
"hash": hash, "hash": hash,
"model": config.clip_model_name "model": config.clip_model_name
}, f) }, f)
if self.device == 'cpu' or self.device == torch.device('cpu'):
self.embeds = [e.astype(np.float32) for e in self.embeds]
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))

10
run_cli.py

@ -20,6 +20,7 @@ def inference(ci, image, mode):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use') parser.add_argument('-c', '--clip', default='ViT-L-14/openai', help='name of CLIP model to use')
parser.add_argument('-d', '--device', default='auto', help='device to use (auto, cuda or cpu)')
parser.add_argument('-f', '--folder', help='path to folder of images') parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
@ -41,9 +42,12 @@ def main():
exit(1) exit(1)
# select device # select device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if args.device == 'auto':
if not torch.cuda.is_available(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("CUDA is not available, using CPU. Warning: this will be very slow!") if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
else:
device = torch.device(args.device)
# generate a nice prompt # generate a nice prompt
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.3.1", version="0.3.2",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save