Browse Source

Handle differences in how open_clip does prompt truncation, run_gradio support for all the open_clip models and --share option.

pull/18/head
pharmapsychotic 2 years ago
parent
commit
1221871c1b
  1. 32
      clip_interrogator/clip_interrogator.py
  2. 16
      run_gradio.py

32
clip_interrogator/clip_interrogator.py

@ -187,15 +187,13 @@ class Interrogator():
extended_flavors = set(flaves)
for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
try:
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:]
if not check(flave):
break
extended_flavors.remove(flave)
except:
# exceeded max prompt length
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:]
if not check(flave):
break
if _prompt_at_max_len(best_prompt, self.tokenize):
break
extended_flavors.remove(flave)
return best_prompt
@ -306,11 +304,15 @@ def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable:
m.embeds.extend(table.embeds)
return m
def _prompt_at_max_len(text: str, tokenize) -> bool:
tokens = tokenize([text])
return tokens[0][-1] != 0
def _truncate_to_fit(text: str, tokenize) -> str:
while True:
try:
_ = tokenize([text])
return text
except:
text = ",".join(text.split(",")[:-1])
parts = text.split(', ')
new_text = parts[0]
for part in parts[1:]:
if _prompt_at_max_len(new_text + part, tokenize):
break
new_text += ', ' + part
return new_text

16
run_gradio.py

@ -1,8 +1,13 @@
#!/usr/bin/env python3
import clip
import argparse
import gradio as gr
import open_clip
from clip_interrogator import Interrogator, Config
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--share', action='store_true', help='Create a public link')
args = parser.parse_args()
ci = Interrogator(Config())
def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
@ -19,11 +24,13 @@ def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams):
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
models = ['/'.join(x) for x in open_clip.list_pretrained()]
inputs = [
gr.inputs.Image(type='pil'),
gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'),
gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'),
gr.Dropdown(models, value='ViT-H-14/laion2b_s32b_b79k', label='CLIP Model'),
gr.Number(value=32, label='Caption Max Length'),
gr.Number(value=64, label='Caption Num Beams'),
]
@ -38,4 +45,5 @@ io = gr.Interface(
title="🕵 CLIP Interrogator 🕵",
allow_flagging=False,
)
io.launch()
io.launch(share=args.share)

Loading…
Cancel
Save