3 changed files with 45 additions and 1 deletions
@ -0,0 +1,42 @@ |
|||||||
|
from clip_interrogator import Interrogator, Config |
||||||
|
from fastapi import FastAPI |
||||||
|
from PIL import Image |
||||||
|
import requests |
||||||
|
import torch |
||||||
|
|
||||||
|
app = FastAPI() |
||||||
|
images = {} |
||||||
|
|
||||||
|
|
||||||
|
def inference(ci: Interrogator, image: Image, mode: str): |
||||||
|
image = image.convert('RGB') |
||||||
|
if mode == 'best': |
||||||
|
return ci.interrogate(image) |
||||||
|
elif mode == 'classic': |
||||||
|
return ci.interrogate_classic(image) |
||||||
|
else: |
||||||
|
return ci.interrogate_fast(image) |
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup") |
||||||
|
async def startup_event(): |
||||||
|
if not torch.cuda.is_available(): |
||||||
|
print("CUDA is not available, using CPU. Warning: this will be very slow!") |
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||||
|
|
||||||
|
config = Config(device=device, clip_model_name='ViT-L-14/openai') |
||||||
|
ci = Interrogator(config) |
||||||
|
|
||||||
|
|
||||||
|
@app.get("/image2prompt/{image_path}") |
||||||
|
async def image2prompt(image_path: str): |
||||||
|
print(image_path) |
||||||
|
if str(image_path).startswith('http://') or str(image_path).startswith('https://'): |
||||||
|
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB') |
||||||
|
else: |
||||||
|
image = Image.open(image_path).convert('RGB') |
||||||
|
if not image: |
||||||
|
print(f'Error opening image {image_path}') |
||||||
|
exit(1) |
||||||
|
print(inference(ci, image, args.mode)) |
||||||
|
|
Loading…
Reference in new issue