Browse Source

Update cog

main
pharmapsychotic 1 year ago
parent
commit
bc07ce62c1
  1. 20
      cog.yaml
  2. 19
      predict.py

20
cog.yaml

@ -1,20 +1,16 @@
build: build:
gpu: true gpu: true
cuda: "11.6" cuda: "11.8"
python_version: "3.8" python_version: "3.10"
system_packages: system_packages:
- "libgl1-mesa-glx" - "libgl1-mesa-glx"
- "libglib2.0-0" - "libglib2.0-0"
python_packages: python_packages:
- "ipython==8.4.0" - "Pillow==10.0.0"
- "fairscale==0.4.12" - "safetensors==0.3.3"
- "transformers==4.21.2" - "tqdm==4.66.1"
- "ftfy==6.1.1" - "open_clip_torch==2.20.0"
- "torch==1.13.0 --extra-index-url=https://download.pytorch.org/whl/cu116" - "accelerate==0.22.0"
- "torchvision==0.14.0 --extra-index-url=https://download.pytorch.org/whl/cu116" - "transformers==4.33.1"
- "open_clip_torch==2.7.0"
- "timm==0.4.12"
- "pycocoevalcap==1.2"
- "git+https://github.com/pharmapsychotic/BLIP.git"
predict: "predict.py:Predictor" predict: "predict.py:Predictor"

19
predict.py

@ -2,13 +2,12 @@ import sys
from PIL import Image from PIL import Image
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
from clip_interrogator import Interrogator, Config from clip_interrogator import Config, Interrogator
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self): def setup(self):
self.ci = Interrogator(Config( self.ci = Interrogator(Config(
blip_model_url='cache/model_large_caption.pth',
clip_model_name="ViT-L-14/openai", clip_model_name="ViT-L-14/openai",
clip_model_path='cache', clip_model_path='cache',
device='cuda:0', device='cuda:0',
@ -19,23 +18,27 @@ class Predictor(BasePredictor):
image: Path = Input(description="Input image"), image: Path = Input(description="Input image"),
clip_model_name: str = Input( clip_model_name: str = Input(
default="ViT-L-14/openai", default="ViT-L-14/openai",
choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"], choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b_b160k"],
description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2", description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
), ),
mode: str = Input( mode: str = Input(
default="best", default="best",
choices=["best", "fast"], choices=["best", "classic", "fast", "negative"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
), ),
) -> str: ) -> str:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
image = Image.open(str(image)).convert("RGB") image = Image.open(str(image)).convert("RGB")
self.switch_model(clip_model_name) self.switch_model(clip_model_name)
if mode == "best": if mode == 'best':
return self.ci.interrogate(image) return self.ci.interrogate(image)
else: elif mode == 'classic':
return self.ci.interrogate_classic(image)
elif mode == 'fast':
return self.ci.interrogate_fast(image) return self.ci.interrogate_fast(image)
elif mode == 'negative':
return self.ci.interrogate_negative(image)
def switch_model(self, clip_model_name: str): def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name: if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name self.ci.config.clip_model_name = clip_model_name

Loading…
Cancel
Save