Image to prompt with BLIP and CLIP
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

102 lines
2.9 KiB

from typing import List
from tungstenkit import BaseIO, Field, Image, Option, define_model
from clip_interrogator import Config, Interrogator
CLIP_MODEL_NAMES = [
"ViT-L-14/openai",
"ViT-H-14/laion2b_s32b_b79k",
"ViT-bigG-14/laion2b_s39b_b160k",
]
class Input(BaseIO):
input_image: Image = Field(description="Input image")
clip_model_name: str = Option(
default="ViT-L-14/openai",
choices=[
"ViT-L-14/openai",
"ViT-H-14/laion2b_s32b_b79k",
"ViT-bigG-14/laion2b_s39b_b160k",
],
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
)
mode: str = Option(
default="best",
choices=["best", "classic", "fast", "negative"],
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
)
class Output(BaseIO):
interrogated: str
@define_model(
input=Input,
output=Output,
gpu=True,
cuda_version="11.8",
python_version="3.10",
system_packages=["libgl1-mesa-glx", "libglib2.0-0"],
python_packages=[
"safetensors==0.3.3",
"tqdm==4.66.1",
"open_clip_torch==2.20.0",
"accelerate==0.22.0",
"transformers==4.33.1",
],
batch_size=1,
)
class CLIPInterrogator:
@staticmethod
def post_build():
"""Download weights"""
ci = Interrogator(
Config(
clip_model_name="ViT-L-14/openai",
clip_model_path="cache",
device="cpu",
)
)
for clip_model_name in CLIP_MODEL_NAMES:
ci.config.clip_model_name = clip_model_name
ci.load_clip_model()
def setup(self):
"""Load weights"""
self.ci = Interrogator(
Config(
clip_model_name="ViT-L-14/openai",
clip_model_path="cache",
device="cuda:0",
)
)
def predict(self, inputs: List[Input]) -> str:
"""Run a single prediction on the model"""
input = inputs[0]
image = input.input_image
clip_model_name = input.clip_model_name
mode = input.mode
image = image.to_pil_image()
self.switch_model(clip_model_name)
if mode == "best":
ret = self.ci.interrogate(image)
elif mode == "classic":
ret = self.ci.interrogate_classic(image)
elif mode == "fast":
ret = self.ci.interrogate_fast(image)
elif mode == "negative":
ret = self.ci.interrogate_negative(image)
else:
raise RuntimeError(f"Unknown mode: {ret}")
return [Output(interrogated=ret)]
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()