Browse Source

Gradio version plus classic and fast modes

replicate
pharmapsychotic 2 years ago
parent
commit
b62cca2097
  1. 65
      clip_interrogator/clip_interrogator.py
  2. 2
      clip_interrogator/data/flavors.txt
  3. 15
      run_cli.py
  4. 41
      run_gradio.py

65
clip_interrogator/clip_interrogator.py

@ -25,9 +25,9 @@ class Config:
# blip settings
blip_image_eval_size: int = 384
blip_max_length: int = 20
blip_max_length: int = 32
blip_model_url: str = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_num_beams: int = 3
blip_num_beams: int = 8
# clip settings
clip_model_name: str = 'ViT-L/14'
@ -40,12 +40,6 @@ class Config:
flavor_intermediate_count: int = 2048
def _load_list(data_path, filename) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
class Interrogator():
def __init__(self, config: Config):
self.config = config
@ -110,13 +104,40 @@ class Interrogator():
)
return caption[0]
def interrogate(self, image: Image) -> str:
caption = self.generate_caption(image)
def image_to_features(self, image: Image) -> torch.Tensor:
images = self.clip_preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.clip_model.encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def interrogate_classic(self, image: Image, max_flaves: int=3) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
medium = self.mediums.rank(image_features, 1)[0]
artist = self.artists.rank(image_features, 1)[0]
trending = self.trendings.rank(image_features, 1)[0]
movement = self.movements.rank(image_features, 1)[0]
flaves = ", ".join(self.flavors.rank(image_features, max_flaves))
if caption.startswith(medium):
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
else:
prompt = f"{caption}, {medium} {artist}, {trending}, {movement}, {flaves}"
return _truncate_to_fit(prompt)
def interrogate_fast(self, image: Image) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, 32)
return _truncate_to_fit(caption + ", " + ", ".join(tops))
def interrogate(self, image: Image) -> str:
caption = self.generate_caption(image)
image_features = self.image_to_features(image)
flaves = self.flavors.rank(image_features, self.config.flavor_intermediate_count)
best_medium = self.mediums.rank(image_features, 1)[0]
@ -258,3 +279,25 @@ class LabelTable():
tops = self._rank(image_features, top_embeds, top_count=top_count)
return [top_labels[i] for i in tops]
def _load_list(data_path, filename) -> List[str]:
with open(os.path.join(data_path, filename), 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m = LabelTable([], None, None, config)
for table in tables:
m.labels.extend(table.labels)
m.embeds.extend(table.embeds)
return m
def _truncate_to_fit(text: str) -> str:
while True:
try:
_ = clip.tokenize([text])
return text
except:
text = ",".join(text.split(",")[:-1])

2
clip_interrogator/data/flavors.txt

@ -44137,7 +44137,7 @@ bionic cyborg implants
venetian mask
renaissance mural
digital art of an elegant
beautifull cyberpunk woman model
beautiful cyberpunk woman model
cute face. dark fantasy
terminator tech
seasonal

15
run_cli.py

@ -1,19 +1,16 @@
#!/usr/bin/env python3
import argparse
import clip
import requests
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
args = parser.parse_args()
if not args.image:
@ -40,7 +37,13 @@ def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config)
print(ci.interrogate(image))
if args.mode == 'best':
prompt = ci.interrogate(image)
elif args.mode == 'classic':
prompt = ci.interrogate_classic(image)
else:
prompt = ci.interrogate_fast(image)
print(prompt)
if __name__ == "__main__":
main()

41
run_gradio.py

@ -0,0 +1,41 @@
#!/usr/bin/env python3
import clip
import gradio as gr
from clip_interrogator import Interrogator, Config
ci = Interrogator(Config())
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
global ci
if clip_model_name != ci.config.clip_model_name:
ci = Interrogator(Config(clip_model_name=clip_model_name))
ci.config.blip_max_length = int(blip_max_length)
ci.config.blip_num_beams = int(blip_num_beams)
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
inputs = [
gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'),
]
outputs = [
gr.outputs.Textbox(label="Output"),
]
io = gr.Interface(
inference,
inputs,
outputs,
title="🕵 CLIP Interrogator 🕵",
allow_flagging=False,
)
io.launch()
Loading…
Cancel
Save