diff --git a/README.md b/README.md
index 01da4ed..00d07cc 100644
--- a/README.md
+++ b/README.md
@@ -8,9 +8,9 @@
-Run Version 2 on Colab, HuggingFace, and Replicate!
+Run Version 2 on Colab, HuggingFace, Replicate, and Tungsten!
-[](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [](https://replicate.com/pharmapsychotic/clip-interrogator) [](https://cloud.lambdalabs.com/demos/ml/CLIP-Interrogator)
+[](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb) [](https://huggingface.co/spaces/pharma/CLIP-Interrogator) [](https://replicate.com/pharmapsychotic/clip-interrogator) [](https://tungsten.run/mjpyeon/clip-interrogator) [](https://cloud.lambdalabs.com/demos/ml/CLIP-Interrogator)
diff --git a/tungsten_model.py b/tungsten_model.py
new file mode 100644
index 0000000..bb6e526
--- /dev/null
+++ b/tungsten_model.py
@@ -0,0 +1,102 @@
+from typing import List
+
+from tungstenkit import BaseIO, Field, Image, Option, define_model
+
+from clip_interrogator import Config, Interrogator
+
+CLIP_MODEL_NAMES = [
+ "ViT-L-14/openai",
+ "ViT-H-14/laion2b_s32b_b79k",
+ "ViT-bigG-14/laion2b_s39b_b160k",
+]
+
+
+class Input(BaseIO):
+ input_image: Image = Field(description="Input image")
+ clip_model_name: str = Option(
+ default="ViT-L-14/openai",
+ choices=[
+ "ViT-L-14/openai",
+ "ViT-H-14/laion2b_s32b_b79k",
+ "ViT-bigG-14/laion2b_s39b_b160k",
+ ],
+ description="Choose ViT-L for Stable Diffusion 1, ViT-H for Stable Diffusion 2, or ViT-bigG for Stable Diffusion XL.",
+ )
+ mode: str = Option(
+ default="best",
+ choices=["best", "classic", "fast", "negative"],
+ description="Prompt mode (best takes 10-20 seconds, fast takes 1-2 seconds).",
+ )
+
+
+class Output(BaseIO):
+ interrogated: str
+
+
+@define_model(
+ input=Input,
+ output=Output,
+ gpu=True,
+ cuda_version="11.8",
+ python_version="3.10",
+ system_packages=["libgl1-mesa-glx", "libglib2.0-0"],
+ python_packages=[
+ "safetensors==0.3.3",
+ "tqdm==4.66.1",
+ "open_clip_torch==2.20.0",
+ "accelerate==0.22.0",
+ "transformers==4.33.1",
+ ],
+ batch_size=1,
+)
+class CLIPInterrogator:
+ @staticmethod
+ def post_build():
+ """Download weights"""
+ ci = Interrogator(
+ Config(
+ clip_model_name="ViT-L-14/openai",
+ clip_model_path="cache",
+ device="cpu",
+ )
+ )
+ for clip_model_name in CLIP_MODEL_NAMES:
+ ci.config.clip_model_name = clip_model_name
+ ci.load_clip_model()
+
+ def setup(self):
+ """Load weights"""
+ self.ci = Interrogator(
+ Config(
+ clip_model_name="ViT-L-14/openai",
+ clip_model_path="cache",
+ device="cuda:0",
+ )
+ )
+
+ def predict(self, inputs: List[Input]) -> str:
+ """Run a single prediction on the model"""
+ input = inputs[0]
+ image = input.input_image
+ clip_model_name = input.clip_model_name
+ mode = input.mode
+
+ image = image.to_pil_image()
+ self.switch_model(clip_model_name)
+ if mode == "best":
+ ret = self.ci.interrogate(image)
+ elif mode == "classic":
+ ret = self.ci.interrogate_classic(image)
+ elif mode == "fast":
+ ret = self.ci.interrogate_fast(image)
+ elif mode == "negative":
+ ret = self.ci.interrogate_negative(image)
+ else:
+ raise RuntimeError(f"Unknown mode: {ret}")
+
+ return [Output(interrogated=ret)]
+
+ def switch_model(self, clip_model_name: str):
+ if clip_model_name != self.ci.config.clip_model_name:
+ self.ci.config.clip_model_name = clip_model_name
+ self.ci.load_clip_model()