1 changed files with 102 additions and 0 deletions
@ -0,0 +1,102 @@ |
|||||||
|
from typing import List |
||||||
|
|
||||||
|
from tungstenkit import BaseIO, Field, Image, Option, define_model |
||||||
|
|
||||||
|
from clip_interrogator import Config, Interrogator |
||||||
|
|
||||||
|
CLIP_MODEL_NAMES = [ |
||||||
|
"ViT-L-14/openai", |
||||||
|
"ViT-H-14/laion2b_s32b_b79k", |
||||||
|
"ViT-bigG-14/laion2b_s39b_b160k", |
||||||
|
] |
||||||
|
|
||||||
|
|
||||||
|
class Input(BaseIO): |
||||||
|
input_image: Image = Field(description="Input image") |
||||||
|
clip_model_name: str = Option( |
||||||
|
default="ViT-L-14/openai", |
||||||
|
choices=[ |
||||||
|
"ViT-L-14/openai", |
||||||
|
"ViT-H-14/laion2b_s32b_b79k", |
||||||
|
"ViT-bigG-14/laion2b_s39b_b160k", |
||||||
|
], |
||||||
|
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.", |
||||||
|
) |
||||||
|
mode: str = Option( |
||||||
|
default="best", |
||||||
|
choices=["best", "classic", "fast", "negative"], |
||||||
|
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
class Output(BaseIO): |
||||||
|
interrogated: str |
||||||
|
|
||||||
|
|
||||||
|
@define_model( |
||||||
|
input=Input, |
||||||
|
output=Output, |
||||||
|
gpu=True, |
||||||
|
cuda_version="11.8", |
||||||
|
python_version="3.10", |
||||||
|
system_packages=["libgl1-mesa-glx", "libglib2.0-0"], |
||||||
|
python_packages=[ |
||||||
|
"safetensors==0.3.3", |
||||||
|
"tqdm==4.66.1", |
||||||
|
"open_clip_torch==2.20.0", |
||||||
|
"accelerate==0.22.0", |
||||||
|
"transformers==4.33.1", |
||||||
|
], |
||||||
|
batch_size=1, |
||||||
|
) |
||||||
|
class CLIPInterrogator: |
||||||
|
@staticmethod |
||||||
|
def post_build(): |
||||||
|
"""Download weights""" |
||||||
|
ci = Interrogator( |
||||||
|
Config( |
||||||
|
clip_model_name="ViT-L-14/openai", |
||||||
|
clip_model_path="cache", |
||||||
|
device="cpu", |
||||||
|
) |
||||||
|
) |
||||||
|
for clip_model_name in CLIP_MODEL_NAMES: |
||||||
|
ci.config.clip_model_name = clip_model_name |
||||||
|
ci.load_clip_model() |
||||||
|
|
||||||
|
def setup(self): |
||||||
|
"""Load weights""" |
||||||
|
self.ci = Interrogator( |
||||||
|
Config( |
||||||
|
clip_model_name="ViT-L-14/openai", |
||||||
|
clip_model_path="cache", |
||||||
|
device="cuda:0", |
||||||
|
) |
||||||
|
) |
||||||
|
|
||||||
|
def predict(self, inputs: List[Input]) -> str: |
||||||
|
"""Run a single prediction on the model""" |
||||||
|
input = inputs[0] |
||||||
|
image = input.input_image |
||||||
|
clip_model_name = input.clip_model_name |
||||||
|
mode = input.mode |
||||||
|
|
||||||
|
image = image.to_pil_image() |
||||||
|
self.switch_model(clip_model_name) |
||||||
|
if mode == "best": |
||||||
|
ret = self.ci.interrogate(image) |
||||||
|
elif mode == "classic": |
||||||
|
ret = self.ci.interrogate_classic(image) |
||||||
|
elif mode == "fast": |
||||||
|
ret = self.ci.interrogate_fast(image) |
||||||
|
elif mode == "negative": |
||||||
|
ret = self.ci.interrogate_negative(image) |
||||||
|
else: |
||||||
|
raise RuntimeError(f"Unknown mode: {ret}") |
||||||
|
|
||||||
|
return [Output(interrogated=ret)] |
||||||
|
|
||||||
|
def switch_model(self, clip_model_name: str): |
||||||
|
if clip_model_name != self.ci.config.clip_model_name: |
||||||
|
self.ci.config.clip_model_name = clip_model_name |
||||||
|
self.ci.load_clip_model() |
Loading…
Reference in new issue