import sys from PIL import Image from cog import BasePredictor, Input, Path from clip_interrogator import Config, Interrogator class Predictor(BasePredictor): def setup(self): self.ci = Interrogator(Config( clip_model_name="ViT-L-14/openai", clip_model_path='cache', device='cuda:0', )) def predict( self, image: Path = Input(description="Input image"), clip_model_name: str = Input( default="ViT-L-14/openai", choices=["ViT-L-14/openai", "ViT-H-14/laion2b_s32b_b79k", "ViT-bigG-14/laion2b_s39b_b160k"], description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.", ), mode: str = Input( default="best", choices=["best", "classic", "fast", "negative"], description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", ), ) -> str: """Run a single prediction on the model""" image = Image.open(str(image)).convert("RGB") self.switch_model(clip_model_name) if mode == 'best': return self.ci.interrogate(image) elif mode == 'classic': return self.ci.interrogate_classic(image) elif mode == 'fast': return self.ci.interrogate_fast(image) elif mode == 'negative': return self.ci.interrogate_negative(image) def switch_model(self, clip_model_name: str): if clip_model_name != self.ci.config.clip_model_name: self.ci.config.clip_model_name = clip_model_name self.ci.load_clip_model()