|
|
|
@ -9,7 +9,7 @@ import time
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
from models.blip import blip_decoder |
|
|
|
|
from models.blip import blip_decoder, BLIP_Decoder |
|
|
|
|
from PIL import Image |
|
|
|
|
from torchvision import transforms |
|
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
@ -20,7 +20,7 @@ from typing import List
|
|
|
|
|
@dataclass |
|
|
|
|
class Config: |
|
|
|
|
# models can optionally be passed in directly |
|
|
|
|
blip_model = None |
|
|
|
|
blip_model: BLIP_Decoder = None |
|
|
|
|
clip_model = None |
|
|
|
|
clip_preprocess = None |
|
|
|
|
|
|
|
|
@ -256,8 +256,6 @@ class LabelTable():
|
|
|
|
|
if data.get('hash') == hash: |
|
|
|
|
self.labels = data['labels'] |
|
|
|
|
self.embeds = data['embeds'] |
|
|
|
|
if self.device == 'cpu': |
|
|
|
|
self.embeds = [e.astype(np.float32) for e in self.embeds] |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
|
|
|
|
@ -282,6 +280,9 @@ class LabelTable():
|
|
|
|
|
"model": config.clip_model_name |
|
|
|
|
}, f) |
|
|
|
|
|
|
|
|
|
if self.device == 'cpu' or self.device == torch.device('cpu'): |
|
|
|
|
self.embeds = [e.astype(np.float32) for e in self.embeds] |
|
|
|
|
|
|
|
|
|
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: |
|
|
|
|
top_count = min(top_count, len(text_embeds)) |
|
|
|
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) |
|
|
|
|