Browse Source

Fixes device detection for chooing precision

In `load_clip_model`, it used to check whether a GPU is being used by checking
if `config.device` == "cuda". This is fine, assuming all users will pass a str
for the device. Unfortunately, many users (including the `run_{cli,gradio}.py`
scripts instead pass a `torch.device`, and `torch.device("cuda") != "cuda"`

This commit makes it compare the `device.type` instead, which will be a string,
making this condition pass, and uses float16 when possible.
pull/46/head
bolshoytoster 2 years ago
parent
commit
487f21c7bf
  1. 18
      clip_interrogator/clip_interrogator.py

18
clip_interrogator/clip_interrogator.py

@ -15,7 +15,7 @@ from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from typing import List
from typing import List, Union
BLIP_MODELS = {
"base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
@ -64,7 +64,7 @@ class Config:
)
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), "data")
device: str = (
device: Union[str, torch.device] = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
@ -89,6 +89,7 @@ class Interrogator:
blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), "configs")
med_config = os.path.join(configs_path, "med_config.json")
blip_model = blip_decoder(
pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size,
@ -137,7 +138,14 @@ class Interrogator:
) = open_clip.create_model_and_transforms(
clip_model_name,
pretrained=clip_model_pretrained_name,
precision="fp16" if config.device == "cuda" else "fp32",
precision="fp16"
if (
config.device.type
if isinstance(config.device, torch.device)
else config.device
)
== "cuda"
else "fp32",
device="cpu",
jit=False,
cache_dir=config.clip_model_path,
@ -480,7 +488,9 @@ class Interrogator:
)
fast_prompt = self._interrogate_fast(caption, image_features, max_flavours)
classic_prompt = self.interrogate_classic(caption, image_features, max_flavours)
classic_prompt = self._interrogate_classic(
caption, image_features, max_flavours
)
candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))]

Loading…
Cancel
Save