Bola Malek
2 years ago
5 changed files with 115 additions and 0 deletions
@ -0,0 +1,8 @@
|
||||
import os |
||||
|
||||
import baseten |
||||
import truss |
||||
|
||||
model = truss.load("./truss") |
||||
baseten.login(os.environ["BASETEN_API_KEY"]) |
||||
baseten.deploy(model, model_name="CLIP Interrogator", publish=True) |
@ -0,0 +1,37 @@
|
||||
bundled_packages_dir: packages |
||||
data_dir: data |
||||
description: null |
||||
environment_variables: {} |
||||
examples_filename: examples.yaml |
||||
external_package_dirs: |
||||
- ../clip_interrogator |
||||
input_type: Any |
||||
live_reload: false |
||||
model_class_filename: model.py |
||||
model_class_name: Model |
||||
model_framework: custom |
||||
model_metadata: {} |
||||
model_module_dir: model |
||||
model_name: null |
||||
model_type: custom |
||||
python_version: py39 |
||||
requirements: |
||||
- torch |
||||
- torchvision |
||||
- Pillow |
||||
- requests |
||||
- safetensors |
||||
- tqdm |
||||
- open_clip_torch |
||||
- accelerate |
||||
- transformers>=4.27.1 |
||||
resources: |
||||
cpu: 7500m |
||||
memory: 15Gi |
||||
use_gpu: true |
||||
accelerator: A10G |
||||
secrets: {} |
||||
spec_version: "2.0" |
||||
system_packages: |
||||
- libgl1-mesa-glx |
||||
- libglib2.0-0 |
@ -0,0 +1,51 @@
|
||||
from typing import Dict, List |
||||
|
||||
import torch |
||||
from b64_utils import b64_to_pil |
||||
from clip_interrogator import Config, Interrogator |
||||
|
||||
DEFAULT_MODEL_NAME = "ViT-L-14/openai" |
||||
|
||||
|
||||
class Model: |
||||
def __init__(self, **kwargs) -> None: |
||||
self._data_dir = kwargs["data_dir"] |
||||
self._config = kwargs["config"] |
||||
self._secrets = kwargs["secrets"] |
||||
self.ci = None |
||||
|
||||
def load(self): |
||||
self.ci = Interrogator( |
||||
Config( |
||||
clip_model_name=DEFAULT_MODEL_NAME, |
||||
clip_model_path="cache", |
||||
device="cuda" if torch.cuda.is_available() else "cpu", |
||||
) |
||||
) |
||||
|
||||
def switch_model(self, clip_model_name: str): |
||||
if clip_model_name != self.ci.config.clip_model_name: |
||||
self.ci.config.clip_model_name = clip_model_name |
||||
self.ci.load_clip_model() |
||||
|
||||
def inference(self, image, mode) -> str: |
||||
image = image.convert("RGB") |
||||
if mode == "best": |
||||
return self.ci.interrogate(image) |
||||
elif mode == "classic": |
||||
return self.ci.interrogate_classic(image) |
||||
elif mode == "fast": |
||||
return self.ci.interrogate_fast(image) |
||||
elif mode == "negative": |
||||
return self.ci.interrogate_negative(image) |
||||
raise ValueError(f"unsupported mode: {mode}") |
||||
|
||||
def predict(self, request: Dict) -> Dict[str, List]: |
||||
image_b64 = request.pop("image") |
||||
image_fmt = request.get("format", "PNG") |
||||
image = b64_to_pil(image_b64, format=image_fmt) |
||||
mode = request.get("mode", "fast") |
||||
clip_model_name = request.get("clip_model_name", DEFAULT_MODEL_NAME) |
||||
self.switch_model(clip_model_name) |
||||
|
||||
return {"caption": self.inference(image, mode)} |
@ -0,0 +1,19 @@
|
||||
import base64 |
||||
from io import BytesIO |
||||
|
||||
from PIL import Image |
||||
|
||||
get_preamble = lambda fmt: f"data:image/{fmt.lower()};base64," |
||||
|
||||
|
||||
def pil_to_b64(pil_img, format="PNG"): |
||||
buffered = BytesIO() |
||||
pil_img.save(buffered, format=format) |
||||
img_str = base64.b64encode(buffered.getvalue()) |
||||
return get_preamble(format) + str(img_str)[2:-1] |
||||
|
||||
|
||||
def b64_to_pil(b64_str, format="PNG"): |
||||
return Image.open( |
||||
BytesIO(base64.b64decode(b64_str.replace(get_preamble(format), ""))) |
||||
) |
Loading…
Reference in new issue