Browse Source

0.3.1 fix for running on cpu, update readme usage instructions

pull/34/head
pharmapsychotic 2 years ago
parent
commit
152d5f551f
  1. 1
      .gitignore
  2. 7
      README.md
  3. 2
      clip_interrogator/__init__.py
  4. 6
      clip_interrogator/clip_interrogator.py
  5. 6
      run_cli.py
  6. 4
      run_gradio.py
  7. 2
      setup.py

1
.gitignore vendored

@ -3,6 +3,7 @@
.vscode/ .vscode/
bench/ bench/
cache/ cache/
ci_env/
clip-interrogator/ clip-interrogator/
clip_interrogator.egg-info/ clip_interrogator.egg-info/
dist/ dist/

7
README.md

@ -26,11 +26,16 @@ The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [C
Create and activate a Python virtual environment Create and activate a Python virtual environment
```bash ```bash
python3 -m venv ci_env python3 -m venv ci_env
source ci_env/bin/activate (for linux ) source ci_env/bin/activate
(for windows) .\ci_env\Scripts\activate
``` ```
Install with PIP Install with PIP
``` ```
# install torch with GPU support for example:
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
# install blip and clip-interrogator
pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip
pip install clip-interrogator pip install clip-interrogator
``` ```

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.2.0' __version__ = '0.3.1'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

6
clip_interrogator/clip_interrogator.py

@ -81,12 +81,12 @@ class Interrogator():
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name, clip_model_name,
pretrained=clip_model_pretrained_name, pretrained=clip_model_pretrained_name,
precision='fp16', precision='fp16' if config.device == 'cuda' else 'fp32',
device=config.device, device=config.device,
jit=False, jit=False,
cache_dir=config.clip_model_path cache_dir=config.clip_model_path
) )
self.clip_model.half().to(config.device).eval() self.clip_model.to(config.device).eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
self.clip_preprocess = config.clip_preprocess self.clip_preprocess = config.clip_preprocess
@ -256,6 +256,8 @@ class LabelTable():
if data.get('hash') == hash: if data.get('hash') == hash:
self.labels = data['labels'] self.labels = data['labels']
self.embeds = data['embeds'] self.embeds = data['embeds']
if self.device == 'cpu':
self.embeds = [e.astype(np.float32) for e in self.embeds]
except Exception as e: except Exception as e:
print(f"Error loading cached table {desc}: {e}") print(f"Error loading cached table {desc}: {e}")

6
run_cli.py

@ -40,8 +40,12 @@ def main():
print(f" available models: {models}") print(f" available models: {models}")
exit(1) exit(1)
# generate a nice prompt # select device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
# generate a nice prompt
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config) ci = Interrogator(config)

4
run_gradio.py

@ -2,12 +2,16 @@
import argparse import argparse
import gradio as gr import gradio as gr
import open_clip import open_clip
import torch
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link') parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args() args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA is not available, using CPU. Warning: this will be very slow!")
ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.2.0", version="0.3.1",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save