Browse Source

Update notebook batch processing with option to rename files so can be used with [filewords] in Dreambooth!

- new `quiet` config option so CLIP Interrogator doesn't print and tqdm
- `max_flavors` option to each interrogate method
pull/22/head
pharmapsychotic 2 years ago
parent
commit
55c922a48a
  1. 1
      .gitignore
  2. 80
      clip_interrogator.ipynb
  3. 2
      clip_interrogator/__init__.py
  4. 28
      clip_interrogator/clip_interrogator.py
  5. 2
      setup.py

1
.gitignore vendored

@ -1,6 +1,7 @@
*.pyc *.pyc
.cog/ .cog/
.vscode/ .vscode/
bench/
cache/ cache/
clip-interrogator/ clip-interrogator/
clip_interrogator.egg-info/ clip_interrogator.egg-info/

80
clip_interrogator.ipynb

@ -25,7 +25,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "aP9FjmWxtLKJ" "id": "aP9FjmWxtLKJ"
@ -38,7 +38,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 2,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "xpPKQR40qvz2" "id": "xpPKQR40qvz2"
@ -46,7 +46,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Setup\n", "#@title Setup\n",
"import argparse, subprocess, sys, time\n", "import subprocess\n",
"\n", "\n",
"def setup():\n", "def setup():\n",
" install_cmds = [\n", " install_cmds = [\n",
@ -65,10 +65,8 @@
"sys.path.append('src/clip')\n", "sys.path.append('src/clip')\n",
"sys.path.append('clip-interrogator')\n", "sys.path.append('clip-interrogator')\n",
"\n", "\n",
"import clip\n",
"import gradio as gr\n", "import gradio as gr\n",
"import torch\n", "from clip_interrogator import Config, Interrogator\n",
"from clip_interrogator import Interrogator, Config\n",
"\n", "\n",
"ci = Interrogator(Config())\n", "ci = Interrogator(Config())\n",
"\n", "\n",
@ -84,7 +82,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": { "colab": {
@ -150,7 +148,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 3,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "OGmvkzITN4Hz" "id": "OGmvkzITN4Hz"
@ -159,43 +157,69 @@
"source": [ "source": [
"#@title Batch process a folder of images 📁 -> 📝\n", "#@title Batch process a folder of images 📁 -> 📝\n",
"\n", "\n",
"#@markdown This will generate prompts for every image in a folder and save results to desc.csv in the same folder.\n", "#@markdown This will generate prompts for every image in a folder and either save results \n",
"#@markdown to a desc.csv file in the same folder or rename the files to contain their prompts.\n",
"#@markdown The renamed files work well for [DreamBooth extension](https://github.com/d8ahazard/sd_dreambooth_extension)\n",
"#@markdown in the [Stable Diffusion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui).\n",
"#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n", "#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n",
"#@markdown notebook or use it to train a different model or just print it out for fun. \n",
"#@markdown If you make something cool, I'd love to see it! Tag me on twitter or something. 😀\n",
"\n", "\n",
"import csv\n", "import csv\n",
"import os\n", "import os\n",
"from IPython.display import display\n", "from IPython.display import clear_output, display\n",
"from PIL import Image\n", "from PIL import Image\n",
"from tqdm import tqdm\n", "from tqdm import tqdm\n",
"\n", "\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n", "folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n",
"output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n",
"max_filename_len = 128 #@param {type:\"integer\"}\n",
"\n", "\n",
"\n", "\n",
"files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n",
" name = \"\".join(c for c in prompt if (c.isalnum() or c in \",._-! \"))\n",
" name = name.strip()[:(max_len-4)] # extra space for extension\n",
" return name\n",
"\n",
"ci.config.quiet = True\n",
"\n",
"files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
"prompts = []\n", "prompts = []\n",
"for file in files:\n", "for idx, file in enumerate(tqdm(files, desc='Generating prompts')):\n",
" if idx > 0 and idx % 100 == 0:\n",
" clear_output(wait=True)\n",
"\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, mode)\n", " prompt = inference(image, prompt_mode)\n",
" prompts.append(prompt)\n", " prompts.append(prompt)\n",
"\n", "\n",
" print(prompt)\n",
" thumb = image.copy()\n", " thumb = image.copy()\n",
" thumb.thumbnail([256, 256])\n", " thumb.thumbnail([256, 256])\n",
" display(thumb)\n", " display(thumb)\n",
"\n", "\n",
" print(prompt)\n", " if output_mode == 'rename':\n",
" name = sanitize_for_filename(prompt, max_filename_len)\n",
" ext = os.path.splitext(file)[1]\n",
" filename = name + ext\n",
" idx = 1\n",
" while os.path.exists(os.path.join(folder_path, filename)):\n",
" print(f'File {filename} already exists, trying {idx+1}...')\n",
" filename = f\"{name}_{idx}{ext}\"\n",
" idx += 1\n",
" os.rename(os.path.join(folder_path, file), os.path.join(folder_path, filename))\n",
"\n", "\n",
"if len(prompts):\n", "if len(prompts):\n",
" csv_path = os.path.join(folder_path, 'desc.csv')\n", " if output_mode == 'desc.csv':\n",
" with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n", " csv_path = os.path.join(folder_path, 'desc.csv')\n",
" w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n", " with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n",
" w.writerow(['image', 'prompt'])\n", " w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n",
" for file, prompt in zip(files, prompts):\n", " w.writerow(['image', 'prompt'])\n",
" w.writerow([file, prompt])\n", " for file, prompt in zip(files, prompts):\n",
"\n", " w.writerow([file, prompt])\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!\")\n", "\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and saved to {csv_path}, enjoy!\")\n",
" else:\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} prompts and renamed your files, enjoy!\")\n",
"else:\n", "else:\n",
" print(f\"Sorry, I couldn't find any images in {folder_path}\")\n" " print(f\"Sorry, I couldn't find any images in {folder_path}\")\n"
] ]
@ -208,7 +232,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.10 ('venv': venv)", "display_name": "Python 3.9.5 ('venv': venv)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -222,12 +246,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.10" "version": "3.9.5"
}, },
"orig_nbformat": 4, "orig_nbformat": 4,
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "f7a8d9541664ade9cff251487a19c76f2dd1b4c864d158f07ee26d1b0fd5c9a1" "hash": "10f7ca63a88f18f789e6adaf7a045f1bcd3706c5534a32f168d622925241605d"
} }
} }
}, },

2
clip_interrogator/__init__.py

@ -1,4 +1,4 @@
from .clip_interrogator import Interrogator, Config from .clip_interrogator import Interrogator, Config
__version__ = '0.1.4' __version__ = '0.2.0'
__author__ = 'pharmapsychotic' __author__ = 'pharmapsychotic'

28
clip_interrogator/clip_interrogator.py

@ -37,7 +37,8 @@ class Config:
chunk_size: int = 2048 chunk_size: int = 2048
data_path: str = os.path.join(os.path.dirname(__file__), 'data') data_path: str = os.path.join(os.path.dirname(__file__), 'data')
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
flavor_intermediate_count: int = 2048 flavor_intermediate_count: int = 2048
quiet: bool = False # when quiet progress bars are not shown
class Interrogator(): class Interrogator():
@ -46,7 +47,8 @@ class Interrogator():
self.device = config.device self.device = config.device
if config.blip_model is None: if config.blip_model is None:
print("Loading BLIP model...") if not config.quiet:
print("Loading BLIP model...")
blip_path = os.path.dirname(inspect.getfile(blip_decoder)) blip_path = os.path.dirname(inspect.getfile(blip_decoder))
configs_path = os.path.join(os.path.dirname(blip_path), 'configs') configs_path = os.path.join(os.path.dirname(blip_path), 'configs')
med_config = os.path.join(configs_path, 'med_config.json') med_config = os.path.join(configs_path, 'med_config.json')
@ -63,7 +65,8 @@ class Interrogator():
self.blip_model = config.blip_model self.blip_model = config.blip_model
if config.clip_model is None: if config.clip_model is None:
print("Loading CLIP model...") if not config.quiet:
print("Loading CLIP model...")
self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device) self.clip_model, self.clip_preprocess = clip.load(config.clip_model_name, device=config.device)
self.clip_model.to(config.device).eval() self.clip_model.to(config.device).eval()
else: else:
@ -111,7 +114,7 @@ class Interrogator():
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features return image_features
def interrogate_classic(self, image: Image, max_flaves: int=3) -> str: def interrogate_classic(self, image: Image, max_flavors: int=3) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
@ -119,7 +122,7 @@ class Interrogator():
artist = self.artists.rank(image_features, 1)[0] artist = self.artists.rank(image_features, 1)[0]
trending = self.trendings.rank(image_features, 1)[0] trending = self.trendings.rank(image_features, 1)[0]
movement = self.movements.rank(image_features, 1)[0] movement = self.movements.rank(image_features, 1)[0]
flaves = ", ".join(self.flavors.rank(image_features, max_flaves)) flaves = ", ".join(self.flavors.rank(image_features, max_flavors))
if caption.startswith(medium): if caption.startswith(medium):
prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}" prompt = f"{caption} {artist}, {trending}, {movement}, {flaves}"
@ -128,14 +131,14 @@ class Interrogator():
return _truncate_to_fit(prompt) return _truncate_to_fit(prompt)
def interrogate_fast(self, image: Image) -> str: def interrogate_fast(self, image: Image, max_flavors: int = 32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config) merged = _merge_tables([self.artists, self.flavors, self.mediums, self.movements, self.trendings], self.config)
tops = merged.rank(image_features, 32) tops = merged.rank(image_features, max_flavors)
return _truncate_to_fit(caption + ", " + ", ".join(tops)) return _truncate_to_fit(caption + ", " + ", ".join(tops))
def interrogate(self, image: Image) -> str: def interrogate(self, image: Image, max_flavors: int=32) -> str:
caption = self.generate_caption(image) caption = self.generate_caption(image)
image_features = self.image_to_features(image) image_features = self.image_to_features(image)
@ -175,7 +178,7 @@ class Interrogator():
check_multi_batch([best_medium, best_artist, best_trending, best_movement]) check_multi_batch([best_medium, best_artist, best_trending, best_movement])
extended_flavors = set(flaves) extended_flavors = set(flaves)
for _ in tqdm(range(25), desc="Flavor chain"): for _ in tqdm(range(max_flavors), desc="Flavor chain", disable=self.config.quiet):
try: try:
best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors]) best = self.rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
flave = best[len(best_prompt)+2:] flave = best[len(best_prompt)+2:]
@ -213,9 +216,10 @@ class Interrogator():
class LabelTable(): class LabelTable():
def __init__(self, labels:List[str], desc:str, clip_model, config: Config): def __init__(self, labels:List[str], desc:str, clip_model, config: Config):
self.chunk_size = config.chunk_size self.chunk_size = config.chunk_size
self.config = config
self.device = config.device self.device = config.device
self.labels = labels
self.embeds = [] self.embeds = []
self.labels = labels
hash = hashlib.sha256(",".join(labels).encode()).hexdigest() hash = hashlib.sha256(",".join(labels).encode()).hexdigest()
@ -234,7 +238,7 @@ class LabelTable():
if len(self.labels) != len(self.embeds): if len(self.labels) != len(self.embeds):
self.embeds = [] self.embeds = []
chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size))
for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None): for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet):
text_tokens = clip.tokenize(chunk).to(self.device) text_tokens = clip.tokenize(chunk).to(self.device)
with torch.no_grad(): with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float() text_features = clip_model.encode_text(text_tokens).float()
@ -270,7 +274,7 @@ class LabelTable():
keep_per_chunk = int(self.chunk_size / num_chunks) keep_per_chunk = int(self.chunk_size / num_chunks)
top_labels, top_embeds = [], [] top_labels, top_embeds = [], []
for chunk_idx in tqdm(range(num_chunks)): for chunk_idx in tqdm(range(num_chunks), disable=self.config.quiet):
start = chunk_idx*self.chunk_size start = chunk_idx*self.chunk_size
stop = min(start+self.chunk_size, len(self.embeds)) stop = min(start+self.chunk_size, len(self.embeds))
tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk) tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)

2
setup.py

@ -5,7 +5,7 @@ from setuptools import setup, find_packages
setup( setup(
name="clip-interrogator", name="clip-interrogator",
version="0.1.4", version="0.2.0",
license='MIT', license='MIT',
author='pharmapsychotic', author='pharmapsychotic',
author_email='me@pharmapsychotic.com', author_email='me@pharmapsychotic.com',

Loading…
Cancel
Save