diff --git a/README.md b/README.md index d9df433..0ffda8a 100644 --- a/README.md +++ b/README.md @@ -35,15 +35,14 @@ Install with PIP # install torch with GPU support for example: pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 -# install clip-interrogator and blip -pip install clip-interrogator==0.3.3 -pip install git+https://github.com/pharmapsychotic/BLIP.git +# install clip-interrogator +pip install clip-interrogator==0.3.5 ``` You can then use it in your script ```python from PIL import Image -from clip_interrogator import Interrogator, Config +from clip_interrogator import Config, Interrogator image = Image.open(image_path).convert('RGB') ci = Interrogator(Config(clip_model_name="ViT-L-14/openai")) print(ci.interrogate(image)) diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 925f76a..7e92186 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.3.3' +__version__ = '0.3.5' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 3315885..0944dc8 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -16,6 +16,10 @@ from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from typing import List +BLIP_MODELS = { + 'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', + 'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' +} @dataclass class Config: @@ -27,7 +31,7 @@ class Config: # blip settings blip_image_eval_size: int = 384 blip_max_length: int = 32 - blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' + blip_model_type: str = 'large' # choose between 'base' or 'large' blip_num_beams: int = 8 blip_offload: bool = False @@ -39,11 +43,10 @@ class Config: cache_path: str = 'cache' chunk_size: int = 2048 data_path: str = os.path.join(os.path.dirname(__file__), 'data') - device: str = 'cuda' if torch.cuda.is_available() else 'cpu' + device: str = ("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") flavor_intermediate_count: int = 2048 quiet: bool = False # when quiet progress bars are not shown - class Interrogator(): def __init__(self, config: Config): self.config = config @@ -56,9 +59,9 @@ class Interrogator(): configs_path = os.path.join(os.path.dirname(blip_path), 'configs') med_config = os.path.join(configs_path, 'med_config.json') blip_model = blip_decoder( - pretrained=config.blip_model_url, + pretrained=BLIP_MODELS[config.blip_model_type], image_size=config.blip_image_eval_size, - vit='large', + vit=config.blip_model_type, med_config=med_config ) blip_model.eval() diff --git a/requirements.txt b/requirements.txt index 735e90b..ced3937 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ torchvision Pillow requests tqdm -open_clip_torch \ No newline at end of file +open_clip_torch +blip-vit \ No newline at end of file diff --git a/setup.py b/setup.py index 8201892..efe11d7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.3.3", + version="0.3.5", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',