diff --git a/.gitignore b/.gitignore index 6e6cd38..18cc6bb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ .vscode/ bench/ cache/ +ci_env/ clip-interrogator/ clip_interrogator.egg-info/ dist/ diff --git a/README.md b/README.md index 947c475..7e029c1 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,16 @@ The **CLIP Interrogator** is a prompt engineering tool that combines OpenAI's [C Create and activate a Python virtual environment ```bash python3 -m venv ci_env -source ci_env/bin/activate +(for linux ) source ci_env/bin/activate +(for windows) .\ci_env\Scripts\activate ``` Install with PIP ``` +# install torch with GPU support for example: +pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 + +# install blip and clip-interrogator pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip pip install clip-interrogator ``` diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 5633303..6470b3b 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.3.0' +__version__ = '0.3.1' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index e7d0ca1..67fe5b8 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -81,12 +81,12 @@ class Interrogator(): self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( clip_model_name, pretrained=clip_model_pretrained_name, - precision='fp16', + precision='fp16' if config.device == 'cuda' else 'fp32', device=config.device, jit=False, cache_dir=config.clip_model_path ) - self.clip_model.half().to(config.device).eval() + self.clip_model.to(config.device).eval() else: self.clip_model = config.clip_model self.clip_preprocess = config.clip_preprocess @@ -256,6 +256,8 @@ class LabelTable(): if data.get('hash') == hash: self.labels = data['labels'] self.embeds = data['embeds'] + if self.device == 'cpu': + self.embeds = [e.astype(np.float32) for e in self.embeds] except Exception as e: print(f"Error loading cached table {desc}: {e}") diff --git a/run_cli.py b/run_cli.py index efdd18b..8d02f69 100755 --- a/run_cli.py +++ b/run_cli.py @@ -40,8 +40,12 @@ def main(): print(f" available models: {models}") exit(1) - # generate a nice prompt + # select device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if not torch.cuda.is_available(): + print("CUDA is not available, using CPU. Warning: this will be very slow!") + + # generate a nice prompt config = Config(device=device, clip_model_name=args.clip) ci = Interrogator(config) diff --git a/run_gradio.py b/run_gradio.py index cada8ce..9fc685f 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -2,12 +2,16 @@ import argparse import gradio as gr import open_clip +import torch from clip_interrogator import Interrogator, Config parser = argparse.ArgumentParser() parser.add_argument('-s', '--share', action='store_true', help='Create a public link') args = parser.parse_args() +if not torch.cuda.is_available(): + print("CUDA is not available, using CPU. Warning: this will be very slow!") + ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): diff --git a/setup.py b/setup.py index 9e82312..12f2d02 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.3.0", + version="0.3.1", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',