From efee3fe0d70f1e32ea7a826a10fbd735025277e1 Mon Sep 17 00:00:00 2001 From: pharmapsychotic Date: Fri, 25 Nov 2022 11:41:44 -0600 Subject: [PATCH] Tweak defaults so runtime acceptable on T4 GPU in Colab --- clip_interrogator.ipynb | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/clip_interrogator.ipynb b/clip_interrogator.ipynb index f2c954e..5f05826 100644 --- a/clip_interrogator.ipynb +++ b/clip_interrogator.ipynb @@ -6,9 +6,7 @@ "id": "3jm8RYrLqvzz" }, "source": [ - "# CLIP Interrogator 2.1 ViTH special edition!\n", - "\n", - "### Please note \"best\" mode is currently not working on Colab properly with 16GB VRAM GPU. Your prompt will be cut quite short. Try out **classic** and **fast** in the mean time. Fix to come!\n", + "# CLIP Interrogator 2.1 ViT-H special edition!\n", "\n", "
\n", "\n", @@ -16,7 +14,7 @@ "\n", "
\n", "\n", - "This version is specialized for producing nice prompts for use with **Stable Diffusion 2.0** using the ViT-H-14 OpenCLIP model!" + "This version is specialized for producing nice prompts for use with **Stable Diffusion 2.0** using the **ViT-H-14** OpenCLIP model!\n" ] }, { @@ -78,14 +76,16 @@ "\n", "config = Config()\n", "config.blip_offload = True\n", - "config.chunk_size = 1024\n", - "config.flavor_intermediate_count = 1024\n", + "config.chunk_size = 2048\n", + "config.flavor_intermediate_count = 512\n", + "config.blip_num_beams = 64\n", + "\n", "ci = Interrogator(config)\n", "\n", - "def inference(image, mode):\n", + "def inference(image, mode, best_max_flavors):\n", " image = image.convert('RGB')\n", " if mode == 'best':\n", - " return ci.interrogate(image)\n", + " return ci.interrogate(image, max_flavors=int(best_max_flavors))\n", " elif mode == 'classic':\n", " return ci.interrogate_classic(image)\n", " else:\n", @@ -144,6 +144,7 @@ "inputs = [\n", " gr.inputs.Image(type='pil'),\n", " gr.Radio(['best', 'classic', 'fast'], label='', value='best'),\n", + " gr.Number(value=4, label='best mode max flavors'),\n", "]\n", "outputs = [\n", " gr.outputs.Textbox(label=\"Output\"),\n", @@ -185,6 +186,7 @@ "prompt_mode = 'best' #@param [\"best\",\"classic\", \"fast\"]\n", "output_mode = 'rename' #@param [\"desc.csv\",\"rename\"]\n", "max_filename_len = 128 #@param {type:\"integer\"}\n", + "best_max_flavors = 4 #@param {type:\"integer\"}\n", "\n", "\n", "def sanitize_for_filename(prompt: str, max_len: int) -> str:\n", @@ -201,7 +203,7 @@ " clear_output(wait=True)\n", "\n", " image = Image.open(os.path.join(folder_path, file)).convert('RGB')\n", - " prompt = inference(image, prompt_mode)\n", + " prompt = inference(image, prompt_mode, best_max_flavors=best_max_flavors)\n", " prompts.append(prompt)\n", "\n", " print(prompt)\n",