From 55c922a48a7d352847fdd9840b8a791271dd0569 Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Mon, 21 Nov 2022 14:17:36 -0600 Subject: [PATCH] Update notebook batch processing with option to rename files so can be used with [filewords] in Dreambooth! - new `quiet` config option so CLIP Interrogator doesn't print and tqdm - `max_flavors` option to each interrogate method --- .gitignore | 1 + clip_interrogator.ipynb | 80 +++++++++++++++++--------- clip_interrogator/__init__.py | 2 +- clip_interrogator/clip_interrogator.py | 28 +++++---- setup.py | 2 +- 5 files changed, 71 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index b5d2001..6e6cd38 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.pyc .cog/ .vscode/ +bench/ cache/ clip-interrogator/ clip_interrogator.egg-info/ diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index 302b326..7329580 100644 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "cellView": "form", "id": "aP9FjmWxtLKJ" @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": { "cellView": "form", "id": "xpPKQR40qvz2" @@ -46,7 +46,7 @@ "outputs": [], "source": [ "#@title Setup\n", - "import argparse, subprocess, sys, time\n", + "import subprocess\n", "\n", "def setup():\n", " install_cmds = [\n", @@ -65,10 +65,8 @@ "sys.path.append('src/clip')\n", "sys.path.append('clip-interrogator')\n", "\n", - "import clip\n", "import gradio as gr\n", - "import torch\n", - "from clip_interrogator import Interrogator, Config\n", + "from clip_interrogator import Config, Interrogator\n", "\n", "ci = Interrogator(Config())\n", "\n", @@ -84,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "cellView": "form", "colab": { @@ -150,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "cellView": "form", "id": "OGmvkzITN4Hz" @@ -159,43 +157,69 @@ "source": [ "#@title Batch process a folder of images 📁 -> 📝\n", "\n", - "#@markdown This will generate prompts for every image in a folder and save results to desc.csv in the same folder.\n", + "#@markdown This will generate prompts for every image in a folder and either save results \n", + "#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n", + "#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n", + "#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n", "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n", - "#@markdown notebook or use it to train a different model or just print it out for fun. \n", - "#@markdown If you make something cool, I'd love to see it! Tag me on twitter or something. 😀\n", "\n", "import csv\n", "import os\n", - "from IPython.display import display\n", + "from IPython.display import clear_output, display\n", "from PIL import Image\n", "from tqdm import tqdm\n", "\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", - "mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", + "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", + "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", + "max_filename_len = 128 #@param {type:\"integer\"}\n", "\n", "\n", - "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n", + "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", + " name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n", + " name = name.strip()[:(max_len-4)] # extra space for extension\n", + " return name\n", + "\n", + "ci.config.quiet = True\n", + "\n", + "files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n", "prompts = []\n", - "for file in files:\n", + "for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n", + " if idx > 0 and idx % 100 == 0:\n", + " clear_output(wait=True)\n", + "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", - " prompt = inference(image, mode)\n", + " prompt = inference(image, prompt_mode)\n", " prompts.append(prompt)\n", "\n", + " print(prompt)\n", " thumb = image.copy()\n", " thumb.thumbnail([256, 256])\n", " display(thumb)\n", "\n", - " print(prompt)\n", + " if output_mode == 'rename':\n", + " name = sanitize_for_filename(prompt, max_filename_len)\n", + " ext = os.path.splitext(file)[1]\n", + " filename = name + ext\n", + " idx = 1\n", + " while os.path.exists(os.path.join(folder_path, filename)):\n", + " print(f'File {filename} already exists, trying {idx+1}...')\n", + " filename = f\"{name}_{idx}{ext}\"\n", + " idx += 1\n", + " os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n", "\n", "if len(prompts):\n", - " csv_path = os.path.join(folder_path, 'desc.csv')\n", - " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", - " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", - " w.writerow(['image', 'prompt'])\n", - " for file, prompt in zip(files, prompts):\n", - " w.writerow([file, prompt])\n", - "\n", - " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!\")\n", + " if output_mode == 'desc.csv':\n", + " csv_path = os.path.join(folder_path, 'desc.csv')\n", + " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", + " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", + " w.writerow(['image', 'prompt'])\n", + " for file, prompt in zip(files, prompts):\n", + " w.writerow([file, prompt])\n", + "\n", + " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n", + " else:\n", + " print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n", "else:\n", " print(f\"Sorry, I couldn't find any images in {folder_path}\")\n" ] @@ -208,7 +232,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "Python 3.8.10 ('venv': venv)", + "display_name": "Python 3.9.5 ('venv': venv)", "language": "python", "name": "python3" }, @@ -222,12 +246,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1" + "hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d" } } }, diff --git a/clip_interrogator/__init__.py b/clip_interrogator/__init__.py index 6cf7af1..f42edce 100644 --- a/clip_interrogator/__init__.py +++ b/clip_interrogator/__init__.py @@ -1,4 +1,4 @@ from .clip_interrogator import Interrogator, Config -__version__ = '0.1.4' +__version__ = '0.2.0' __author__ = 'pharmapsychotic' \ No newline at end of file diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 6838031..fb26215 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -37,7 +37,8 @@ class Config: chunk_size: int = 2048 data_path: str = os.path.join(os.path.dirname(__file__), 'data') device: str = 'cuda' if torch.cuda.is_available() else 'cpu' - flavor_intermediate_count: int = 2048 + flavor_intermediate_count: int = 2048 + quiet: bool = False # when quiet progress bars are not shown class Interrogator(): @@ -46,7 +47,8 @@ class Interrogator(): self.device = config.device if config.blip_model is None: - print("Loading BLIP model...") + if not config.quiet: + print("Loading BLIP model...") blip_path = os.path.dirname(inspect.getfile(blip_decoder)) configs_path = os.path.join(os.path.dirname(blip_path), 'configs') med_config = os.path.join(configs_path, 'med_config.json') @@ -63,7 +65,8 @@ class Interrogator(): self.blip_model = config.blip_model if config.clip_model is None: - print("Loading CLIP model...") + if not config.quiet: + print("Loading CLIP model...") self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) self.clip_model.to(config.device).eval() else: @@ -111,7 +114,7 @@ class Interrogator(): image_features /= image_features.norm(dim=-1, keepdim=True) return image_features - def interrogate_classic(self, image: Image, max_flaves: int=3) -> str: + def interrogate_classic(self, image: Image, max_flavors: int=3) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) @@ -119,7 +122,7 @@ class Interrogator(): artist = self.artists.rank(image_features, 1)[0] trending = self.trendings.rank(image_features, 1)[0] movement = self.movements.rank(image_features, 1)[0] - flaves = ", ".join(self.flavors.rank(image_features, max_flaves)) + flaves = ", ".join(self.flavors.rank(image_features, max_flavors)) if caption.startswith(medium): prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" @@ -128,14 +131,14 @@ class Interrogator(): return _truncate_to_fit(prompt) - def interrogate_fast(self, image: Image) -> str: + def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) - tops = merged.rank(image_features, 32) + tops = merged.rank(image_features, max_flavors) return _truncate_to_fit(caption + ", " + ", ".join(tops)) - def interrogate(self, image: Image) -> str: + def interrogate(self, image: Image, max_flavors: int=32) -> str: caption = self.generate_caption(image) image_features = self.image_to_features(image) @@ -175,7 +178,7 @@ class Interrogator(): check_multi_batch([best_medium, best_artist, best_trending, best_movement]) extended_flavors = set(flaves) - for _ in tqdm(range(25), desc="Flavor chain"): + for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet): try: best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) flave = best[len(best_prompt)+2:] @@ -213,9 +216,10 @@ class Interrogator(): class LabelTable(): def __init__(self, labels:List[str], desc:str, clip_model, config: Config): self.chunk_size = config.chunk_size + self.config = config self.device = config.device - self.labels = labels self.embeds = [] + self.labels = labels hash = hashlib.sha256(",".join(labels).encode()).hexdigest() @@ -234,7 +238,7 @@ class LabelTable(): if len(self.labels) != len(self.embeds): self.embeds = [] chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) - for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): + for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): text_tokens = clip.tokenize(chunk).to(self.device) with torch.no_grad(): text_features = clip_model.encode_text(text_tokens).float() @@ -270,7 +274,7 @@ class LabelTable(): keep_per_chunk = int(self.chunk_size / num_chunks) top_labels, top_embeds = [], [] - for chunk_idx in tqdm(range(num_chunks)): + for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet): start = chunk_idx*self.chunk_size stop = min(start+self.chunk_size, len(self.embeds)) tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) diff --git a/setup.py b/setup.py index 502c769..a08a50e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ from setuptools import setup, find_packages setup( name="clip-interrogator", - version="0.1.4", + version="0.2.0", license='MIT', author='pharmapsychotic', author_email='me@pharmapsychotic.com',