{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "993a2a24-1a58-42be-8034-6d116fb8d786", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import re\n", "import math\n", "import json\n", "from tqdm import tqdm\n", "import random\n", "from dotenv import load_dotenv\n", "from huggingface_hub import login\n", "import numpy as np\n", "import pickle\n", "from sentence_transformers import SentenceTransformer\n", "from datasets import load_dataset\n", "import chromadb\n", "from items import Item\n", "from sklearn.manifold import TSNE\n", "import plotly.graph_objects as go" ] }, { "cell_type": "code", "execution_count": 2, "id": "0e31676f-6f31-465f-a80e-02d51ff8425a", "metadata": {}, "outputs": [], "source": [ "# CONSTANTS\n", "\n", "HF_USER = \"ed-donner\" # your HF name here! Or use mine if you just want to reproduce my results.\n", "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n", "QUESTION = \"How much does this cost to the nearest dollar?\\n\\n\"\n", "DB = \"products_vectorstore\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "2359ccc0-dbf2-4b1e-9473-e472b32f548b", "metadata": {}, "outputs": [], "source": [ "# environment\n", "\n", "load_dotenv()\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" ] }, { "cell_type": "code", "execution_count": 4, "id": "a29fcc4e-e4d7-4c54-aa6b-e5d1111ea9c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Token is valid (permission: write).\n", "Your token has been saved in your configured git credential helpers (osxkeychain).\n", "Your token has been saved to /Users/ed/.cache/huggingface/token\n", "Login successful\n" ] } ], "source": [ "# Log in to HuggingFace\n", "\n", "hf_token = os.environ['HF_TOKEN']\n", "login(hf_token, add_to_git_credential=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "688bd995-ec3e-43cd-8179-7fe14b275877", "metadata": {}, "outputs": [], "source": [ "# Let's avoid curating all our data again! Load in the pickle files:\n", "\n", "with open('train.pkl', 'rb') as file:\n", " train = pickle.load(file)" ] }, { "cell_type": "code", "execution_count": 6, "id": "2817eaf5-4302-4a18-9148-d1062e3b3dbb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "400000" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "items = train\n", "len(items)" ] }, { "cell_type": "code", "execution_count": 11, "id": "f4aab95e-d719-4476-b6e7-e248120df25a", "metadata": {}, "outputs": [], "source": [ "client = chromadb.PersistentClient(path=DB)" ] }, { "cell_type": "code", "execution_count": 12, "id": "5f95dafd-ab80-464e-ba8a-dec7a2424780", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Deleted existing collection: products\n" ] } ], "source": [ "# Check if the collection exists and delete it if it does\n", "collection_name = \"products\"\n", "existing_collection_names = [collection.name for collection in client.list_collections()]\n", "if collection_name in existing_collection_names:\n", " client.delete_collection(collection_name)\n", " print(f\"Deleted existing collection: {collection_name}\")\n", "\n", "collection = client.create_collection(collection_name)" ] }, { "cell_type": "code", "execution_count": 13, "id": "a87db200-d19d-44bf-acbd-15c45c70f5c9", "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')" ] }, { "cell_type": "code", "execution_count": 14, "id": "38de1bf8-c9b5-45b4-9f4b-86af93b3f80d", "metadata": {}, "outputs": [], "source": [ "def description(item):\n", " text = item.prompt.replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n", " return text.split(\"\\n\\nPrice is $\")[0]" ] }, { "cell_type": "code", "execution_count": 15, "id": "8c79e2fe-1f50-4ebf-9a93-34f3088f2996", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [21:47<00:00, 3.27s/it]\n" ] } ], "source": [ "for i in tqdm(range(0, len(items), 1000)):\n", " documents = [description(item) for item in items[i: i+1000]]\n", " vectors = model.encode(documents).astype(float).tolist()\n", " metadatas = [{\"category\": item.category, \"price\": item.price} for item in items[i: i+1000]]\n", " ids = [f\"doc_{j}\" for j in range(i, i+1000)]\n", " collection.add(\n", " ids=ids,\n", " documents=documents,\n", " embeddings=vectors,\n", " metadatas=metadatas\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "525fc313-8a16-4ac0-8c42-6a6d1ba1c9b8", "metadata": {}, "outputs": [], "source": [ "CATEGORIES = ['Appliances', 'Automotive', 'Cell_Phones_and_Accessories', 'Electronics','Musical_Instruments', 'Office_Products', 'Tools_and_Home_Improvement', 'Toys_and_Games']\n", "COLORS = ['red', 'blue', 'brown', 'orange', 'yellow', 'green' , 'purple', 'cyan']" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }