diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index b8b24af..31094ba 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -5,6 +5,7 @@ import numpy as np import open_clip import os import pickle +import time import torch from dataclasses import dataclass @@ -32,6 +33,7 @@ class Config: # clip settings clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' + clip_model_path: str = None # interrogator settings cache_path: str = 'cache' @@ -65,12 +67,25 @@ class Interrogator(): else: self.blip_model = config.blip_model + self.load_clip_model() + + def load_clip_model(self): + start_time = time.time() + config = self.config + if config.clip_model is None: if not config.quiet: print("Loading CLIP model...") clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) - self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name) + self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( + clip_model_name, + pretrained=clip_model_pretrained_name, + precision='fp16', + device=config.device, + jit=False, + cache_dir=config.clip_model_path + ) self.clip_model.half().to(config.device).eval() else: self.clip_model = config.clip_model @@ -93,6 +108,10 @@ class Interrogator(): self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) + end_time = time.time() + if not config.quiet: + print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") + def generate_caption(self, pil_image: Image) -> str: if self.config.blip_offload: self.blip_model = self.blip_model.to(self.device) diff --git a/cog.yaml b/cog.yaml index a9f9a8e..20c411a 100644 --- a/cog.yaml +++ b/cog.yaml @@ -1,6 +1,6 @@ build: gpu: true - cuda: "11.3" + cuda: "11.6" python_version: "3.8" system_packages: - "libgl1-mesa-glx" @@ -10,11 +10,12 @@ build: - "fairscale==0.4.12" - "transformers==4.21.2" - "ftfy==6.1.1" - - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116" + - "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116" - "open_clip_torch==2.7.0" + - "timm==0.4.12" + - "pycocoevalcap==1.2" run: - - pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip - - mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt" + - git clone https://github.com/salesforce/BLIP /root/blip predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py index 4452b90..a8bc923 100644 --- a/predict.py +++ b/predict.py @@ -2,17 +2,38 @@ import sys from PIL import Image from cog import BasePredictor, Input, Path -sys.path.extend(["src/clip", "src/blip"]) +sys.path.append('/root/blip') from clip_interrogator import Interrogator, Config class Predictor(BasePredictor): def setup(self): - config = Config(device="cuda:0", clip_model_name='ViT-L-14/openai') - self.ci = Interrogator(config) + self.ci = Interrogator(Config( + blip_model_url='cache/model_large_caption.pth', + clip_model_name="ViT-L-14/openai", + clip_model_path='cache', + device='cuda:0', + )) - def predict(self, image: Path = Input(description="Input image")) -> str: + def predict( + self, + image: Path = Input(description="Input image"), + clip_model_name: str = Input( + default="ViT-L-14/openai", + choices=[ + "ViT-L-14/openai", + "ViT-H-14/laion2b_s32b_b79k", + ], + description="Choose a clip model.", + ), + ) -> str: """Run a single prediction on the model""" image = Image.open(str(image)).convert("RGB") + self.switch_model(clip_model_name) return self.ci.interrogate(image) + + def switch_model(self, clip_model_name: str): + if clip_model_name != self.ci.config.clip_model_name: + self.ci.config.clip_model_name = clip_model_name + self.ci.load_clip_model() diff --git a/run_gradio.py b/run_gradio.py index a93ea5e..cada8ce 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -8,12 +8,12 @@ parser = argparse.ArgumentParser() parser.add_argument('-s', '--share', action='store_true', help='Create a public link') args = parser.parse_args() -ci = Interrogator(Config()) +ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): - global ci if clip_model_name != ci.config.clip_model_name: - ci = Interrogator(Config(clip_model_name=clip_model_name)) + ci.config.clip_model_name = clip_model_name + ci.load_clip_model() ci.config.blip_max_length = int(blip_max_length) ci.config.blip_num_beams = int(blip_num_beams) @@ -30,7 +30,7 @@ models = ['/'.join(x) for x in open_clip.list_pretrained()] inputs = [ gr.inputs.Image(type='pil'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), - gr.Dropdown(models, value='ViT-H-14/laion2b_s32b_b79k', label='CLIP Model'), + gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), gr.Number(value=32, label='Caption Max Length'), gr.Number(value=64, label='Caption Num Beams'), ]