|
|
|
@ -17,25 +17,29 @@ from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
from tqdm import tqdm |
|
|
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
|
from safetensors.numpy import load_file, save_file |
|
|
|
|
|
|
|
|
|
BLIP_MODELS = { |
|
|
|
|
"base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth", |
|
|
|
|
"large": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
CACHE_URLS_VITL = [ |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl", |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_flavors.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_negative.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.safetensors', |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
CACHE_URLS_VITH = [ |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl", |
|
|
|
|
"https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl", |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_negative.safetensors', |
|
|
|
|
'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.safetensors', |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -97,7 +101,8 @@ class Interrogator:
|
|
|
|
|
med_config=med_config, |
|
|
|
|
) |
|
|
|
|
blip_model.eval() |
|
|
|
|
blip_model = blip_model.to(config.device) |
|
|
|
|
if not self.config.blip_offload: |
|
|
|
|
blip_model = blip_model.to(config.device) |
|
|
|
|
self.blip_model = blip_model |
|
|
|
|
else: |
|
|
|
|
self.blip_model = config.blip_model |
|
|
|
@ -558,23 +563,8 @@ class LabelTable:
|
|
|
|
|
self.tokenize = tokenize |
|
|
|
|
|
|
|
|
|
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() |
|
|
|
|
|
|
|
|
|
cache_filepath = None |
|
|
|
|
if config.cache_path is not None and desc is not None: |
|
|
|
|
os.makedirs(config.cache_path, exist_ok=True) |
|
|
|
|
sanitized_name = config.clip_model_name.replace("/", "_").replace("@", "_") |
|
|
|
|
cache_filepath = os.path.join( |
|
|
|
|
config.cache_path, f"{sanitized_name}_{desc}.pkl" |
|
|
|
|
) |
|
|
|
|
if desc is not None and os.path.exists(cache_filepath): |
|
|
|
|
with open(cache_filepath, "rb") as f: |
|
|
|
|
try: |
|
|
|
|
data = pickle.load(f) |
|
|
|
|
if data.get("hash") == hash: |
|
|
|
|
self.labels = data["labels"] |
|
|
|
|
self.embeds = data["embeds"] |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
sanitized_name = self.config.clip_model_name.replace('/', '_').replace('@', '_') |
|
|
|
|
self._load_cached(desc, hash, sanitized_name) |
|
|
|
|
|
|
|
|
|
if len(self.labels) != len(self.embeds): |
|
|
|
|
self.embeds = [] |
|
|
|
@ -594,28 +584,50 @@ class LabelTable:
|
|
|
|
|
for i in range(text_features.shape[0]): |
|
|
|
|
self.embeds.append(text_features[i]) |
|
|
|
|
|
|
|
|
|
if cache_filepath is not None: |
|
|
|
|
with open(cache_filepath, "wb") as f: |
|
|
|
|
pickle.dump( |
|
|
|
|
{ |
|
|
|
|
"labels": self.labels, |
|
|
|
|
"embeds": self.embeds, |
|
|
|
|
"hash": hash, |
|
|
|
|
"model": config.clip_model_name, |
|
|
|
|
}, |
|
|
|
|
f, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if self.device == "cpu" or self.device == torch.device("cpu"): |
|
|
|
|
if desc and self.config.cache_path: |
|
|
|
|
os.makedirs(self.config.cache_path, exist_ok=True) |
|
|
|
|
cache_filepath = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
|
|
|
|
tensors = { |
|
|
|
|
"embeds": np.stack(self.embeds), |
|
|
|
|
"hash": np.array([ord(c) for c in hash], dtype=np.int8) |
|
|
|
|
} |
|
|
|
|
save_file(tensors, cache_filepath) |
|
|
|
|
|
|
|
|
|
if self.device == 'cpu' or self.device == torch.device('cpu'): |
|
|
|
|
self.embeds = [e.astype(np.float32) for e in self.embeds] |
|
|
|
|
|
|
|
|
|
def _rank( |
|
|
|
|
self, |
|
|
|
|
image_features: torch.Tensor, |
|
|
|
|
text_embeds: torch.Tensor, |
|
|
|
|
top_count: int = 1, |
|
|
|
|
reverse: bool = False, |
|
|
|
|
) -> str: |
|
|
|
|
def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: |
|
|
|
|
if self.config.cache_path is None or desc is None: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
# load from old pkl format if it exists |
|
|
|
|
cached_pkl = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.pkl") |
|
|
|
|
if os.path.exists(cached_pkl): |
|
|
|
|
with open(cached_pkl, 'rb') as f: |
|
|
|
|
try: |
|
|
|
|
data = pickle.load(f) |
|
|
|
|
if data.get('hash') == hash: |
|
|
|
|
self.labels = data['labels'] |
|
|
|
|
self.embeds = data['embeds'] |
|
|
|
|
return True |
|
|
|
|
except Exception as e: |
|
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
|
|
|
|
|
# load from new safetensors format if it exists |
|
|
|
|
cached_safetensors = os.path.join(self.config.cache_path, f"{sanitized_name}_{desc}.safetensors") |
|
|
|
|
if os.path.exists(cached_safetensors): |
|
|
|
|
tensors = load_file(cached_safetensors) |
|
|
|
|
if 'hash' in tensors and 'embeds' in tensors: |
|
|
|
|
if np.array_equal(tensors['hash'], np.array([ord(c) for c in hash], dtype=np.int8)): |
|
|
|
|
self.embeds = tensors['embeds'] |
|
|
|
|
if len(self.embeds.shape) == 2: |
|
|
|
|
self.embeds = [self.embeds[i] for i in range(self.embeds.shape[0])] |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: |
|
|
|
|
top_count = min(top_count, len(text_embeds)) |
|
|
|
|
text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to( |
|
|
|
|
self.device |
|
|
|
|