Browse Source

Free Colab is too much of a potato to handle model swapping so just pick it in the setup cell.

pull/34/head
pharmapsychotic 2 years ago
parent
commit
8b689592aa
  1. 24
      clip_interrogator.ipynb

24
clip_interrogator.ipynb

@ -63,7 +63,10 @@
"\n", "\n",
"setup()\n", "setup()\n",
"\n", "\n",
"# download cache files\n", "\n",
"clip_model_name = 'ViT-L-14/openai' #@param [\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"]\n",
"\n",
"\n",
"print(\"Download preprocessed cache files...\")\n", "print(\"Download preprocessed cache files...\")\n",
"CACHE_URLS = [\n", "CACHE_URLS = [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_artists.pkl',\n",
@ -71,6 +74,7 @@
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_mediums.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_movements.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-L-14_openai_trendings.pkl',\n",
"] if clip_model_name == 'ViT-L-14/openai' else [\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',\n",
" 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n", " 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',\n",
@ -81,6 +85,7 @@
"for url in CACHE_URLS:\n", "for url in CACHE_URLS:\n",
" print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", " print(subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n",
"\n", "\n",
"\n",
"import sys\n", "import sys\n",
"sys.path.append('src/blip')\n", "sys.path.append('src/blip')\n",
"sys.path.append('clip-interrogator')\n", "sys.path.append('clip-interrogator')\n",
@ -91,16 +96,12 @@
"config = Config()\n", "config = Config()\n",
"config.blip_num_beams = 64\n", "config.blip_num_beams = 64\n",
"config.blip_offload = False\n", "config.blip_offload = False\n",
"config.chunk_size = 2048\n", "config.clip_model_name = clip_model_name\n",
"config.flavor_intermediate_count = 2048\n",
"\n",
"ci = Interrogator(config)\n", "ci = Interrogator(config)\n",
"\n", "\n",
"def inference(image, mode, clip_model_name, best_max_flavors=32):\n", "def inference(image, mode, best_max_flavors=32):\n",
" if clip_model_name != ci.config.clip_model_name:\n", " ci.config.chunk_size = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.config.clip_model_name = clip_model_name\n", " ci.config.flavor_intermediate_count = 2048 if ci.config.clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" ci.load_clip_model()\n",
" ci.config.flavor_intermediate_count = 2048 if clip_model_name == \"ViT-L-14/openai\" else 1024\n",
" image = image.convert('RGB')\n", " image = image.convert('RGB')\n",
" if mode == 'best':\n", " if mode == 'best':\n",
" return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n",
@ -161,8 +162,7 @@
" \n", " \n",
"inputs = [\n", "inputs = [\n",
" gr.inputs.Image(type='pil'),\n", " gr.inputs.Image(type='pil'),\n",
" gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", " gr.Radio(['best', 'fast'], label='', value='best'),\n",
" gr.Dropdown([\"ViT-L-14/openai\", \"ViT-H-14/laion2b_s32b_b79k\"], value='ViT-L-14/openai', label='CLIP Model'),\n",
" gr.Number(value=16, label='best mode max flavors'),\n", " gr.Number(value=16, label='best mode max flavors'),\n",
"]\n", "]\n",
"outputs = [\n", "outputs = [\n",
@ -175,7 +175,7 @@
" outputs, \n", " outputs, \n",
" allow_flagging=False,\n", " allow_flagging=False,\n",
")\n", ")\n",
"io.launch()\n" "io.launch(debug=False)\n"
] ]
}, },
{ {

Loading…
Cancel
Save