Browse Source

Batch processing in notebook and run_cli!

replicate
pharmapsychotic 2 years ago
parent
commit
31b1d22e82
  1. 126
      clip_interrogator.ipynb
  2. 67
      run_cli.py

126
clip_interrogator.ipynb

@ -6,7 +6,7 @@
"id": "3jm8RYrLqvzz"
},
"source": [
"# CLIP Interrogator 2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n",
"<br>\n",
"\n",
@ -70,25 +70,34 @@
"import torch\n",
"from clip_interrogator import Interrogator, Config\n",
"\n",
"ci = Interrogator(Config())\n"
"ci = Interrogator(Config())\n",
"\n",
"def inference(image, mode):\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n",
" else:\n",
" return ci.interrogate_fast(image)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 677
},
"cellView": "form",
"id": "Pf6qkFG6MPRj",
"outputId": "5f959af5-f6dd-43f2-f8df-8331a422d317"
"outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"output_type": "stream",
"text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n",
@ -99,64 +108,29 @@
]
},
{
"output_type": "display_data",
"data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"(async (port, path, width, height, cache, element) => {\n",
" if (!google.colab.kernel.accessAllowed && !cache) {\n",
" return;\n",
" }\n",
" element.appendChild(document.createTextNode(''));\n",
" const url = await google.colab.kernel.proxyPort(port, {cache});\n",
"\n",
" const external_link = document.createElement('div');\n",
" external_link.innerHTML = `\n",
" <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n",
" Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n",
" https://localhost:${port}${path}\n",
" </a>\n",
" </div>\n",
" `;\n",
" element.appendChild(external_link);\n",
"\n",
" const iframe = document.createElement('iframe');\n",
" iframe.src = new URL(path, url).toString();\n",
" iframe.height = height;\n",
" iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n",
" iframe.width = width;\n",
" iframe.style.border = 0;\n",
" element.appendChild(iframe);\n",
" })(7866, \"/\", \"100%\", 500, false, window.element)"
]
},
"metadata": {}
"metadata": {},
"output_type": "display_data"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(<gradio.routes.App at 0x7f6f06fc3450>, 'http://127.0.0.1:7866/', None)"
"(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
]
},
"execution_count": 4,
"metadata": {},
"execution_count": 9
"output_type": "execute_result"
}
],
"source": [
"#@title Run!\n",
"\n",
"def inference(image, mode):\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n",
" else:\n",
" return ci.interrogate_fast(image)\n",
" \n",
"#@title Image to prompt! 🖼 -> 📝\n",
" \n",
"inputs = [\n",
" gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n",
@ -173,6 +147,58 @@
")\n",
"io.launch()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "OGmvkzITN4Hz"
},
"outputs": [],
"source": [
"#@title Batch process a folder of images 📁 -> 📝\n",
"\n",
"#@markdown This will generate prompts for every image in a folder and save results to desc.csv in the same folder.\n",
"#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n",
"#@markdown notebook or use it to train a different model or just print it out for fun. \n",
"#@markdown If you make something cool, I'd love to see it! Tag me on twitter or something. 😀\n",
"\n",
"import csv\n",
"import os\n",
"from IPython.display import display\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n",
"\n",
"\n",
"files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
"prompts = []\n",
"for file in files:\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, mode)\n",
" prompts.append(prompt)\n",
"\n",
" thumb = image.copy()\n",
" thumb.thumbnail([256, 256])\n",
" display(thumb)\n",
"\n",
" print(prompt)\n",
"\n",
"if len(prompts):\n",
" csv_path = os.path.join(folder_path, 'desc.csv')\n",
" with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n",
" w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n",
" w.writerow(['image', 'prompt'])\n",
" for file, prompt in zip(files, prompts):\n",
" w.writerow([file, prompt])\n",
"\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!\")\n",
"else:\n",
" print(f\"Sorry, I couldn't find any images in {folder_path}\")\n"
]
}
],
"metadata": {
@ -207,4 +233,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

67
run_cli.py

@ -1,30 +1,36 @@
#!/usr/bin/env python3
import argparse
import clip
import csv
import os
import requests
import torch
from PIL import Image
from clip_interrogator import Interrogator, Config
def inference(ci, image, mode):
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
args = parser.parse_args()
if not args.image:
if not args.folder and not args.image:
parser.print_help()
exit(1)
# load image
image_path = args.image
if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path).convert('RGB')
if not image:
print(f'Error opening image {image_path}')
if args.folder is not None and args.image is not None:
print("Specify a folder or batch processing or a single image, not both")
exit(1)
# validate clip model name
@ -37,13 +43,42 @@ def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config)
if args.mode == 'best':
prompt = ci.interrogate(image)
elif args.mode == 'classic':
prompt = ci.interrogate_classic(image)
else:
prompt = ci.interrogate_fast(image)
print(prompt)
# process single image
if args.image is not None:
image_path = args.image
if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path).convert('RGB')
if not image:
print(f'Error opening image {image_path}')
exit(1)
print(inference(ci, image, args.mode))
# process folder of images
elif args.folder is not None:
if not os.path.exists(args.folder):
print(f'The folder {args.folder} does not exist!')
exit(1)
files = [f for f in os.listdir(args.folder) if f.endswith('.jpg') or f.endswith('.png')]
prompts = []
for file in files:
image = Image.open(os.path.join(args.folder, file)).convert('RGB')
prompt = inference(ci, image, args.mode)
prompts.append(prompt)
print(prompt)
if len(prompts):
csv_path = os.path.join(args.folder, 'desc.csv')
with open(csv_path, 'w', encoding='utf-8', newline='') as f:
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
w.writerow(['image', 'prompt'])
for file, prompt in zip(files, prompts):
w.writerow([file, prompt])
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!")
if __name__ == "__main__":
main()

Loading…
Cancel
Save