You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
2.9 KiB
102 lines
2.9 KiB
from typing import List |
|
|
|
from tungstenkit import BaseIO, Field, Image, Option, define_model |
|
|
|
from clip_interrogator import Config, Interrogator |
|
|
|
CLIP_MODEL_NAMES = [ |
|
"ViT-L-14/openai", |
|
"ViT-H-14/laion2b_s32b_b79k", |
|
"ViT-bigG-14/laion2b_s39b_b160k", |
|
] |
|
|
|
|
|
class Input(BaseIO): |
|
input_image: Image = Field(description="Input image") |
|
clip_model_name: str = Option( |
|
default="ViT-L-14/openai", |
|
choices=[ |
|
"ViT-L-14/openai", |
|
"ViT-H-14/laion2b_s32b_b79k", |
|
"ViT-bigG-14/laion2b_s39b_b160k", |
|
], |
|
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.", |
|
) |
|
mode: str = Option( |
|
default="best", |
|
choices=["best", "classic", "fast", "negative"], |
|
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", |
|
) |
|
|
|
|
|
class Output(BaseIO): |
|
interrogated: str |
|
|
|
|
|
@define_model( |
|
input=Input, |
|
output=Output, |
|
gpu=True, |
|
cuda_version="11.8", |
|
python_version="3.10", |
|
system_packages=["libgl1-mesa-glx", "libglib2.0-0"], |
|
python_packages=[ |
|
"safetensors==0.3.3", |
|
"tqdm==4.66.1", |
|
"open_clip_torch==2.20.0", |
|
"accelerate==0.22.0", |
|
"transformers==4.33.1", |
|
], |
|
batch_size=1, |
|
) |
|
class CLIPInterrogator: |
|
@staticmethod |
|
def post_build(): |
|
"""Download weights""" |
|
ci = Interrogator( |
|
Config( |
|
clip_model_name="ViT-L-14/openai", |
|
clip_model_path="cache", |
|
device="cpu", |
|
) |
|
) |
|
for clip_model_name in CLIP_MODEL_NAMES: |
|
ci.config.clip_model_name = clip_model_name |
|
ci.load_clip_model() |
|
|
|
def setup(self): |
|
"""Load weights""" |
|
self.ci = Interrogator( |
|
Config( |
|
clip_model_name="ViT-L-14/openai", |
|
clip_model_path="cache", |
|
device="cuda:0", |
|
) |
|
) |
|
|
|
def predict(self, inputs: List[Input]) -> str: |
|
"""Run a single prediction on the model""" |
|
input = inputs[0] |
|
image = input.input_image |
|
clip_model_name = input.clip_model_name |
|
mode = input.mode |
|
|
|
image = image.to_pil_image() |
|
self.switch_model(clip_model_name) |
|
if mode == "best": |
|
ret = self.ci.interrogate(image) |
|
elif mode == "classic": |
|
ret = self.ci.interrogate_classic(image) |
|
elif mode == "fast": |
|
ret = self.ci.interrogate_fast(image) |
|
elif mode == "negative": |
|
ret = self.ci.interrogate_negative(image) |
|
else: |
|
raise RuntimeError(f"Unknown mode: {ret}") |
|
|
|
return [Output(interrogated=ret)] |
|
|
|
def switch_model(self, clip_model_name: str): |
|
if clip_model_name != self.ci.config.clip_model_name: |
|
self.ci.config.clip_model_name = clip_model_name |
|
self.ci.load_clip_model()
|
|
|