From ed2bee0399aaadd2a99d14c2b4af74c8dec0cf2c Mon Sep 17 00:00:00 2001 From: Kirill Korikov Date: Tue, 7 Feb 2023 09:36:40 -0800 Subject: [PATCH] add fastapi --- .gitignore | 1 + requirements.txt | 3 ++- run_fastapi.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 run_fastapi.py diff --git a/.gitignore b/.gitignore index 18cc6bb..c0d995d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ clip-interrogator/ clip_interrogator.egg-info/ dist/ venv/ +images/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ced3937..3415d9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ Pillow requests tqdm open_clip_torch -blip-vit \ No newline at end of file +blip-vit +fastapi[all] \ No newline at end of file diff --git a/run_fastapi.py b/run_fastapi.py new file mode 100644 index 0000000..30bfc39 --- /dev/null +++ b/run_fastapi.py @@ -0,0 +1,42 @@ +from clip_interrogator import Interrogator, Config +from fastapi import FastAPI +from PIL import Image +import requests +import torch + +app = FastAPI() +images = {} + + +def inference(ci: Interrogator, image: Image, mode: str): + image = image.convert('RGB') + if mode == 'best': + return ci.interrogate(image) + elif mode == 'classic': + return ci.interrogate_classic(image) + else: + return ci.interrogate_fast(image) + + +@app.on_event("startup") +async def startup_event(): + if not torch.cuda.is_available(): + print("CUDA is not available, using CPU. Warning: this will be very slow!") + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + config = Config(device=device, clip_model_name='ViT-L-14/openai') + ci = Interrogator(config) + + +@app.get("/image2prompt/{image_path}") +async def image2prompt(image_path: str): + print(image_path) + if str(image_path).startswith('http://') or str(image_path).startswith('https://'): + image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') + else: + image = Image.open(image_path).convert('RGB') + if not image: + print(f'Error opening image {image_path}') + exit(1) + print(inference(ci, image, args.mode)) +