|
|
|
@ -71,7 +71,7 @@ class Interrogator():
|
|
|
|
|
|
|
|
|
|
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) |
|
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name) |
|
|
|
|
self.clip_model.to(config.device).eval() |
|
|
|
|
self.clip_model.half().to(config.device).eval() |
|
|
|
|
else: |
|
|
|
|
self.clip_model = config.clip_model |
|
|
|
|
self.clip_preprocess = config.clip_preprocess |
|
|
|
@ -117,9 +117,9 @@ class Interrogator():
|
|
|
|
|
|
|
|
|
|
def image_to_features(self, image: Image) -> torch.Tensor: |
|
|
|
|
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
image_features = self.clip_model.encode_image(images).float() |
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
|
|
image_features = self.clip_model.encode_image(images) |
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
return image_features |
|
|
|
|
|
|
|
|
|
def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: |
|
|
|
@ -197,26 +197,21 @@ class Interrogator():
|
|
|
|
|
|
|
|
|
|
return best_prompt |
|
|
|
|
|
|
|
|
|
def rank_top(self, image_features, text_array: List[str]) -> str: |
|
|
|
|
def rank_top(self, image_features: torch.Tensor, text_array: List[str]) -> str: |
|
|
|
|
text_tokens = self.tokenize([text for text in text_array]).to(self.device) |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
text_features = self.clip_model.encode_text(text_tokens).float() |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
|
similarity = torch.zeros((1, len(text_array)), device=self.device) |
|
|
|
|
for i in range(image_features.shape[0]): |
|
|
|
|
similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) |
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
|
|
text_features = self.clip_model.encode_text(text_tokens) |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
similarity = text_features @ image_features.T |
|
|
|
|
return text_array[similarity.argmax().item()] |
|
|
|
|
|
|
|
|
|
_, top_labels = similarity.cpu().topk(1, dim=-1) |
|
|
|
|
return text_array[top_labels[0][0].numpy()] |
|
|
|
|
|
|
|
|
|
def similarity(self, image_features, text) -> np.float32: |
|
|
|
|
def similarity(self, image_features: torch.Tensor, text: str) -> float: |
|
|
|
|
text_tokens = self.tokenize([text]).to(self.device) |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
text_features = self.clip_model.encode_text(text_tokens).float() |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T |
|
|
|
|
return similarity[0][0] |
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
|
|
text_features = self.clip_model.encode_text(text_tokens) |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
similarity = text_features @ image_features.T |
|
|
|
|
return similarity[0][0].item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LabelTable(): |
|
|
|
@ -247,10 +242,10 @@ class LabelTable():
|
|
|
|
|
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) |
|
|
|
|
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): |
|
|
|
|
text_tokens = self.tokenize(chunk).to(self.device) |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
text_features = clip_model.encode_text(text_tokens).float() |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
text_features = text_features.half().cpu().numpy() |
|
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
|
|
|
text_features = clip_model.encode_text(text_tokens) |
|
|
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
text_features = text_features.half().cpu().numpy() |
|
|
|
|
for i in range(text_features.shape[0]): |
|
|
|
|
self.embeds.append(text_features[i]) |
|
|
|
|
|
|
|
|
@ -263,16 +258,15 @@ class LabelTable():
|
|
|
|
|
"model": config.clip_model_name |
|
|
|
|
}, f) |
|
|
|
|
|
|
|
|
|
def _rank(self, image_features, text_embeds, top_count=1): |
|
|
|
|
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1) -> str: |
|
|
|
|
top_count = min(top_count, len(text_embeds)) |
|
|
|
|
similarity = torch.zeros((1, len(text_embeds))).to(self.device) |
|
|
|
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(self.device) |
|
|
|
|
for i in range(image_features.shape[0]): |
|
|
|
|
similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1) |
|
|
|
|
_, top_labels = similarity.cpu().topk(top_count, dim=-1) |
|
|
|
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) |
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
|
|
similarity = image_features @ text_embeds.T |
|
|
|
|
_, top_labels = similarity.float().cpu().topk(top_count, dim=-1) |
|
|
|
|
return [top_labels[0][i].numpy() for i in range(top_count)] |
|
|
|
|
|
|
|
|
|
def rank(self, image_features, top_count=1) -> List[str]: |
|
|
|
|
def rank(self, image_features: torch.Tensor, top_count: int=1) -> List[str]: |
|
|
|
|
if len(self.labels) <= self.chunk_size: |
|
|
|
|
tops = self._rank(image_features, self.embeds, top_count=top_count) |
|
|
|
|
return [self.labels[i] for i in tops] |
|
|
|
@ -292,7 +286,7 @@ class LabelTable():
|
|
|
|
|
return [top_labels[i] for i in tops] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_list(data_path, filename) -> List[str]: |
|
|
|
|
def _load_list(data_path: str, filename: str) -> List[str]: |
|
|
|
|
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: |
|
|
|
|
items = [line.strip() for line in f.readlines()] |
|
|
|
|
return items |
|
|
|
|