import sys from PIL import Image from cog import BasePredictor, Input, Path from clip_interrogator import Interrogator, Config class Predictor(BasePredictor): def setup(self): self.ci = Interrogator(Config( blip_model_url='cache/model_large_caption.pth', clip_model_name="ViT-L-14/openai", clip_model_path='cache', device='cuda:0', )) def predict( self, image: Path = Input(description="Input image"), clip_model_name: str = Input( default="ViT-L-14/openai", choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k"], description="Choose ViT-L for Stable Diffusion 1, and ViT-H for Stable Diffusion 2", ), mode: str = Input( default="best", choices=["best", "fast"], description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", ), ) -> str: """Run a single prediction on the model""" image = Image.open(str(image)).convert("RGB") self.switch_model(clip_model_name) if mode == "best": return self.ci.interrogate(image) else: return self.ci.interrogate_fast(image) def switch_model(self, clip_model_name: str): if clip_model_name != self.ci.config.clip_model_name: self.ci.config.clip_model_name = clip_model_name self.ci.load_clip_model()