Browse Source

Merge 21392325e0 into bc07ce62c1

pull/84/merge
Ben Cherry 2 years ago committed by GitHub
parent
commit
57e8f18a49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 12
      clip_interrogator/clip_interrogator.py

12
clip_interrogator/clip_interrogator.py

@ -197,7 +197,7 @@ class Interrogator():
def image_to_features(self, image: Image) -> torch.Tensor: def image_to_features(self, image: Image) -> torch.Tensor:
self._prepare_clip() self._prepare_clip()
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
image_features = self.clip_model.encode_image(images) image_features = self.clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features return image_features
@ -257,7 +257,7 @@ class Interrogator():
def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str:
self._prepare_clip() self._prepare_clip()
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
@ -268,7 +268,7 @@ class Interrogator():
def similarity(self, image_features: torch.Tensor, text: str) -> float: def similarity(self, image_features: torch.Tensor, text: str) -> float:
self._prepare_clip() self._prepare_clip()
text_tokens = self.tokenize([text]).to(self.device) text_tokens = self.tokenize([text]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
@ -277,7 +277,7 @@ class Interrogator():
def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]:
self._prepare_clip() self._prepare_clip()
text_tokens = self.tokenize([text for text in text_array]).to(self.device) text_tokens = self.tokenize([text for text in text_array]).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
text_features = self.clip_model.encode_text(text_tokens) text_features = self.clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features @ image_features.T similarity = text_features @ image_features.T
@ -319,7 +319,7 @@ class LabelTable():
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
text_tokens = self.tokenize(chunk).to(self.device) text_tokens = self.tokenize(chunk).to(self.device)
with torch.no_grad(), torch.cuda.amp.autocast(): with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
text_features = clip_model.encode_text(text_tokens) text_features = clip_model.encode_text(text_tokens)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.half().cpu().numpy() text_features = text_features.half().cpu().numpy()
@ -373,7 +373,7 @@ class LabelTable():
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str:
top_count = min(top_count, len(text_embeds)) top_count = min(top_count, len(text_embeds))
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device)
with torch.cuda.amp.autocast(): with torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'):
similarity = image_features @ text_embeds.T similarity = image_features @ text_embeds.T
if reverse: if reverse:
similarity = -similarity similarity = -similarity

Loading…
Cancel
Save