From 1221871c1b2e56d04c3495cc509f7fe69e813143 Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Fri, 25 Nov 2022 10:33:13 -0600 Subject: [PATCH] Handle differences in how open_clip does prompt truncation, run_gradio support for all the open_clip models and --share option. --- clip_interrogator/clip_interrogator.py | 32 ++++++++++++++------------ run_gradio.py | 16 +++++++++---- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 1079c6c..d58b273 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -187,15 +187,13 @@ class Interrogator(): extended_flavors = set(flaves) for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): - try: - best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) - flave = best[len(best_prompt)+2:] - if not check(flave): - break - extended_flavors.remove(flave) - except: - # exceeded max prompt length + best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) + flave = best[len(best_prompt)+2:] + if not check(flave): break + if _prompt_at_max_len(best_prompt, self.tokenize): + break + extended_flavors.remove(flave) return best_prompt @@ -306,11 +304,15 @@ def _merge_tables(tables: List[LabelTable], config: Config) -> LabelTable: m.embeds.extend(table.embeds) return m +def _prompt_at_max_len(text: str, tokenize) -> bool: + tokens = tokenize([text]) + return tokens[0][-1] != 0 + def _truncate_to_fit(text: str, tokenize) -> str: - while True: - try: - _ = tokenize([text]) - return text - except: - text = ",".join(text.split(",")[:-1]) - \ No newline at end of file + parts = text.split(', ') + new_text = parts[0] + for part in parts[1:]: + if _prompt_at_max_len(new_text + part, tokenize): + break + new_text += ', ' + part + return new_text diff --git a/run_gradio.py b/run_gradio.py index 3d92498..a93ea5e 100755 --- a/run_gradio.py +++ b/run_gradio.py @@ -1,8 +1,13 @@ #!/usr/bin/env python3 -import clip +import argparse import gradio as gr +import open_clip from clip_interrogator import Interrogator, Config +parser = argparse.ArgumentParser() +parser.add_argument('-s', '--share', action='store_true', help='Create a public link') +args = parser.parse_args() + ci = Interrogator(Config()) def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): @@ -19,11 +24,13 @@ def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): return ci.interrogate_classic(image) else: return ci.interrogate_fast(image) - + +models = ['/'.join(x) for x in open_clip.list_pretrained()] + inputs = [ gr.inputs.Image(type='pil'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), - gr.Dropdown(clip.available_models(), value='ViT-L/14', label='CLIP Model'), + gr.Dropdown(models, value='ViT-H-14/laion2b_s32b_b79k', label='CLIP Model'), gr.Number(value=32, label='Caption Max Length'), gr.Number(value=64, label='Caption Num Beams'), ] @@ -38,4 +45,5 @@ io = gr.Interface( title="🕵️‍♂️ CLIP Interrogator 🕵️‍♂️", allow_flagging=False, ) -io.launch() +io.launch(share=args.share) +