Browse Source

Fixes device detection for chooing precision

In `load_clip_model`, it used to check whether a GPU is being used by checking
if `config.device` == "cuda". This is fine, assuming all users will pass a str
for the device. Unfortunately, many users (including the `run_{cli,gradio}.py`
scripts instead pass a `torch.device`, and `torch.device("cuda") != "cuda"`

This commit makes it compare the `device.type` instead, which will be a string,
making this condition pass, and uses float16 when possible.
pull/46/head
bolshoytoster 2 years ago
parent
commit
487f21c7bf
  1. 18
      clip_interrogator/clip_interrogator.py

18
clip_interrogator/clip_interrogator.py

@ -15,7 +15,7 @@ from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm from tqdm import tqdm
from typing import List from typing import List, Union
BLIP_MODELS = { BLIP_MODELS = {
"base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth", "base": "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth",
@ -64,7 +64,7 @@ class Config:
) )
chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM chunk_size: int = 2048 # batch size for CLIP, use smaller for lower VRAM
data_path: str = os.path.join(os.path.dirname(__file__), "data") data_path: str = os.path.join(os.path.dirname(__file__), "data")
device: str = ( device: Union[str, torch.device] = (
"mps" "mps"
if torch.backends.mps.is_available() if torch.backends.mps.is_available()
else "cuda" else "cuda"
@ -89,6 +89,7 @@ class Interrogator:
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), "configs") configs_path = os.path.join(os.path.dirname(blip_path), "configs")
med_config = os.path.join(configs_path, "med_config.json") med_config = os.path.join(configs_path, "med_config.json")
blip_model = blip_decoder( blip_model = blip_decoder(
pretrained=BLIP_MODELS[config.blip_model_type], pretrained=BLIP_MODELS[config.blip_model_type],
image_size=config.blip_image_eval_size, image_size=config.blip_image_eval_size,
@ -137,7 +138,14 @@ class Interrogator:
) = open_clip.create_model_and_transforms( ) = open_clip.create_model_and_transforms(
clip_model_name, clip_model_name,
pretrained=clip_model_pretrained_name, pretrained=clip_model_pretrained_name,
precision="fp16" if config.device == "cuda" else "fp32", precision="fp16"
if (
config.device.type
if isinstance(config.device, torch.device)
else config.device
)
== "cuda"
else "fp32",
device="cpu", device="cpu",
jit=False, jit=False,
cache_dir=config.clip_model_path, cache_dir=config.clip_model_path,
@ -480,7 +488,9 @@ class Interrogator:
) )
fast_prompt = self._interrogate_fast(caption, image_features, max_flavours) fast_prompt = self._interrogate_fast(caption, image_features, max_flavours)
classic_prompt = self.interrogate_classic(caption, image_features, max_flavours) classic_prompt = self._interrogate_classic(
caption, image_features, max_flavours
)
candidates = [caption, classic_prompt, fast_prompt, best_prompt] candidates = [caption, classic_prompt, fast_prompt, best_prompt]
return candidates[np.argmax(self.similarities(image_features, candidates))] return candidates[np.argmax(self.similarities(image_features, candidates))]

Loading…
Cancel
Save