In [1]:
# imports

import os
import re
import math
import json
from tqdm import tqdm
import random
from dotenv import load_dotenv
from huggingface_hub import login
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import chromadb
from items import Item
from sklearn.manifold import TSNE
import plotly.graph_objects as go

In [2]:
# CONSTANTS

HF_USER = "ed-donner" # your HF name here! Or use mine if you just want to reproduce my results.
DATASET_NAME = f"{HF_USER}/pricer-data"
QUESTION = "How much does this cost to the nearest dollar?\n\n"
DB = "products_vectorstore"

In [3]:
# environment

load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')

In [4]:
# Log in to HuggingFace

hf_token = os.environ['HF_TOKEN']
login(hf_token, add_to_git_credential=True)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (osxkeychain).
Your token has been saved to /Users/ed/.cache/huggingface/token
Login successful


In [5]:
# Let's avoid curating all our data again! Load in the pickle files:

with open('train.pkl', 'rb') as file:
 train = pickle.load(file)

In [6]:
items = train
len(items)

400000

In [11]:
client = chromadb.PersistentClient(path=DB)

In [12]:
# Check if the collection exists and delete it if it does
collection_name = "products"
existing_collection_names = [collection.name for collection in client.list_collections()]
if collection_name in existing_collection_names:
 client.delete_collection(collection_name)
 print(f"Deleted existing collection: {collection_name}")

collection = client.create_collection(collection_name)

Deleted existing collection: products


In [13]:
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [14]:
def description(item):
 text = item.prompt.replace("How much does this cost to the nearest dollar?\n\n", "")
 return text.split("\n\nPrice is $")[0]

In [15]:
for i in tqdm(range(0, len(items), 1000)):
 documents = [description(item) for item in items[i: i+1000]]
 vectors = model.encode(documents).astype(float).tolist()
 metadatas = [{"category": item.category, "price": item.price} for item in items[i: i+1000]]
 ids = [f"doc_{j}" for j in range(i, i+1000)]
 collection.add(
 ids=ids,
 documents=documents,
 embeddings=vectors,
 metadatas=metadatas
 )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [21:47<00:00, 3.27s/it]


In [None]:
CATEGORIES = ['Appliances', 'Automotive', 'Cell_Phones_and_Accessories', 'Electronics','Musical_Instruments', 'Office_Products', 'Tools_and_Home_Improvement', 'Toys_and_Games']
COLORS = ['red', 'blue', 'brown', 'orange', 'yellow', 'green' , 'purple', 'cyan']