You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.7 KiB
52 lines
1.7 KiB
2 years ago
|
from typing import Dict, List
|
||
|
|
||
|
import torch
|
||
|
from b64_utils import b64_to_pil
|
||
|
from clip_interrogator import Config, Interrogator
|
||
|
|
||
|
DEFAULT_MODEL_NAME = "ViT-L-14/openai"
|
||
|
|
||
|
|
||
|
class Model:
|
||
|
def __init__(self, **kwargs) -> None:
|
||
|
self._data_dir = kwargs["data_dir"]
|
||
|
self._config = kwargs["config"]
|
||
|
self._secrets = kwargs["secrets"]
|
||
|
self.ci = None
|
||
|
|
||
|
def load(self):
|
||
|
self.ci = Interrogator(
|
||
|
Config(
|
||
|
clip_model_name=DEFAULT_MODEL_NAME,
|
||
|
clip_model_path="cache",
|
||
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def switch_model(self, clip_model_name: str):
|
||
|
if clip_model_name != self.ci.config.clip_model_name:
|
||
|
self.ci.config.clip_model_name = clip_model_name
|
||
|
self.ci.load_clip_model()
|
||
|
|
||
|
def inference(self, image, mode) -> str:
|
||
|
image = image.convert("RGB")
|
||
|
if mode == "best":
|
||
|
return self.ci.interrogate(image)
|
||
|
elif mode == "classic":
|
||
|
return self.ci.interrogate_classic(image)
|
||
|
elif mode == "fast":
|
||
|
return self.ci.interrogate_fast(image)
|
||
|
elif mode == "negative":
|
||
|
return self.ci.interrogate_negative(image)
|
||
|
raise ValueError(f"unsupported mode: {mode}")
|
||
|
|
||
|
def predict(self, request: Dict) -> Dict[str, List]:
|
||
|
image_b64 = request.pop("image")
|
||
|
image_fmt = request.get("format", "PNG")
|
||
|
image = b64_to_pil(image_b64, format=image_fmt)
|
||
|
mode = request.get("mode", "fast")
|
||
|
clip_model_name = request.get("clip_model_name", DEFAULT_MODEL_NAME)
|
||
|
self.switch_model(clip_model_name)
|
||
|
|
||
|
return {"caption": self.inference(image, mode)}
|