Browse Source

Ability to swap CLIP models (takes about 5s for ViTL and 10s for ViTH), update Replicate cog

pull/18/head
pharmapsychotic 2 years ago
parent
commit
5aed16b011
  1. 21
      clip_interrogator/clip_interrogator.py
  2. 11
      cog.yaml
  3. 29
      predict.py
  4. 8
      run_gradio.py

21
clip_interrogator/clip_interrogator.py

@ -5,6 +5,7 @@ import numpy as np
import open_clip import open_clip
import os import os
import pickle import pickle
import time
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
@ -32,6 +33,7 @@ class Config:
# clip settings # clip settings
clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k'
clip_model_path: str = None
# interrogator settings # interrogator settings
cache_path: str = 'cache' cache_path: str = 'cache'
@ -65,12 +67,25 @@ class Interrogator():
else: else:
self.blip_model = config.blip_model self.blip_model = config.blip_model
self.load_clip_model()
def load_clip_model(self):
start_time = time.time()
config = self.config
if config.clip_model is None: if config.clip_model is None:
if not config.quiet: if not config.quiet:
print("Loading CLIP model...") print("Loading CLIP model...")
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2)
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name) self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision='fp16',
device=config.device,
jit=False,
cache_dir=config.clip_model_path
)
self.clip_model.half().to(config.device).eval() self.clip_model.half().to(config.device).eval()
else: else:
self.clip_model = config.clip_model self.clip_model = config.clip_model
@ -93,6 +108,10 @@ class Interrogator():
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config)
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config)
end_time = time.time()
if not config.quiet:
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.")
def generate_caption(self, pil_image: Image) -> str: def generate_caption(self, pil_image: Image) -> str:
if self.config.blip_offload: if self.config.blip_offload:
self.blip_model = self.blip_model.to(self.device) self.blip_model = self.blip_model.to(self.device)

11
cog.yaml

@ -1,6 +1,6 @@
build: build:
gpu: true gpu: true
cuda: "11.3" cuda: "11.6"
python_version: "3.8" python_version: "3.8"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
@ -10,11 +10,12 @@ build:
- "fairscale==0.4.12" - "fairscale==0.4.12"
- "transformers==4.21.2" - "transformers==4.21.2"
- "ftfy==6.1.1" - "ftfy==6.1.1"
- "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" - "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "open_clip_torch==2.7.0" - "open_clip_torch==2.7.0"
- "timm==0.4.12"
- "pycocoevalcap==1.2"
run: run:
- pip install -e git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip - git clone https://github.com/salesforce/BLIP /root/blip
- mkdir -p /root/.cache/clip && wget --output-document "/root/.cache/clip/ViT-L-14.pt" "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
predict: "predict.py:Predictor" predict: "predict.py:Predictor"

29
predict.py

@ -2,17 +2,38 @@ import sys
from PIL import Image from PIL import Image
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
sys.path.extend(["src/clip", "src/blip"]) sys.path.append('/root/blip')
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
config = Config(device="cuda:0", clip_model_name='ViT-L-14/openai') self.ci = Interrogator(Config(
self.ci = Interrogator(config) blip_model_url='cache/model_large_caption.pth',
clip_model_name="ViT-L-14/openai",
clip_model_path='cache',
device='cuda:0',
))
def predict(self, image: Path = Input(description="Input image")) -> str: def predict(
self,
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
choices=[
"ViT-L-14/openai",
"ViT-H-14/laion2b_s32b_b79k",
],
description="Choose a clip model.",
),
) -> str:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB") image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name)
return self.ci.interrogate(image) return self.ci.interrogate(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()

8
run_gradio.py

@ -8,12 +8,12 @@ parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link') parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args() args = parser.parse_args()
ci = Interrogator(Config()) ci = Interrogator(Config(cache_path="cache", clip_model_path="cache"))
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
global ci
if clip_model_name != ci.config.clip_model_name: if clip_model_name != ci.config.clip_model_name:
ci = Interrogator(Config(clip_model_name=clip_model_name)) ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
ci.config.blip_max_length = int(blip_max_length) ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams) ci.config.blip_num_beams = int(blip_num_beams)
@ -30,7 +30,7 @@ models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [ inputs = [
gr.inputs.Image(type='pil'), gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(models, value='ViT-H-14/laion2b_s32b_b79k', label='CLIP Model'), gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'), gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'), gr.Number(value=64, label='Caption Num Beams'),
] ]

Loading…
Cancel
Save