Bola Malek
2 years ago
5 changed files with 115 additions and 0 deletions
@ -0,0 +1,8 @@ |
|||||||
|
import os |
||||||
|
|
||||||
|
import baseten |
||||||
|
import truss |
||||||
|
|
||||||
|
model = truss.load("./truss") |
||||||
|
baseten.login(os.environ["BASETEN_API_KEY"]) |
||||||
|
baseten.deploy(model, model_name="CLIP Interrogator", publish=True) |
@ -0,0 +1,37 @@ |
|||||||
|
bundled_packages_dir: packages |
||||||
|
data_dir: data |
||||||
|
description: null |
||||||
|
environment_variables: {} |
||||||
|
examples_filename: examples.yaml |
||||||
|
external_package_dirs: |
||||||
|
- ../clip_interrogator |
||||||
|
input_type: Any |
||||||
|
live_reload: false |
||||||
|
model_class_filename: model.py |
||||||
|
model_class_name: Model |
||||||
|
model_framework: custom |
||||||
|
model_metadata: {} |
||||||
|
model_module_dir: model |
||||||
|
model_name: null |
||||||
|
model_type: custom |
||||||
|
python_version: py39 |
||||||
|
requirements: |
||||||
|
- torch |
||||||
|
- torchvision |
||||||
|
- Pillow |
||||||
|
- requests |
||||||
|
- safetensors |
||||||
|
- tqdm |
||||||
|
- open_clip_torch |
||||||
|
- accelerate |
||||||
|
- transformers>=4.27.1 |
||||||
|
resources: |
||||||
|
cpu: 7500m |
||||||
|
memory: 15Gi |
||||||
|
use_gpu: true |
||||||
|
accelerator: A10G |
||||||
|
secrets: {} |
||||||
|
spec_version: "2.0" |
||||||
|
system_packages: |
||||||
|
- libgl1-mesa-glx |
||||||
|
- libglib2.0-0 |
@ -0,0 +1,51 @@ |
|||||||
|
from typing import Dict, List |
||||||
|
|
||||||
|
import torch |
||||||
|
from b64_utils import b64_to_pil |
||||||
|
from clip_interrogator import Config, Interrogator |
||||||
|
|
||||||
|
DEFAULT_MODEL_NAME = "ViT-L-14/openai" |
||||||
|
|
||||||
|
|
||||||
|
class Model: |
||||||
|
def __init__(self, **kwargs) -> None: |
||||||
|
self._data_dir = kwargs["data_dir"] |
||||||
|
self._config = kwargs["config"] |
||||||
|
self._secrets = kwargs["secrets"] |
||||||
|
self.ci = None |
||||||
|
|
||||||
|
def load(self): |
||||||
|
self.ci = Interrogator( |
||||||
|
Config( |
||||||
|
clip_model_name=DEFAULT_MODEL_NAME, |
||||||
|
clip_model_path="cache", |
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu", |
||||||
|
) |
||||||
|
) |
||||||
|
|
||||||
|
def switch_model(self, clip_model_name: str): |
||||||
|
if clip_model_name != self.ci.config.clip_model_name: |
||||||
|
self.ci.config.clip_model_name = clip_model_name |
||||||
|
self.ci.load_clip_model() |
||||||
|
|
||||||
|
def inference(self, image, mode) -> str: |
||||||
|
image = image.convert("RGB") |
||||||
|
if mode == "best": |
||||||
|
return self.ci.interrogate(image) |
||||||
|
elif mode == "classic": |
||||||
|
return self.ci.interrogate_classic(image) |
||||||
|
elif mode == "fast": |
||||||
|
return self.ci.interrogate_fast(image) |
||||||
|
elif mode == "negative": |
||||||
|
return self.ci.interrogate_negative(image) |
||||||
|
raise ValueError(f"unsupported mode: {mode}") |
||||||
|
|
||||||
|
def predict(self, request: Dict) -> Dict[str, List]: |
||||||
|
image_b64 = request.pop("image") |
||||||
|
image_fmt = request.get("format", "PNG") |
||||||
|
image = b64_to_pil(image_b64, format=image_fmt) |
||||||
|
mode = request.get("mode", "fast") |
||||||
|
clip_model_name = request.get("clip_model_name", DEFAULT_MODEL_NAME) |
||||||
|
self.switch_model(clip_model_name) |
||||||
|
|
||||||
|
return {"caption": self.inference(image, mode)} |
@ -0,0 +1,19 @@ |
|||||||
|
import base64 |
||||||
|
from io import BytesIO |
||||||
|
|
||||||
|
from PIL import Image |
||||||
|
|
||||||
|
get_preamble = lambda fmt: f"data:image/{fmt.lower()};base64," |
||||||
|
|
||||||
|
|
||||||
|
def pil_to_b64(pil_img, format="PNG"): |
||||||
|
buffered = BytesIO() |
||||||
|
pil_img.save(buffered, format=format) |
||||||
|
img_str = base64.b64encode(buffered.getvalue()) |
||||||
|
return get_preamble(format) + str(img_str)[2:-1] |
||||||
|
|
||||||
|
|
||||||
|
def b64_to_pil(b64_str, format="PNG"): |
||||||
|
return Image.open( |
||||||
|
BytesIO(base64.b64decode(b64_str.replace(get_preamble(format), ""))) |
||||||
|
) |
Loading…
Reference in new issue