diff --git a/README.md b/README.md index 01da4ed..c44bb59 100644 --- a/README.md +++ b/README.md @@ -84,3 +84,37 @@ table = LabelTable(load_list('terms.txt'), 'terms', ci) best_match = table.rank(ci.image_to_features(image), top_count=1)[0] print(best_match) ``` + +## Deploying as Cloud Service (using Baseten) +This repo contains a [`truss`]("./truss"), which packages the model for cloud deployment using the [truss open-source library](https://github.com/basetenlabs/truss) by Baseten. Using this truss, you can easily deploy your own scalable cloud service of this model by following these steps. + +1. Clone the repo: `git clone https://github.com/pharmapsychotic/clip-interrogator.git` +2. `cd clip-interrogator` +3. Setup virtualenv with baseten and truss deps (make sure to upgrade) +``` +python3 -m venv .env +source .env/bin/activate +pip install --upgrade pip +pip install --upgrade baseten truss +``` +4. [Grab API key from your Baseten account](https://docs.baseten.co/settings/api-keys) +5. Deploy using this command +``` +BASETEN_API_KEY=API_KEY_COPIED_FROM_BASETEN python deploy_baseten.py +``` +6. You'll get an email once your model is ready and you can call it using the instructions from the UI. +Below is a sample invocation. +``` +import baseten, os +baseten.login(os.environ["BASETEN_API_KEY"]) + +img_str = 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABgAAAAYCAYAAADgdz34AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAApgAAAKYB3X3/OAAAABl0RVh0U29mdHdhcmUAd3d3Lmlua3NjYXBlLm9yZ5vuPBoAAANCSURBVEiJtZZPbBtFFMZ/M7ubXdtdb1xSFyeilBapySVU8h8OoFaooFSqiihIVIpQBKci6KEg9Q6H9kovIHoCIVQJJCKE1ENFjnAgcaSGC6rEnxBwA04Tx43t2FnvDAfjkNibxgHxnWb2e/u992bee7tCa00YFsffekFY+nUzFtjW0LrvjRXrCDIAaPLlW0nHL0SsZtVoaF98mLrx3pdhOqLtYPHChahZcYYO7KvPFxvRl5XPp1sN3adWiD1ZAqD6XYK1b/dvE5IWryTt2udLFedwc1+9kLp+vbbpoDh+6TklxBeAi9TL0taeWpdmZzQDry0AcO+jQ12RyohqqoYoo8RDwJrU+qXkjWtfi8Xxt58BdQuwQs9qC/afLwCw8tnQbqYAPsgxE1S6F3EAIXux2oQFKm0ihMsOF71dHYx+f3NND68ghCu1YIoePPQN1pGRABkJ6Bus96CutRZMydTl+TvuiRW1m3n0eDl0vRPcEysqdXn+jsQPsrHMquGeXEaY4Yk4wxWcY5V/9scqOMOVUFthatyTy8QyqwZ+kDURKoMWxNKr2EeqVKcTNOajqKoBgOE28U4tdQl5p5bwCw7BWquaZSzAPlwjlithJtp3pTImSqQRrb2Z8PHGigD4RZuNX6JYj6wj7O4TFLbCO/Mn/m8R+h6rYSUb3ekokRY6f/YukArN979jcW+V/S8g0eT/N3VN3kTqWbQ428m9/8k0P/1aIhF36PccEl6EhOcAUCrXKZXXWS3XKd2vc/TRBG9O5ELC17MmWubD2nKhUKZa26Ba2+D3P+4/MNCFwg59oWVeYhkzgN/JDR8deKBoD7Y+ljEjGZ0sosXVTvbc6RHirr2reNy1OXd6pJsQ+gqjk8VWFYmHrwBzW/n+uMPFiRwHB2I7ih8ciHFxIkd/3Omk5tCDV1t+2nNu5sxxpDFNx+huNhVT3/zMDz8usXC3ddaHBj1GHj/As08fwTS7Kt1HBTmyN29vdwAw+/wbwLVOJ3uAD1wi/dUH7Qei66PfyuRj4Ik9is+hglfbkbfR3cnZm7chlUWLdwmprtCohX4HUtlOcQjLYCu+fzGJH2QRKvP3UNz8bWk1qMxjGTOMThZ3kvgLI5AzFfo379UAAAAASUVORK5CYII=' + +model = baseten.deployed_model_id("MODEL_ID_FROM_ACCOUNT") +model.predict({ + "image": img_str, + "format": "PNG", + "mode": "fast", + "clip_model_name": "ViT-L-14/openai" +}) +``` \ No newline at end of file diff --git a/deploy_baseten.py b/deploy_baseten.py new file mode 100644 index 0000000..c2bc1e5 --- /dev/null +++ b/deploy_baseten.py @@ -0,0 +1,8 @@ +import os + +import baseten +import truss + +model = truss.load("./truss") +baseten.login(os.environ["BASETEN_API_KEY"]) +baseten.deploy(model, model_name="CLIP Interrogator", publish=True) diff --git a/truss/config.yaml b/truss/config.yaml new file mode 100644 index 0000000..082b765 --- /dev/null +++ b/truss/config.yaml @@ -0,0 +1,37 @@ +bundled_packages_dir: packages +data_dir: data +description: null +environment_variables: {} +examples_filename: examples.yaml +external_package_dirs: + - ../clip_interrogator +input_type: Any +live_reload: false +model_class_filename: model.py +model_class_name: Model +model_framework: custom +model_metadata: {} +model_module_dir: model +model_name: null +model_type: custom +python_version: py39 +requirements: + - torch + - torchvision + - Pillow + - requests + - safetensors + - tqdm + - open_clip_torch + - accelerate + - transformers>=4.27.1 +resources: + cpu: 7500m + memory: 15Gi + use_gpu: true + accelerator: A10G +secrets: {} +spec_version: "2.0" +system_packages: + - libgl1-mesa-glx + - libglib2.0-0 diff --git a/truss/model/__init__.py b/truss/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/truss/model/model.py b/truss/model/model.py new file mode 100644 index 0000000..4ac873a --- /dev/null +++ b/truss/model/model.py @@ -0,0 +1,51 @@ +from typing import Dict, List + +import torch +from b64_utils import b64_to_pil +from clip_interrogator import Config, Interrogator + +DEFAULT_MODEL_NAME = "ViT-L-14/openai" + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self.ci = None + + def load(self): + self.ci = Interrogator( + Config( + clip_model_name=DEFAULT_MODEL_NAME, + clip_model_path="cache", + device="cuda" if torch.cuda.is_available() else "cpu", + ) + ) + + def switch_model(self, clip_model_name: str): + if clip_model_name != self.ci.config.clip_model_name: + self.ci.config.clip_model_name = clip_model_name + self.ci.load_clip_model() + + def inference(self, image, mode) -> str: + image = image.convert("RGB") + if mode == "best": + return self.ci.interrogate(image) + elif mode == "classic": + return self.ci.interrogate_classic(image) + elif mode == "fast": + return self.ci.interrogate_fast(image) + elif mode == "negative": + return self.ci.interrogate_negative(image) + raise ValueError(f"unsupported mode: {mode}") + + def predict(self, request: Dict) -> Dict[str, List]: + image_b64 = request.pop("image") + image_fmt = request.get("format", "PNG") + image = b64_to_pil(image_b64, format=image_fmt) + mode = request.get("mode", "fast") + clip_model_name = request.get("clip_model_name", DEFAULT_MODEL_NAME) + self.switch_model(clip_model_name) + + return {"caption": self.inference(image, mode)} diff --git a/truss/packages/b64_utils.py b/truss/packages/b64_utils.py new file mode 100644 index 0000000..fe94304 --- /dev/null +++ b/truss/packages/b64_utils.py @@ -0,0 +1,19 @@ +import base64 +from io import BytesIO + +from PIL import Image + +get_preamble = lambda fmt: f"data:image/{fmt.lower()};base64," + + +def pil_to_b64(pil_img, format="PNG"): + buffered = BytesIO() + pil_img.save(buffered, format=format) + img_str = base64.b64encode(buffered.getvalue()) + return get_preamble(format) + str(img_str)[2:-1] + + +def b64_to_pil(b64_str, format="PNG"): + return Image.open( + BytesIO(base64.b64decode(b64_str.replace(get_preamble(format), ""))) + )