|
|
@ -5,6 +5,7 @@ import numpy as np |
|
|
|
import open_clip |
|
|
|
import open_clip |
|
|
|
import os |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
import pickle |
|
|
|
|
|
|
|
import time |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
from dataclasses import dataclass |
|
|
@ -32,6 +33,7 @@ class Config: |
|
|
|
|
|
|
|
|
|
|
|
# clip settings |
|
|
|
# clip settings |
|
|
|
clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' |
|
|
|
clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' |
|
|
|
|
|
|
|
clip_model_path: str = None |
|
|
|
|
|
|
|
|
|
|
|
# interrogator settings |
|
|
|
# interrogator settings |
|
|
|
cache_path: str = 'cache' |
|
|
|
cache_path: str = 'cache' |
|
|
@ -65,12 +67,25 @@ class Interrogator(): |
|
|
|
else: |
|
|
|
else: |
|
|
|
self.blip_model = config.blip_model |
|
|
|
self.blip_model = config.blip_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load_clip_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_clip_model(self): |
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
config = self.config |
|
|
|
|
|
|
|
|
|
|
|
if config.clip_model is None: |
|
|
|
if config.clip_model is None: |
|
|
|
if not config.quiet: |
|
|
|
if not config.quiet: |
|
|
|
print("Loading CLIP model...") |
|
|
|
print("Loading CLIP model...") |
|
|
|
|
|
|
|
|
|
|
|
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) |
|
|
|
clip_model_name, clip_model_pretrained_name = config.clip_model_name.split('/', 2) |
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_model_name, pretrained=clip_model_pretrained_name) |
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
|
|
|
|
|
|
|
clip_model_name, |
|
|
|
|
|
|
|
pretrained=clip_model_pretrained_name, |
|
|
|
|
|
|
|
precision='fp16', |
|
|
|
|
|
|
|
device=config.device, |
|
|
|
|
|
|
|
jit=False, |
|
|
|
|
|
|
|
cache_dir=config.clip_model_path |
|
|
|
|
|
|
|
) |
|
|
|
self.clip_model.half().to(config.device).eval() |
|
|
|
self.clip_model.half().to(config.device).eval() |
|
|
|
else: |
|
|
|
else: |
|
|
|
self.clip_model = config.clip_model |
|
|
|
self.clip_model = config.clip_model |
|
|
@ -93,6 +108,10 @@ class Interrogator(): |
|
|
|
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) |
|
|
|
self.movements = LabelTable(_load_list(config.data_path, 'movements.txt'), "movements", self.clip_model, self.tokenize, config) |
|
|
|
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) |
|
|
|
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
|
|
|
|
|
if not config.quiet: |
|
|
|
|
|
|
|
print(f"Loaded CLIP model and data in {end_time-start_time:.2f} seconds.") |
|
|
|
|
|
|
|
|
|
|
|
def generate_caption(self, pil_image: Image) -> str: |
|
|
|
def generate_caption(self, pil_image: Image) -> str: |
|
|
|
if self.config.blip_offload: |
|
|
|
if self.config.blip_offload: |
|
|
|
self.blip_model = self.blip_model.to(self.device) |
|
|
|
self.blip_model = self.blip_model.to(self.device) |
|
|
|