|
|
|
@ -17,25 +17,27 @@ from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
from safetensors.numpy import load_file, save_file |
|
|
|
|
|
|
|
|
|
BLIP_MODELS = { |
|
|
|
|
'base': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth', |
|
|
|
|
'large': 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
CACHE_URLS_VITL = [ |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors', |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
CACHE_URLS_VITH = [ |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors', |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -316,21 +318,8 @@ class LabelTable():
|
|
|
|
|
self.tokenize = tokenize |
|
|
|
|
|
|
|
|
|
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() |
|
|
|
|
|
|
|
|
|
cache_filepath = None |
|
|
|
|
if config.cache_path is not None and desc is not None: |
|
|
|
|
os.makedirs(config.cache_path, exist_ok=True) |
|
|
|
|
sanitized_name = config.clip_model_name.replace('/', '_').replace('@', '_') |
|
|
|
|
cache_filepath = os.path.join(config.cache_path, f"{sanitized_name}_{desc}.pkl") |
|
|
|
|
if desc is not None and os.path.exists(cache_filepath): |
|
|
|
|
with open(cache_filepath, 'rb') as f: |
|
|
|
|
try: |
|
|
|
|
data = pickle.load(f) |
|
|
|
|
if data.get('hash') == hash: |
|
|
|
|
self.labels = data['labels'] |
|
|
|
|
self.embeds = data['embeds'] |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') |
|
|
|
|
self._load_cached(desc, hash, sanitized_name) |
|
|
|
|
|
|
|
|
|
if len(self.labels) != len(self.embeds): |
|
|
|
|
self.embeds = [] |
|
|
|
@ -344,18 +333,49 @@ class LabelTable():
|
|
|
|
|
for i in range(text_features.shape[0]): |
|
|
|
|
self.embeds.append(text_features[i]) |
|
|
|
|
|
|
|
|
|
if cache_filepath is not None: |
|
|
|
|
with open(cache_filepath, 'wb') as f: |
|
|
|
|
pickle.dump({ |
|
|
|
|
"labels": self.labels, |
|
|
|
|
"embeds": self.embeds, |
|
|
|
|
"hash": hash, |
|
|
|
|
"model": config.clip_model_name |
|
|
|
|
}, f) |
|
|
|
|
if desc and self.config.cache_path: |
|
|
|
|
os.makedirs(self.config.cache_path, exist_ok=True) |
|
|
|
|
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
|
|
|
|
tensors = { |
|
|
|
|
"embeds": np.stack(self.embeds), |
|
|
|
|
"hash": np.array([ord(c) for c in hash], dtype=np.int8) |
|
|
|
|
} |
|
|
|
|
save_file(tensors, cache_filepath) |
|
|
|
|
|
|
|
|
|
if self.device == 'cpu' or self.device == torch.device('cpu'): |
|
|
|
|
self.embeds = [e.astype(np.float32) for e in self.embeds] |
|
|
|
|
|
|
|
|
|
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: |
|
|
|
|
if self.config.cache_path is None or desc is None: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
# load from old pkl format if it exists |
|
|
|
|
cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl") |
|
|
|
|
if os.path.exists(cached_pkl): |
|
|
|
|
with open(cached_pkl, 'rb') as f: |
|
|
|
|
try: |
|
|
|
|
data = pickle.load(f) |
|
|
|
|
if data.get('hash') == hash: |
|
|
|
|
self.labels = data['labels'] |
|
|
|
|
self.embeds = data['embeds'] |
|
|
|
|
return True |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
|
|
|
|
|
# load from new safetensors format if it exists |
|
|
|
|
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
|
|
|
|
if os.path.exists(cached_safetensors): |
|
|
|
|
tensors = load_file(cached_safetensors) |
|
|
|
|
if 'hash' in tensors and 'embeds' in tensors: |
|
|
|
|
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): |
|
|
|
|
self.embeds = tensors['embeds'] |
|
|
|
|
if len(self.embeds.shape) == 2: |
|
|
|
|
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])] |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: |
|
|
|
|
top_count = min(top_count, len(text_embeds)) |
|
|
|
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) |
|
|
|
|