|
|
|
@ -25,9 +25,9 @@ class Config:
|
|
|
|
|
|
|
|
|
|
# blip settings |
|
|
|
|
blip_image_eval_size: int = 384 |
|
|
|
|
blip_max_length: int = 20 |
|
|
|
|
blip_max_length: int = 32 |
|
|
|
|
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
|
blip_num_beams: int = 3 |
|
|
|
|
blip_num_beams: int = 8 |
|
|
|
|
|
|
|
|
|
# clip settings |
|
|
|
|
clip_model_name: str = 'ViT-L/14' |
|
|
|
@ -40,12 +40,6 @@ class Config:
|
|
|
|
|
flavor_intermediate_count: int = 2048 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_list(data_path, filename) -> List[str]: |
|
|
|
|
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: |
|
|
|
|
items = [line.strip() for line in f.readlines()] |
|
|
|
|
return items |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Interrogator(): |
|
|
|
|
def __init__(self, config: Config): |
|
|
|
|
self.config = config |
|
|
|
@ -110,13 +104,40 @@ class Interrogator():
|
|
|
|
|
) |
|
|
|
|
return caption[0] |
|
|
|
|
|
|
|
|
|
def interrogate(self, image: Image) -> str: |
|
|
|
|
caption = self.generate_caption(image) |
|
|
|
|
|
|
|
|
|
def image_to_features(self, image: Image) -> torch.Tensor: |
|
|
|
|
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
image_features = self.clip_model.encode_image(images).float() |
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
return image_features |
|
|
|
|
|
|
|
|
|
def interrogate_classic(self, image: Image, max_flaves: int=3) -> str: |
|
|
|
|
caption = self.generate_caption(image) |
|
|
|
|
image_features = self.image_to_features(image) |
|
|
|
|
|
|
|
|
|
medium = self.mediums.rank(image_features, 1)[0] |
|
|
|
|
artist = self.artists.rank(image_features, 1)[0] |
|
|
|
|
trending = self.trendings.rank(image_features, 1)[0] |
|
|
|
|
movement = self.movements.rank(image_features, 1)[0] |
|
|
|
|
flaves = ", ".join(self.flavors.rank(image_features, max_flaves)) |
|
|
|
|
|
|
|
|
|
if caption.startswith(medium): |
|
|
|
|
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" |
|
|
|
|
else: |
|
|
|
|
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}" |
|
|
|
|
|
|
|
|
|
return _truncate_to_fit(prompt) |
|
|
|
|
|
|
|
|
|
def interrogate_fast(self, image: Image) -> str: |
|
|
|
|
caption = self.generate_caption(image) |
|
|
|
|
image_features = self.image_to_features(image) |
|
|
|
|
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) |
|
|
|
|
tops = merged.rank(image_features, 32) |
|
|
|
|
return _truncate_to_fit(caption + ", " + ", ".join(tops)) |
|
|
|
|
|
|
|
|
|
def interrogate(self, image: Image) -> str: |
|
|
|
|
caption = self.generate_caption(image) |
|
|
|
|
image_features = self.image_to_features(image) |
|
|
|
|
|
|
|
|
|
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) |
|
|
|
|
best_medium = self.mediums.rank(image_features, 1)[0] |
|
|
|
@ -258,3 +279,25 @@ class LabelTable():
|
|
|
|
|
|
|
|
|
|
tops = self._rank(image_features, top_embeds, top_count=top_count) |
|
|
|
|
return [top_labels[i] for i in tops] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_list(data_path, filename) -> List[str]: |
|
|
|
|
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f: |
|
|
|
|
items = [line.strip() for line in f.readlines()] |
|
|
|
|
return items |
|
|
|
|
|
|
|
|
|
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: |
|
|
|
|
m = LabelTable([], None, None, config) |
|
|
|
|
for table in tables: |
|
|
|
|
m.labels.extend(table.labels) |
|
|
|
|
m.embeds.extend(table.embeds) |
|
|
|
|
return m |
|
|
|
|
|
|
|
|
|
def _truncate_to_fit(text: str) -> str: |
|
|
|
|
while True: |
|
|
|
|
try: |
|
|
|
|
_ = clip.tokenize([text]) |
|
|
|
|
return text |
|
|
|
|
except: |
|
|
|
|
text = ",".join(text.split(",")[:-1]) |
|
|
|
|
|