{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "993a2a24-1a58-42be-8034-6d116fb8d786",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import re\n",
"import math\n",
"import json\n",
"from tqdm import tqdm\n",
"import random\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"import numpy as np\n",
"import pickle\n",
"from sentence_transformers import SentenceTransformer\n",
"from datasets import load_dataset\n",
"import chromadb\n",
"from items import Item\n",
"from sklearn.manifold import TSNE\n",
"import plotly.graph_objects as go"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0e31676f-6f31-465f-a80e-02d51ff8425a",
"metadata": {},
"outputs": [],
"source": [
"# CONSTANTS\n",
"\n",
"HF_USER = \"ed-donner\" # your HF name here! Or use mine if you just want to reproduce my results.\n",
"DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
"QUESTION = \"How much does this cost to the nearest dollar?\\n\\n\"\n",
"DB = \"products_vectorstore\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2359ccc0-dbf2-4b1e-9473-e472b32f548b",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv()\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a29fcc4e-e4d7-4c54-aa6b-e5d1111ea9c4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Token is valid (permission: write).\n",
"Your token has been saved in your configured git credential helpers (osxkeychain).\n",
"Your token has been saved to /Users/ed/.cache/huggingface/token\n",
"Login successful\n"
]
}
],
"source": [
"# Log in to HuggingFace\n",
"\n",
"hf_token = os.environ['HF_TOKEN']\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "688bd995-ec3e-43cd-8179-7fe14b275877",
"metadata": {},
"outputs": [],
"source": [
"# Let's avoid curating all our data again! Load in the pickle files:\n",
"\n",
"with open('train.pkl', 'rb') as file:\n",
" train = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "2817eaf5-4302-4a18-9148-d1062e3b3dbb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"400000"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"items = train\n",
"len(items)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "f4aab95e-d719-4476-b6e7-e248120df25a",
"metadata": {},
"outputs": [],
"source": [
"client = chromadb.PersistentClient(path=DB)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5f95dafd-ab80-464e-ba8a-dec7a2424780",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Deleted existing collection: products\n"
]
}
],
"source": [
"# Check if the collection exists and delete it if it does\n",
"collection_name = \"products\"\n",
"existing_collection_names = [collection.name for collection in client.list_collections()]\n",
"if collection_name in existing_collection_names:\n",
" client.delete_collection(collection_name)\n",
" print(f\"Deleted existing collection: {collection_name}\")\n",
"\n",
"collection = client.create_collection(collection_name)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a87db200-d19d-44bf-acbd-15c45c70f5c9",
"metadata": {},
"outputs": [],
"source": [
"model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "38de1bf8-c9b5-45b4-9f4b-86af93b3f80d",
"metadata": {},
"outputs": [],
"source": [
"def description(item):\n",
" text = item.prompt.replace(\"How much does this cost to the nearest dollar?\\n\\n\", \"\")\n",
" return text.split(\"\\n\\nPrice is $\")[0]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8c79e2fe-1f50-4ebf-9a93-34f3088f2996",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [21:47<00:00, 3.27s/it]\n"
]
}
],
"source": [
"for i in tqdm(range(0, len(items), 1000)):\n",
" documents = [description(item) for item in items[i: i+1000]]\n",
" vectors = model.encode(documents).astype(float).tolist()\n",
" metadatas = [{\"category\": item.category, \"price\": item.price} for item in items[i: i+1000]]\n",
" ids = [f\"doc_{j}\" for j in range(i, i+1000)]\n",
" collection.add(\n",
" ids=ids,\n",
" documents=documents,\n",
" embeddings=vectors,\n",
" metadatas=metadatas\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "525fc313-8a16-4ac0-8c42-6a6d1ba1c9b8",
"metadata": {},
"outputs": [],
"source": [
"CATEGORIES = ['Appliances', 'Automotive', 'Cell_Phones_and_Accessories', 'Electronics','Musical_Instruments', 'Office_Products', 'Tools_and_Home_Improvement', 'Toys_and_Games']\n",
"COLORS = ['red', 'blue', 'brown', 'orange', 'yellow', 'green' , 'purple', 'cyan']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4cf1c9a-1ced-48d4-974c-3c850905034e",
"metadata": {},
"outputs": [],
"source": [
"# Prework\n",
"\n",
"vectors_np = np.array(vectors)\n",
"colors = [COLORS[CATEGORIES.index(t)] for t in categories]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c6718b3-e0fd-4319-a1b5-d9d34d6b1dd9",
"metadata": {},
"outputs": [],
"source": [
"# We humans find it easier to visalize things in 2D!\n",
"# Reduce the dimensionality of the vectors to 2D using t-SNE\n",
"# (t-distributed stochastic neighbor embedding)\n",
"\n",
"tsne = TSNE(n_components=2, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors_np)\n",
"\n",
"# Create the 2D scatter plot\n",
"fig = go.Figure(data=[go.Scatter(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" mode='markers',\n",
" marker=dict(size=3, color=colors, opacity=0.8),\n",
" text=[f\"Category: {c}
Text: {d[:100]}...\" for c, d in zip(categories, descriptions)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='2D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x',yaxis_title='y'),\n",
" width=1200,\n",
" height=800,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c54df150-c8d8-4bc3-8877-6759691eeb42",
"metadata": {},
"outputs": [],
"source": [
"# Let's try 3D!\n",
"\n",
"tsne = TSNE(n_components=3, random_state=42)\n",
"reduced_vectors = tsne.fit_transform(vectors_np)\n",
"\n",
"# Create the 3D scatter plot\n",
"fig = go.Figure(data=[go.Scatter3d(\n",
" x=reduced_vectors[:, 0],\n",
" y=reduced_vectors[:, 1],\n",
" z=reduced_vectors[:, 2],\n",
" mode='markers',\n",
" marker=dict(size=3, color=colors, opacity=0.7),\n",
" text=[f\"Category: {c}
Text: {d[:100]}...\" for c, d in zip(categories, descriptions)],\n",
" hoverinfo='text'\n",
")])\n",
"\n",
"fig.update_layout(\n",
" title='3D Chroma Vector Store Visualization',\n",
" scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n",
" width=1200,\n",
" height=800,\n",
" margin=dict(r=20, b=10, l=10, t=40)\n",
")\n",
"\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8fb2a63-24c5-4dce-9e63-aa208272f82d",
"metadata": {},
"outputs": [],
"source": [
"def "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}