{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "993a2a24-1a58-42be-8034-6d116fb8d786", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import re\n", "import math\n", "import json\n", "from tqdm import tqdm\n", "import random\n", "from dotenv import load_dotenv\n", "from huggingface_hub import login\n", "import numpy as np\n", "import pickle\n", "from sentence_transformers import SentenceTransformer\n", "from datasets import load_dataset\n", "import chromadb\n", "from items import Item\n", "from sklearn.manifold import TSNE\n", "import plotly.graph_objects as go" ] }, { "cell_type": "code", "execution_count": 11, "id": "f4aab95e-d719-4476-b6e7-e248120df25a", "metadata": {}, "outputs": [], "source": [ "DB = \"products_vectorstore\"\n", "client = chromadb.PersistentClient(path=DB)" ] }, { "cell_type": "code", "execution_count": 12, "id": "5f95dafd-ab80-464e-ba8a-dec7a2424780", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Deleted existing collection: products\n" ] } ], "source": [ "collection = client.get_or_create_collection('products')" ] }, { "cell_type": "code", "execution_count": null, "id": "525fc313-8a16-4ac0-8c42-6a6d1ba1c9b8", "metadata": {}, "outputs": [], "source": [ "CATEGORIES = ['Appliances', 'Automotive', 'Cell_Phones_and_Accessories', 'Electronics','Musical_Instruments', 'Office_Products', 'Tools_and_Home_Improvement', 'Toys_and_Games']\n", "COLORS = ['red', 'blue', 'brown', 'orange', 'yellow', 'green' , 'purple', 'cyan']" ] }, { "cell_type": "code", "execution_count": null, "id": "a4cf1c9a-1ced-48d4-974c-3c850905034e", "metadata": {}, "outputs": [], "source": [ "# Prework\n", "result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n", "vectors = np.array(result['embeddings'])\n", "documents = result['documents']\n", "categories = [metadata['category'] for metadata in result['metadatas']]\n", "colors = [COLORS[CATEGORIES.index(c)] for c in categories]" ] }, { "cell_type": "code", "execution_count": null, "id": "c54df150-c8d8-4bc3-8877-6759691eeb42", "metadata": {}, "outputs": [], "source": [ "# Let's try 3D!\n", "\n", "tsne = TSNE(n_components=3, random_state=42, max_iter=250, n_jobs=-1)\n", "reduced_vectors = tsne.fit_transform(vectors)" ] }, { "cell_type": "code", "execution_count": null, "id": "e8fb2a63-24c5-4dce-9e63-aa208272f82d", "metadata": {}, "outputs": [], "source": [ "\n", "# Create the 3D scatter plot\n", "fig = go.Figure(data=[go.Scatter3d(\n", " x=reduced_vectors[:, 0],\n", " y=reduced_vectors[:, 1],\n", " z=reduced_vectors[:, 2],\n", " mode='markers',\n", " marker=dict(size=3, color=colors, opacity=0.7),\n", " text=[f\"Category: {c}
Text: {d[:100]}...\" for c, d in zip(categories, documents)],\n", " hoverinfo='text'\n", ")])\n", "\n", "fig.update_layout(\n", " title='3D Chroma Vector Store Visualization',\n", " scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n", " width=1200,\n", " height=800,\n", " margin=dict(r=20, b=10, l=10, t=40)\n", ")\n", "\n", "fig.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }