{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "993a2a24-1a58-42be-8034-6d116fb8d786",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import re\n",
    "import math\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from dotenv import load_dotenv\n",
    "from huggingface_hub import login\n",
    "import numpy as np\n",
    "import pickle\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from datasets import load_dataset\n",
    "import chromadb\n",
    "from items import Item\n",
    "from sklearn.manifold import TSNE\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0e31676f-6f31-465f-a80e-02d51ff8425a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "\n",
    "HF_USER = \"ed-donner\" # your HF name here! Or use mine if you just want to reproduce my results.\n",
    "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
    "QUESTION = \"How much does this cost to the nearest dollar?\\n\\n\"\n",
    "DB = \"products_vectorstore\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2359ccc0-dbf2-4b1e-9473-e472b32f548b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "\n",
    "load_dotenv()\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a29fcc4e-e4d7-4c54-aa6b-e5d1111ea9c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token is valid (permission: write).\n",
      "Your token has been saved in your configured git credential helpers (osxkeychain).\n",
      "Your token has been saved to /Users/ed/.cache/huggingface/token\n",
      "Login successful\n"
     ]
    }
   ],
   "source": [
    "# Log in to HuggingFace\n",
    "\n",
    "hf_token = os.environ['HF_TOKEN']\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "688bd995-ec3e-43cd-8179-7fe14b275877",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's avoid curating all our data again! Load in the pickle files:\n",
    "\n",
    "with open('train.pkl', 'rb') as file:\n",
    "    train = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2817eaf5-4302-4a18-9148-d1062e3b3dbb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "400000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "items = train\n",
    "len(items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f4aab95e-d719-4476-b6e7-e248120df25a",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = chromadb.PersistentClient(path=DB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5f95dafd-ab80-464e-ba8a-dec7a2424780",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deleted existing collection: products\n"
     ]
    }
   ],
   "source": [
    "# Check if the collection exists and delete it if it does\n",
    "collection_name = \"products\"\n",
    "existing_collection_names = [collection.name for collection in client.list_collections()]\n",
    "if collection_name in existing_collection_names:\n",
    "    client.delete_collection(collection_name)\n",
    "    print(f\"Deleted existing collection: {collection_name}\")\n",
    "\n",
    "collection = client.create_collection(collection_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a87db200-d19d-44bf-acbd-15c45c70f5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "38de1bf8-c9b5-45b4-9f4b-86af93b3f80d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def description(item):\n",
    "    text = item.prompt.replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
    "    return text.split(\"\\n\\nPrice is $\")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8c79e2fe-1f50-4ebf-9a93-34f3088f2996",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [21:47<00:00,  3.27s/it]\n"
     ]
    }
   ],
   "source": [
    "for i in tqdm(range(0, len(items), 1000)):\n",
    "    documents = [description(item) for item in items[i: i+1000]]\n",
    "    vectors = model.encode(documents).astype(float).tolist()\n",
    "    metadatas = [{\"category\": item.category, \"price\": item.price} for item in items[i: i+1000]]\n",
    "    ids = [f\"doc_{j}\" for j in range(i, i+1000)]\n",
    "    collection.add(\n",
    "        ids=ids,\n",
    "        documents=documents,\n",
    "        embeddings=vectors,\n",
    "        metadatas=metadatas\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "525fc313-8a16-4ac0-8c42-6a6d1ba1c9b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "CATEGORIES = ['Appliances', 'Automotive', 'Cell_Phones_and_Accessories', 'Electronics','Musical_Instruments', 'Office_Products', 'Tools_and_Home_Improvement', 'Toys_and_Games']\n",
    "COLORS = ['red', 'blue', 'brown', 'orange', 'yellow', 'green' , 'purple', 'cyan']"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}