Browse Source

Batch processing in notebook and run_cli!

replicate
pharmapsychotic 2 years ago
parent
commit
31b1d22e82
  1. 124
      clip_interrogator.ipynb
  2. 67
      run_cli.py

124
clip_interrogator.ipynb

@ -6,7 +6,7 @@
"id": "3jm8RYrLqvzz" "id": "3jm8RYrLqvzz"
}, },
"source": [ "source": [
"# CLIP Interrogator 2 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n", "# CLIP Interrogator 2.1 by [@pharmapsychotic](https://twitter.com/pharmapsychotic) \n",
"\n", "\n",
"<br>\n", "<br>\n",
"\n", "\n",
@ -70,25 +70,34 @@
"import torch\n", "import torch\n",
"from clip_interrogator import Interrogator, Config\n", "from clip_interrogator import Interrogator, Config\n",
"\n", "\n",
"ci = Interrogator(Config())\n" "ci = Interrogator(Config())\n",
"\n",
"def inference(image, mode):\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n",
" else:\n",
" return ci.interrogate_fast(image)\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 4,
"metadata": { "metadata": {
"cellView": "form",
"colab": { "colab": {
"base_uri": "https://localhost:8080/", "base_uri": "https://localhost:8080/",
"height": 677 "height": 677
}, },
"cellView": "form",
"id": "Pf6qkFG6MPRj", "id": "Pf6qkFG6MPRj",
"outputId": "5f959af5-f6dd-43f2-f8df-8331a422d317" "outputId": "8d542b56-8be7-453d-bf27-d0540a774c7d"
}, },
"outputs": [ "outputs": [
{ {
"output_type": "stream",
"name": "stdout", "name": "stdout",
"output_type": "stream",
"text": [ "text": [
"Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n",
"\n", "\n",
@ -99,64 +108,29 @@
] ]
}, },
{ {
"output_type": "display_data",
"data": { "data": {
"application/javascript": "(async (port, path, width, height, cache, element) => {\n if (!google.colab.kernel.accessAllowed && !cache) {\n return;\n }\n element.appendChild(document.createTextNode(''));\n const url = await google.colab.kernel.proxyPort(port, {cache});\n\n const external_link = document.createElement('div');\n external_link.innerHTML = `\n <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n https://localhost:${port}${path}\n </a>\n </div>\n `;\n element.appendChild(external_link);\n\n const iframe = document.createElement('iframe');\n iframe.src = new URL(path, url).toString();\n iframe.height = height;\n iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n iframe.width = width;\n iframe.style.border = 0;\n element.appendChild(iframe);\n })(7860, \"/\", \"100%\", 500, false, window.element)",
"text/plain": [ "text/plain": [
"<IPython.core.display.Javascript object>" "<IPython.core.display.Javascript object>"
],
"application/javascript": [
"(async (port, path, width, height, cache, element) => {\n",
" if (!google.colab.kernel.accessAllowed && !cache) {\n",
" return;\n",
" }\n",
" element.appendChild(document.createTextNode(''));\n",
" const url = await google.colab.kernel.proxyPort(port, {cache});\n",
"\n",
" const external_link = document.createElement('div');\n",
" external_link.innerHTML = `\n",
" <div style=\"font-family: monospace; margin-bottom: 0.5rem\">\n",
" Running on <a href=${new URL(path, url).toString()} target=\"_blank\">\n",
" https://localhost:${port}${path}\n",
" </a>\n",
" </div>\n",
" `;\n",
" element.appendChild(external_link);\n",
"\n",
" const iframe = document.createElement('iframe');\n",
" iframe.src = new URL(path, url).toString();\n",
" iframe.height = height;\n",
" iframe.allow = \"autoplay; camera; microphone; clipboard-read; clipboard-write;\"\n",
" iframe.width = width;\n",
" iframe.style.border = 0;\n",
" element.appendChild(iframe);\n",
" })(7866, \"/\", \"100%\", 500, false, window.element)"
] ]
}, },
"metadata": {} "metadata": {},
"output_type": "display_data"
}, },
{ {
"output_type": "execute_result",
"data": { "data": {
"text/plain": [ "text/plain": [
"(<gradio.routes.App at 0x7f6f06fc3450>, 'http://127.0.0.1:7866/', None)" "(<gradio.routes.App at 0x7f894e553710>, 'http://127.0.0.1:7860/', None)"
] ]
}, },
"execution_count": 4,
"metadata": {}, "metadata": {},
"execution_count": 9 "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"#@title Run!\n", "#@title Image to prompt! 🖼 -> 📝\n",
"\n", " \n",
"def inference(image, mode):\n",
" image = image.convert('RGB')\n",
" if mode == 'best':\n",
" return ci.interrogate(image)\n",
" elif mode == 'classic':\n",
" return ci.interrogate_classic(image)\n",
" else:\n",
" return ci.interrogate_fast(image)\n",
" \n",
"inputs = [\n", "inputs = [\n",
" gr.inputs.Image(type='pil'),\n", " gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n",
@ -173,6 +147,58 @@
")\n", ")\n",
"io.launch()\n" "io.launch()\n"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "OGmvkzITN4Hz"
},
"outputs": [],
"source": [
"#@title Batch process a folder of images 📁 -> 📝\n",
"\n",
"#@markdown This will generate prompts for every image in a folder and save results to desc.csv in the same folder.\n",
"#@markdown You can use the generated csv in the [Stable Diffusion Finetuning](https://colab.research.google.com/drive/1vrh_MUSaAMaC5tsLWDxkFILKJ790Z4Bl?usp=sharing)\n",
"#@markdown notebook or use it to train a different model or just print it out for fun. \n",
"#@markdown If you make something cool, I'd love to see it! Tag me on twitter or something. 😀\n",
"\n",
"import csv\n",
"import os\n",
"from IPython.display import display\n",
"from PIL import Image\n",
"from tqdm import tqdm\n",
"\n",
"folder_path = \"/content/my_images\" #@param {type:\"string\"}\n",
"mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n",
"\n",
"\n",
"files = [f for f in os.listdir(folder_path) if f.endswith('.jpg') or f.endswith('.png')] if os.path.exists(folder_path) else []\n",
"prompts = []\n",
"for file in files:\n",
" image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n",
" prompt = inference(image, mode)\n",
" prompts.append(prompt)\n",
"\n",
" thumb = image.copy()\n",
" thumb.thumbnail([256, 256])\n",
" display(thumb)\n",
"\n",
" print(prompt)\n",
"\n",
"if len(prompts):\n",
" csv_path = os.path.join(folder_path, 'desc.csv')\n",
" with open(csv_path, 'w', encoding='utf-8', newline='') as f:\n",
" w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)\n",
" w.writerow(['image', 'prompt'])\n",
" for file, prompt in zip(files, prompts):\n",
" w.writerow([file, prompt])\n",
"\n",
" print(f\"\\n\\n\\n\\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!\")\n",
"else:\n",
" print(f\"Sorry, I couldn't find any images in {folder_path}\")\n"
]
} }
], ],
"metadata": { "metadata": {

67
run_cli.py

@ -1,30 +1,36 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import clip import clip
import csv
import os
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
from clip_interrogator import Interrogator, Config from clip_interrogator import Interrogator, Config
def inference(ci, image, mode):
image = image.convert('RGB')
if mode == 'best':
return ci.interrogate(image)
elif mode == 'classic':
return ci.interrogate_classic(image)
else:
return ci.interrogate_fast(image)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use') parser.add_argument('-c', '--clip', default='ViT-L/14', help='name of CLIP model to use')
parser.add_argument('-f', '--folder', help='path to folder of images')
parser.add_argument('-i', '--image', help='image file or url') parser.add_argument('-i', '--image', help='image file or url')
parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast') parser.add_argument('-m', '--mode', default='best', help='best, classic, or fast')
args = parser.parse_args() args = parser.parse_args()
if not args.image: if not args.folder and not args.image:
parser.print_help() parser.print_help()
exit(1) exit(1)
# load image if args.folder is not None and args.image is not None:
image_path = args.image print("Specify a folder or batch processing or a single image, not both")
if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path).convert('RGB')
if not image:
print(f'Error opening image {image_path}')
exit(1) exit(1)
# validate clip model name # validate clip model name
@ -37,13 +43,42 @@ def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Config(device=device, clip_model_name=args.clip) config = Config(device=device, clip_model_name=args.clip)
ci = Interrogator(config) ci = Interrogator(config)
if args.mode == 'best':
prompt = ci.interrogate(image) # process single image
elif args.mode == 'classic': if args.image is not None:
prompt = ci.interrogate_classic(image) image_path = args.image
else: if str(image_path).startswith('http://') or str(image_path).startswith('https://'):
prompt = ci.interrogate_fast(image) image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
print(prompt) else:
image = Image.open(image_path).convert('RGB')
if not image:
print(f'Error opening image {image_path}')
exit(1)
print(inference(ci, image, args.mode))
# process folder of images
elif args.folder is not None:
if not os.path.exists(args.folder):
print(f'The folder {args.folder} does not exist!')
exit(1)
files = [f for f in os.listdir(args.folder) if f.endswith('.jpg') or f.endswith('.png')]
prompts = []
for file in files:
image = Image.open(os.path.join(args.folder, file)).convert('RGB')
prompt = inference(ci, image, args.mode)
prompts.append(prompt)
print(prompt)
if len(prompts):
csv_path = os.path.join(args.folder, 'desc.csv')
with open(csv_path, 'w', encoding='utf-8', newline='') as f:
w = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
w.writerow(['image', 'prompt'])
for file, prompt in zip(files, prompts):
w.writerow([file, prompt])
print(f"\n\n\n\nGenerated {len(prompts)} and saved to {csv_path}, enjoy!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

Loading…
Cancel
Save