Browse Source

Update cog

main
pharmapsychotic 1 year ago
parent
commit
bc07ce62c1
  1. 20
      cog.yaml
  2. 19
      predict.py

20
cog.yaml

@ -1,20 +1,16 @@
build:
gpu: true
cuda: "11.6"
python_version: "3.8"
cuda: "11.8"
python_version: "3.10"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==8.4.0"
- "fairscale==0.4.12"
- "transformers==4.21.2"
- "ftfy==6.1.1"
- "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "open_clip_torch==2.7.0"
- "timm==0.4.12"
- "pycocoevalcap==1.2"
- "git+https://github.com/pharmapsychotic/BLIP.git"
- "Pillow==10.0.0"
- "safetensors==0.3.3"
- "tqdm==4.66.1"
- "open_clip_torch==2.20.0"
- "accelerate==0.22.0"
- "transformers==4.33.1"
predict: "predict.py:Predictor"

19
predict.py

@ -2,13 +2,12 @@ import sys
from PIL import Image
from cog import BasePredictor, Input, Path
from clip_interrogator import Interrogator, Config
from clip_interrogator import Config, Interrogator
class Predictor(BasePredictor):
def setup(self):
self.ci = Interrogator(Config(
blip_model_url='cache/model_large_caption.pth',
clip_model_name="ViT-L-14/openai",
clip_model_path='cache',
device='cuda:0',
@ -19,23 +18,27 @@ class Predictor(BasePredictor):
image: Path = Input(description="Input image"),
clip_model_name: str = Input(
default="ViT-L-14/openai",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"],
description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b_b160k"],
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
),
mode: str = Input(
default="best",
choices=["best", "fast"],
choices=["best", "classic", "fast", "negative"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
),
) -> str:
"""Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name)
if mode == "best":
if mode == 'best':
return self.ci.interrogate(image)
else:
elif mode == 'classic':
return self.ci.interrogate_classic(image)
elif mode == 'fast':
return self.ci.interrogate_fast(image)
elif mode == 'negative':
return self.ci.interrogate_negative(image)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name

Loading…
Cancel
Save