Image to prompt with BLIP and CLIP
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

386 lines
18 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"<br>\n",
"\n",
"Want to figure out what a good prompt might be to create new images like an existing one? The CLIP Interrogator is here to get you answers!\n",
"\n",
"<br>\n",
"\n",
"This version is specialized for producing nice prompts for use with Stable Diffusion and achieves higher alignment between generated text prompt and source image. You can try out the old [version 1](https://colab.research.google.com/github/pharmapsychotic/clip-interrogator/blob/v1/clip_interrogator.ipynb) to see how different CLIP models ranks terms. \n",
"\n",
"<br>\n",
"\n",
"If this notebook is helpful to you please consider buying me a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following me on [twitter](https://twitter.com/pharmapsychotic) for more cool Ai stuff. 🙂\n",
"\n",
"And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).\n"
]
},
{
"cell_type": "code",
"source": [
"#@title Check GPU\n",
"!nvidia-smi -L"
],
"metadata": {
"cellView": "form",
"id": "aP9FjmWxtLKJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xpPKQR40qvz2"
},
"outputs": [],
"source": [
"#@title Setup\n",
"import argparse, subprocess, sys, time\n",
"\n",
"def setup():\n",
" install_cmds = [\n",
" ['pip', 'install', 'ftfy', 'regex', 'tqdm', 'transformers==4.21.2', 'timm', 'fairscale', 'requests'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n",
" ['pip', 'install', '-e', 'git+https://github.com/pharmapsychotic/BLIP.git@main#egg=blip'],\n",
" ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']\n",
" ]\n",
" for cmd in install_cmds:\n",
" print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n",
"setup()\n",
"\n",
"import sys\n",
"sys.path.append('src/blip')\n",
"sys.path.append('src/clip')\n",
"\n",
"import clip\n",
"import hashlib\n",
"import io\n",
"import IPython\n",
"import ipywidgets as widgets\n",
"import math\n",
"import numpy as np\n",
"import os\n",
"import pickle\n",
"import requests\n",
"import torch\n",
"import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n",
"\n",
"from models.blip import blip_decoder\n",
"from PIL import Image\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torchvision import transforms\n",
"from torchvision.transforms.functional import InterpolationMode\n",
"from tqdm import tqdm\n",
"from zipfile import ZipFile\n",
"\n",
"\n",
"chunk_size = 2048\n",
"flavor_intermediate_count = 2048\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"print(\"Loading BLIP model...\")\n",
"blip_image_eval_size = 384\n",
"blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth' \n",
"blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='large', med_config='./src/blip/configs/med_config.json')\n",
"blip_model.eval()\n",
"blip_model = blip_model.to(device)\n",
"\n",
"print(\"Loading CLIP model...\")\n",
"clip_model_name = 'ViT-L/14' #@param ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px', 'RN101', 'RN50', 'RN50x4', 'RN50x16', 'RN50x64'] {type:'string'}\n",
"clip_model, clip_preprocess = clip.load(clip_model_name, device=\"cuda\")\n",
"clip_model.cuda().eval()\n",
"\n",
"\n",
"class LabelTable():\n",
" def __init__(self, labels, desc):\n",
" self.labels = labels\n",
" self.embeds = []\n",
"\n",
" hash = hashlib.sha256(\",\".join(labels).encode()).hexdigest()\n",
"\n",
" os.makedirs('./cache', exist_ok=True)\n",
" cache_filepath = f\"./cache/{desc}.pkl\"\n",
" if desc is not None and os.path.exists(cache_filepath):\n",
" with open(cache_filepath, 'rb') as f:\n",
" data = pickle.load(f)\n",
" if data.get('hash') == hash and data.get('model') == clip_model_name:\n",
" self.labels = data['labels']\n",
" self.embeds = data['embeds']\n",
"\n",
" if len(self.labels) != len(self.embeds):\n",
" self.embeds = []\n",
" chunks = np.array_split(self.labels, max(1, len(self.labels)/chunk_size))\n",
" for chunk in tqdm(chunks, desc=f\"Preprocessing {desc}\" if desc else None):\n",
" text_tokens = clip.tokenize(chunk).cuda()\n",
" with torch.no_grad():\n",
" text_features = clip_model.encode_text(text_tokens).float()\n",
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
" text_features = text_features.half().cpu().numpy()\n",
" for i in range(text_features.shape[0]):\n",
" self.embeds.append(text_features[i])\n",
"\n",
" with open(cache_filepath, 'wb') as f:\n",
" pickle.dump({\"labels\":self.labels, \"embeds\":self.embeds, \"hash\":hash, \"model\":clip_model_name}, f)\n",
" \n",
" def _rank(self, image_features, text_embeds, top_count=1):\n",
" top_count = min(top_count, len(text_embeds))\n",
" similarity = torch.zeros((1, len(text_embeds))).to(device)\n",
" text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)\n",
" for i in range(image_features.shape[0]):\n",
" similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1)\n",
" _, top_labels = similarity.cpu().topk(top_count, dim=-1)\n",
" return [top_labels[0][i].numpy() for i in range(top_count)]\n",
"\n",
" def rank(self, image_features, top_count=1):\n",
" if len(self.labels) <= chunk_size:\n",
" tops = self._rank(image_features, self.embeds, top_count=top_count)\n",
" return [self.labels[i] for i in tops]\n",
"\n",
" num_chunks = int(math.ceil(len(self.labels)/chunk_size))\n",
" keep_per_chunk = int(chunk_size / num_chunks)\n",
"\n",
" top_labels, top_embeds = [], []\n",
" for chunk_idx in tqdm(range(num_chunks)):\n",
" start = chunk_idx*chunk_size\n",
" stop = min(start+chunk_size, len(self.embeds))\n",
" tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)\n",
" top_labels.extend([self.labels[start+i] for i in tops])\n",
" top_embeds.extend([self.embeds[start+i] for i in tops])\n",
"\n",
" tops = self._rank(image_features, top_embeds, top_count=top_count)\n",
" return [top_labels[i] for i in tops]\n",
"\n",
"def generate_caption(pil_image):\n",
" gpu_image = transforms.Compose([\n",
" transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
" ])(pil_image).unsqueeze(0).to(device)\n",
"\n",
" with torch.no_grad():\n",
" caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)\n",
" return caption[0]\n",
"\n",
"def rank_top(image_features, text_array):\n",
" text_tokens = clip.tokenize([text for text in text_array]).cuda()\n",
" with torch.no_grad():\n",
" text_features = clip_model.encode_text(text_tokens).float()\n",
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
"\n",
" similarity = torch.zeros((1, len(text_array)), device=device)\n",
" for i in range(image_features.shape[0]):\n",
" similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)\n",
"\n",
" _, top_labels = similarity.cpu().topk(1, dim=-1)\n",
" return text_array[top_labels[0][0].numpy()]\n",
"\n",
"def similarity(image_features, text):\n",
" text_tokens = clip.tokenize([text]).cuda()\n",
" with torch.no_grad():\n",
" text_features = clip_model.encode_text(text_tokens).float() \n",
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
" similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T\n",
" return similarity[0][0]\n",
"\n",
"def load_list(filename):\n",
" with open(filename, 'r', encoding='utf-8', errors='replace') as f:\n",
" items = [line.strip() for line in f.readlines()]\n",
" return items\n",
"\n",
"def interrogate(image):\n",
" caption = generate_caption(image)\n",
"\n",
" images = clip_preprocess(image).unsqueeze(0).cuda()\n",
" with torch.no_grad():\n",
" image_features = clip_model.encode_image(images).float()\n",
" image_features /= image_features.norm(dim=-1, keepdim=True)\n",
"\n",
" flaves = flavors.rank(image_features, flavor_intermediate_count)\n",
" best_medium = mediums.rank(image_features, 1)[0]\n",
" best_artist = artists.rank(image_features, 1)[0]\n",
" best_trending = trendings.rank(image_features, 1)[0]\n",
" best_movement = movements.rank(image_features, 1)[0]\n",
"\n",
" best_prompt = caption\n",
" best_sim = similarity(image_features, best_prompt)\n",
"\n",
" def check(addition):\n",
" nonlocal best_prompt, best_sim\n",
" prompt = best_prompt + \", \" + addition\n",
" sim = similarity(image_features, prompt)\n",
" if sim > best_sim:\n",
" best_sim = sim\n",
" best_prompt = prompt\n",
" return True\n",
" return False\n",
"\n",
" def check_multi_batch(opts):\n",
" nonlocal best_prompt, best_sim\n",
" prompts = []\n",
" for i in range(2**len(opts)):\n",
" prompt = best_prompt\n",
" for bit in range(len(opts)):\n",
" if i & (1 << bit):\n",
" prompt += \", \" + opts[bit]\n",
" prompts.append(prompt)\n",
"\n",
" t = LabelTable(prompts, None)\n",
" best_prompt = t.rank(image_features, 1)[0]\n",
" best_sim = similarity(image_features, best_prompt)\n",
"\n",
" check_multi_batch([best_medium, best_artist, best_trending, best_movement])\n",
"\n",
" extended_flavors = set(flaves)\n",
" for _ in tqdm(range(25), desc=\"Flavor chain\"):\n",
" try:\n",
" best = rank_top(image_features, [f\"{best_prompt}, {f}\" for f in extended_flavors])\n",
" flave = best[len(best_prompt)+2:]\n",
" if not check(flave):\n",
" break\n",
" extended_flavors.remove(flave)\n",
" except:\n",
" # exceeded max prompt length\n",
" break\n",
"\n",
" return best_prompt\n",
"\n",
"DATA_PATH = 'clip-interrogator/data'\n",
"\n",
"sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']\n",
"trending_list = [site for site in sites]\n",
"trending_list.extend([\"trending on \"+site for site in sites])\n",
"trending_list.extend([\"featured on \"+site for site in sites])\n",
"trending_list.extend([site+\" contest winner\" for site in sites])\n",
"\n",
"raw_artists = load_list(f'{DATA_PATH}/artists.txt')\n",
"artists = [f\"by {a}\" for a in raw_artists]\n",
"artists.extend([f\"inspired by {a}\" for a in raw_artists])\n",
"\n",
"artists = LabelTable(artists, \"artists\")\n",
"flavors = LabelTable(load_list(f'{DATA_PATH}/flavors.txt'), \"flavors\")\n",
"mediums = LabelTable(load_list(f'{DATA_PATH}/mediums.txt'), \"mediums\")\n",
"movements = LabelTable(load_list(f'{DATA_PATH}/movements.txt'), \"movements\")\n",
"trendings = LabelTable(trending_list, \"trendings\")\n",
"\n"
]
},
{
"cell_type": "code",
"source": [
"#@title Interrogate\n",
"\n",
"#@markdown Run this cell and then paste a link to an image or upload an image in the UI. Then click the Interrogate button to get a prompt suggestion.\n",
"\n",
"image_url = 'https://cdnb.artstation.com/p/assets/images/images/032/142/769/large/ignacio-bazan-lazcano-book-4-final.jpg'\n",
"\n",
"def show_ui():\n",
" go_button = widgets.Button(\n",
" description='Interrogate!',\n",
" disabled=False,\n",
" button_style='',\n",
" tooltip='Click me'\n",
" )\n",
" image_txt = widgets.Text(\n",
" value=image_url, \n",
" description='', \n",
" layout=widgets.Layout(width='50%')\n",
" )\n",
" uploader = widgets.FileUpload(accept='image/*', multiple=False)\n",
"\n",
" ui = widgets.VBox([\n",
" widgets.HBox([widgets.Label('image url:'), image_txt]),\n",
" widgets.HBox([widgets.Label('or upload:'), uploader]),\n",
" widgets.Label(''),\n",
" go_button\n",
" ])\n",
"\n",
" def go(btn):\n",
" image_url = image_txt.value\n",
" if len(uploader.value):\n",
" print(uploader.value)\n",
" print(uploader.value.items())\n",
" for name, file_info in uploader.value.items():\n",
" image = Image.open(io.BytesIO(file_info['content']))\n",
" break\n",
" else:\n",
" if str(image_url).startswith('http://') or str(image_url).startswith('https://'):\n",
" image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')\n",
" else:\n",
" image = Image.open(image_url).convert('RGB')\n",
"\n",
" IPython.display.clear_output()\n",
" print('\\n\\n')\n",
" thumb = image.copy()\n",
" thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])\n",
" print(\"Interrogating...\")\n",
" display(thumb)\n",
"\n",
" prompt = interrogate(image)\n",
" IPython.display.clear_output()\n",
" show_ui()\n",
"\n",
" print('\\n\\n')\n",
" display(thumb)\n",
" ui = widgets.VBox([\n",
" widgets.Textarea(\n",
" value=prompt,\n",
" description='prompt:',\n",
" layout=widgets.Layout(width='75%', height='6em')\n",
" )\n",
" ])\n",
" display(ui)\n",
" \n",
" go_button.on_click(go)\n",
" image_txt.on_submit(go)\n",
" display(ui)\n",
"\n",
"show_ui()"
],
"metadata": {
"cellView": "form",
"id": "34fmVUqjx3l7"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('pytorch')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.7"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "d4dd9c310c32a31bb53615812f2f2c6cba010b7aa4dfb14e2b192e650667fecd"
}
},
"colab": {
"provenance": [],
"collapsed_sections": []
},
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}