1 changed files with 102 additions and 0 deletions
@ -0,0 +1,102 @@
|
||||
from typing import List |
||||
|
||||
from tungstenkit import BaseIO, Field, Image, Option, define_model |
||||
|
||||
from clip_interrogator import Config, Interrogator |
||||
|
||||
CLIP_MODEL_NAMES = [ |
||||
"ViT-L-14/openai", |
||||
"ViT-H-14/laion2b_s32b_b79k", |
||||
"ViT-bigG-14/laion2b_s39b_b160k", |
||||
] |
||||
|
||||
|
||||
class Input(BaseIO): |
||||
input_image: Image = Field(description="Input image") |
||||
clip_model_name: str = Option( |
||||
default="ViT-L-14/openai", |
||||
choices=[ |
||||
"ViT-L-14/openai", |
||||
"ViT-H-14/laion2b_s32b_b79k", |
||||
"ViT-bigG-14/laion2b_s39b_b160k", |
||||
], |
||||
description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.", |
||||
) |
||||
mode: str = Option( |
||||
default="best", |
||||
choices=["best", "classic", "fast", "negative"], |
||||
description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).", |
||||
) |
||||
|
||||
|
||||
class Output(BaseIO): |
||||
interrogated: str |
||||
|
||||
|
||||
@define_model( |
||||
input=Input, |
||||
output=Output, |
||||
gpu=True, |
||||
cuda_version="11.8", |
||||
python_version="3.10", |
||||
system_packages=["libgl1-mesa-glx", "libglib2.0-0"], |
||||
python_packages=[ |
||||
"safetensors==0.3.3", |
||||
"tqdm==4.66.1", |
||||
"open_clip_torch==2.20.0", |
||||
"accelerate==0.22.0", |
||||
"transformers==4.33.1", |
||||
], |
||||
batch_size=1, |
||||
) |
||||
class CLIPInterrogator: |
||||
@staticmethod |
||||
def post_build(): |
||||
"""Download weights""" |
||||
ci = Interrogator( |
||||
Config( |
||||
clip_model_name="ViT-L-14/openai", |
||||
clip_model_path="cache", |
||||
device="cpu", |
||||
) |
||||
) |
||||
for clip_model_name in CLIP_MODEL_NAMES: |
||||
ci.config.clip_model_name = clip_model_name |
||||
ci.load_clip_model() |
||||
|
||||
def setup(self): |
||||
"""Load weights""" |
||||
self.ci = Interrogator( |
||||
Config( |
||||
clip_model_name="ViT-L-14/openai", |
||||
clip_model_path="cache", |
||||
device="cuda:0", |
||||
) |
||||
) |
||||
|
||||
def predict(self, inputs: List[Input]) -> str: |
||||
"""Run a single prediction on the model""" |
||||
input = inputs[0] |
||||
image = input.input_image |
||||
clip_model_name = input.clip_model_name |
||||
mode = input.mode |
||||
|
||||
image = image.to_pil_image() |
||||
self.switch_model(clip_model_name) |
||||
if mode == "best": |
||||
ret = self.ci.interrogate(image) |
||||
elif mode == "classic": |
||||
ret = self.ci.interrogate_classic(image) |
||||
elif mode == "fast": |
||||
ret = self.ci.interrogate_fast(image) |
||||
elif mode == "negative": |
||||
ret = self.ci.interrogate_negative(image) |
||||
else: |
||||
raise RuntimeError(f"Unknown mode: {ret}") |
||||
|
||||
return [Output(interrogated=ret)] |
||||
|
||||
def switch_model(self, clip_model_name: str): |
||||
if clip_model_name != self.ci.config.clip_model_name: |
||||
self.ci.config.clip_model_name = clip_model_name |
||||
self.ci.load_clip_model() |
Loading…
Reference in new issue