Browse Source

Added truss to deploy to baseten

pull/70/head
Bola Malek 2 years ago
parent
commit
db2545b737
  1. 8
      deploy_baseten.py
  2. 37
      truss/config.yaml
  3. 0
      truss/model/__init__.py
  4. 51
      truss/model/model.py
  5. 19
      truss/packages/b64_utils.py

8
deploy_baseten.py

@ -0,0 +1,8 @@
import os
import baseten
import truss
model = truss.load("./truss")
baseten.login(os.environ["BASETEN_API_KEY"])
baseten.deploy(model, model_name="CLIP Interrogator", publish=True)

37
truss/config.yaml

@ -0,0 +1,37 @@
bundled_packages_dir: packages
data_dir: data
description: null
environment_variables: {}
examples_filename: examples.yaml
external_package_dirs:
- ../clip_interrogator
input_type: Any
live_reload: false
model_class_filename: model.py
model_class_name: Model
model_framework: custom
model_metadata: {}
model_module_dir: model
model_name: null
model_type: custom
python_version: py39
requirements:
- torch
- torchvision
- Pillow
- requests
- safetensors
- tqdm
- open_clip_torch
- accelerate
- transformers>=4.27.1
resources:
cpu: 7500m
memory: 15Gi
use_gpu: true
accelerator: A10G
secrets: {}
spec_version: "2.0"
system_packages:
- libgl1-mesa-glx
- libglib2.0-0

0
truss/model/__init__.py

51
truss/model/model.py

@ -0,0 +1,51 @@
from typing import Dict, List
import torch
from b64_utils import b64_to_pil
from clip_interrogator import Config, Interrogator
DEFAULT_MODEL_NAME = "ViT-L-14/openai"
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self.ci = None
def load(self):
self.ci = Interrogator(
Config(
clip_model_name=DEFAULT_MODEL_NAME,
clip_model_path="cache",
device="cuda" if torch.cuda.is_available() else "cpu",
)
)
def switch_model(self, clip_model_name: str):
if clip_model_name != self.ci.config.clip_model_name:
self.ci.config.clip_model_name = clip_model_name
self.ci.load_clip_model()
def inference(self, image, mode) -> str:
image = image.convert("RGB")
if mode == "best":
return self.ci.interrogate(image)
elif mode == "classic":
return self.ci.interrogate_classic(image)
elif mode == "fast":
return self.ci.interrogate_fast(image)
elif mode == "negative":
return self.ci.interrogate_negative(image)
raise ValueError(f"unsupported mode: {mode}")
def predict(self, request: Dict) -> Dict[str, List]:
image_b64 = request.pop("image")
image_fmt = request.get("format", "PNG")
image = b64_to_pil(image_b64, format=image_fmt)
mode = request.get("mode", "fast")
clip_model_name = request.get("clip_model_name", DEFAULT_MODEL_NAME)
self.switch_model(clip_model_name)
return {"caption": self.inference(image, mode)}

19
truss/packages/b64_utils.py

@ -0,0 +1,19 @@
import base64
from io import BytesIO
from PIL import Image
get_preamble = lambda fmt: f"data:image/{fmt.lower()};base64,"
def pil_to_b64(pil_img, format="PNG"):
buffered = BytesIO()
pil_img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue())
return get_preamble(format) + str(img_str)[2:-1]
def b64_to_pil(b64_str, format="PNG"):
return Image.open(
BytesIO(base64.b64decode(b64_str.replace(get_preamble(format), "")))
)
Loading…
Cancel
Save