#!/usr/bin/env python3 import argparse import gradio as gr import open_clip import torch from clip_interrogator import Interrogator, Config parser = argparse.ArgumentParser() parser.add_argument('-s', '--share', action='store_true', help='Create a public link') args = parser.parse_args() if not torch.cuda.is_available(): print("CUDA is not available, using CPU. Warning: this will be very slow!") ci = Interrogator(Config(cache_path="cache", clip_model_path="cache")) def inference(image, mode, clip_model_name, blip_max_length, blip_num_beams): if clip_model_name != ci.config.clip_model_name: ci.config.clip_model_name = clip_model_name ci.load_clip_model() ci.config.blip_max_length = int(blip_max_length) ci.config.blip_num_beams = int(blip_num_beams) image = image.convert('RGB') if mode == 'best': return ci.interrogate(image) elif mode == 'classic': return ci.interrogate_classic(image) else: return ci.interrogate_fast(image) models = ['/'.join(x) for x in open_clip.list_pretrained()] inputs = [ gr.inputs.Image(type='pil'), gr.Radio(['best', 'classic', 'fast'], label='Mode', value='best'), gr.Dropdown(models, value='ViT-L-14/openai', label='CLIP Model'), gr.Number(value=32, label='Caption Max Length'), gr.Number(value=64, label='Caption Num Beams'), ] outputs = [ gr.outputs.Textbox(label="Output"), ] io = gr.Interface( inference, inputs, outputs, title="🕵️‍♂️ CLIP Interrogator 🕵️‍♂️", allow_flagging=False, ) io.launch(share=args.share)