|
|
|
@ -28,6 +28,7 @@ class Config:
|
|
|
|
|
blip_max_length: int = 32 |
|
|
|
|
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' |
|
|
|
|
blip_num_beams: int = 8 |
|
|
|
|
blip_offload: bool = False |
|
|
|
|
|
|
|
|
|
# clip settings |
|
|
|
|
clip_model_name: str = 'ViT-H-14/laion2b_s32b_b79k' |
|
|
|
@ -93,6 +94,8 @@ class Interrogator():
|
|
|
|
|
self.trendings = LabelTable(trending_list, "trendings", self.clip_model, self.tokenize, config) |
|
|
|
|
|
|
|
|
|
def generate_caption(self, pil_image: Image) -> str: |
|
|
|
|
if self.config.blip_offload: |
|
|
|
|
self.blip_model = self.blip_model.to(self.device) |
|
|
|
|
size = self.config.blip_image_eval_size |
|
|
|
|
gpu_image = transforms.Compose([ |
|
|
|
|
transforms.Resize((size, size), interpolation=InterpolationMode.BICUBIC), |
|
|
|
@ -108,6 +111,8 @@ class Interrogator():
|
|
|
|
|
max_length=self.config.blip_max_length, |
|
|
|
|
min_length=5 |
|
|
|
|
) |
|
|
|
|
if self.config.blip_offload: |
|
|
|
|
self.blip_model = self.blip_model.to("cpu") |
|
|
|
|
return caption[0] |
|
|
|
|
|
|
|
|
|
def image_to_features(self, image: Image) -> torch.Tensor: |
|
|
|
|