Browse Source

Change load_list calls to directly use path so users can experiment with custom list with one line change

pull/1/head
pharmapsychotic 2 years ago
parent
commit
f36e13caef
  1. 18
      clip_interrogator.ipynb

18
clip_interrogator.ipynb

@ -50,12 +50,10 @@
"\n", "\n",
"import clip\n", "import clip\n",
"import gc\n", "import gc\n",
"import io\n",
"import math\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n",
"import pandas as pd\n", "import pandas as pd\n",
"import requests\n", "import requests\n",
"import sys\n",
"import torch\n", "import torch\n",
"import torchvision.transforms as T\n", "import torchvision.transforms as T\n",
"import torchvision.transforms.functional as TF\n", "import torchvision.transforms.functional as TF\n",
@ -87,8 +85,8 @@
" caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)\n", " caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)\n",
" return caption[0]\n", " return caption[0]\n",
"\n", "\n",
"def load_list(name):\n", "def load_list(filename):\n",
" with open(f\"/content/clip-interrogator/data/{name}.txt\", 'r', encoding='utf-8', errors='replace') as f:\n", " with open(filename, 'r', encoding='utf-8', errors='replace') as f:\n",
" items = [line.strip() for line in f.readlines()]\n", " items = [line.strip() for line in f.readlines()]\n",
" return items\n", " return items\n",
"\n", "\n",
@ -157,11 +155,13 @@
" else:\n", " else:\n",
" print(f\"\\n\\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}\")\n", " print(f\"\\n\\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}\")\n",
"\n", "\n",
"data_path = \"../clip-interrogator/data/\"\n",
"\n",
"artists = load_list(os.path.join(data_path, 'artists.txt'))\n",
"flavors = load_list(os.path.join(data_path, 'flavors.txt'))\n",
"mediums = load_list(os.path.join(data_path, 'mediums.txt'))\n",
"movements = load_list(os.path.join(data_path, 'movements.txt'))\n",
"\n", "\n",
"artists = load_list('artists')\n",
"flavors = load_list('flavors')\n",
"mediums = load_list('mediums')\n",
"movements = load_list('movements')\n",
"sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']\n", "sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']\n",
"trending_list = [site for site in sites]\n", "trending_list = [site for site in sites]\n",
"trending_list.extend([\"trending on \"+site for site in sites])\n", "trending_list.extend([\"trending on \"+site for site in sites])\n",

Loading…
Cancel
Save