Browse Source

Gradio version plus classic and fast modes

replicate
pharmapsychotic 2 years ago
parent
commit
b62cca2097
  1. 65
      clip_interrogator/clip_interrogator.py
  2. 2
      clip_interrogator/data/flavors.txt
  3. 15
      run_cli.py
  4. 41
      run_gradio.py

65
clip_interrogator/clip_interrogator.py

@ -25,9 +25,9 @@ class Config:
# blip settings # blip settings
blip_image_eval_size: int = 384 blip_image_eval_size: int = 384
blip_max_length: int = 20 blip_max_length: int = 32
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_num_beams: int = 3 blip_num_beams: int = 8
# clip settings # clip settings
clip_model_name: str = 'ViT-L/14' clip_model_name: str = 'ViT-L/14'
@ -40,12 +40,6 @@ class Config:
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
def _load_list(data_path, filename) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
class Interrogator(): class Interrogator():
def __init__(self, config: Config): def __init__(self, config: Config):
self.config = config self.config = config
@ -110,13 +104,40 @@ class Interrogator():
) )
return caption[0] return caption[0]
def interrogate(self, image: Image) -> str: def image_to_features(self, image: Image) -> torch.Tensor:
caption = self.generate_caption(image)
images = self.clip_preprocess(image).unsqueeze(0).to(self.device) images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad(): with torch.no_grad():
image_features = self.clip_model.encode_image(images).float() image_features = self.clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def interrogate_classic(self, image: Image, max_flaves: int=3) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
medium = self.mediums.rank(image_features, 1)[0]
artist = self.artists.rank(image_features, 1)[0]
trending = self.trendings.rank(image_features, 1)[0]
movement = self.movements.rank(image_features, 1)[0]
flaves = ", ".join(self.flavors.rank(image_features, max_flaves))
if caption.startswith(medium):
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
else:
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
return _truncate_to_fit(prompt)
def interrogate_fast(self, image: Image) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, 32)
return _truncate_to_fit(caption + ", " + ", ".join(tops))
def interrogate(self, image: Image) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count) flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
best_medium = self.mediums.rank(image_features, 1)[0] best_medium = self.mediums.rank(image_features, 1)[0]
@ -258,3 +279,25 @@ class LabelTable():
tops = self._rank(image_features, top_embeds, top_count=top_count) tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops] return [top_labels[i] for i in tops]
def _load_list(data_path, filename) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, config)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
return m
def _truncate_to_fit(text: str) -> str:
while True:
try:
_ = clip.tokenize([text])
return text
except:
text = ",".join(text.split(",")[:-1])

2
clip_interrogator/data/flavors.txt

@ -44137,7 +44137,7 @@ bionic cyborg implants
venetian mask venetian mask
renaissance mural renaissance mural
digital art of an elegant digital art of an elegant
beautifull cyberpunk woman model beautiful cyberpunk woman model
cute face. dark fantasy cute face. dark fantasy
terminator tech terminator tech
seasonal seasonal

15
run_cli.py

@ -1,19 +1,16 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import clip import clip
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
args = parser.parse_args() args = parser.parse_args()
if not args.image: if not args.image:
@ -40,7 +37,13 @@ def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config) ci = Interrogator(config)
print(ci.interrogate(image)) if args.mode == 'best':
prompt = ci.interrogate(image)
elif args.mode == 'classic':
prompt = ci.interrogate_classic(image)
else:
prompt = ci.interrogate_fast(image)
print(prompt)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

41
run_gradio.py

@ -0,0 +1,41 @@
#!/usr/bin/env python3
import clip
import gradio as gr
from clip_interrogator import Interrogator, Config
ci = Interrogator(Config())
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
global ci
if clip_model_name != ci.config.clip_model_name:
ci = Interrogator(Config(clip_model_name=clip_model_name))
ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
inputs = [
gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'),
]
outputs = [
gr.outputs.Textbox(label="Output"),
]
io = gr.Interface(
inference,
inputs,
outputs,
title="🕵 CLIP Interrogator 🕵",
allow_flagging=False,
)
io.launch()
Loading…
Cancel
Save