|
|
@ -81,12 +81,12 @@ class Interrogator(): |
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
|
|
|
self.clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms( |
|
|
|
clip_model_name, |
|
|
|
clip_model_name, |
|
|
|
pretrained=clip_model_pretrained_name, |
|
|
|
pretrained=clip_model_pretrained_name, |
|
|
|
precision='fp16', |
|
|
|
precision='fp16' if config.device == 'cuda' else 'fp32', |
|
|
|
device=config.device, |
|
|
|
device=config.device, |
|
|
|
jit=False, |
|
|
|
jit=False, |
|
|
|
cache_dir=config.clip_model_path |
|
|
|
cache_dir=config.clip_model_path |
|
|
|
) |
|
|
|
) |
|
|
|
self.clip_model.half().to(config.device).eval() |
|
|
|
self.clip_model.to(config.device).eval() |
|
|
|
else: |
|
|
|
else: |
|
|
|
self.clip_model = config.clip_model |
|
|
|
self.clip_model = config.clip_model |
|
|
|
self.clip_preprocess = config.clip_preprocess |
|
|
|
self.clip_preprocess = config.clip_preprocess |
|
|
@ -256,6 +256,8 @@ class LabelTable(): |
|
|
|
if data.get('hash') == hash: |
|
|
|
if data.get('hash') == hash: |
|
|
|
self.labels = data['labels'] |
|
|
|
self.labels = data['labels'] |
|
|
|
self.embeds = data['embeds'] |
|
|
|
self.embeds = data['embeds'] |
|
|
|
|
|
|
|
if self.device == 'cpu': |
|
|
|
|
|
|
|
self.embeds = [e.astype(np.float32) for e in self.embeds] |
|
|
|
except Exception as e: |
|
|
|
except Exception as e: |
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
print(f"Error loading cached table {desc}: {e}") |
|
|
|
|
|
|
|
|
|
|
|