{ "cells": [ { "cell_type": "markdown", "id": "28a0673e-96b5-43f2-8a8b-bd033bf851b0", "metadata": {}, "source": [ "# The Big Project begins!!\n", "\n", "## The Product Pricer\n", "\n", "A model that can estimate how much something costs, from its description.\n", "\n", "## Data Curation Part 1\n", "\n", "Today we'll begin our scrubbing and curating our dataset by focusing on a subset of the data: Home Appliances.\n", "\n", "The dataset is here: \n", "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023\n", "\n", "And the folder with all the product datasets is here: \n", "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories" ] }, { "cell_type": "code", "execution_count": null, "id": "67cedf85-8125-4322-998e-9375fe745597", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "from dotenv import load_dotenv\n", "from huggingface_hub import login\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "7390a6aa-79cb-4dea-b6d7-de7e4b13e472", "metadata": {}, "outputs": [], "source": [ "# environment\n", "\n", "load_dotenv()\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" ] }, { "cell_type": "code", "execution_count": null, "id": "0732274a-aa6a-44fc-aee2-40dc8a8e4451", "metadata": {}, "outputs": [], "source": [ "# Log in to HuggingFace\n", "\n", "hf_token = os.environ['HF_TOKEN']\n", "login(hf_token, add_to_git_credential=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "b5521526-0da9-42d7-99e3-f950fab71662", "metadata": {}, "outputs": [], "source": [ "# One more import - the Item class\n", "# If you get an error that you need to agree to Meta's terms when you run this, then follow the link it provides you and follow their instructions\n", "# You should get approved by Meta within minutes\n", "# Any problems - message me or email me!\n", "# With thanks to student Dr John S. for pointing out that this import needs to come after signing in to HF\n", "\n", "from items import Item" ] }, { "cell_type": "code", "execution_count": null, "id": "1adcf323-de9d-4c24-a9c3-d7ae554d06ca", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "id": "049885d4-fdfa-4ff0-a932-4a2ed73928e2", "metadata": {}, "outputs": [], "source": [ "# Load in our dataset\n", "\n", "dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_Appliances\", split=\"full\", trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "cde08860-b393-49b8-a620-06a8c0990a64", "metadata": {}, "outputs": [], "source": [ "print(f\"Number of Appliances: {len(dataset):,}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3e29a5ab-ca61-41cc-9b33-22d374681b85", "metadata": {}, "outputs": [], "source": [ "# Investigate a particular datapoint\n", "datapoint = dataset[2]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "40a4e10f-6710-4780-a95e-6c0030c3fb87", "metadata": {}, "outputs": [], "source": [ "# Investigate\n", "\n", "print(datapoint[\"title\"])\n", "print(datapoint[\"description\"])\n", "print(datapoint[\"features\"])\n", "print(datapoint[\"details\"])\n", "print(datapoint[\"price\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "9d356c6f-b6e8-4e01-98cd-c562d132aafa", "metadata": {}, "outputs": [], "source": [ "# How many have prices?\n", "\n", "prices = 0\n", "for datapoint in dataset:\n", " try:\n", " price = float(datapoint[\"price\"])\n", " if price > 0:\n", " prices += 1\n", " except ValueError as e:\n", " pass\n", "\n", "print(f\"There are {prices:,} with prices which is {prices/len(dataset)*100:,.1f}%\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bd890259-aa25-4097-9524-f91c2bdd719b", "metadata": {}, "outputs": [], "source": [ "# For those with prices, gather the price and the length\n", "\n", "prices = []\n", "lengths = []\n", "for datapoint in dataset:\n", " try:\n", " price = float(datapoint[\"price\"])\n", " if price > 0:\n", " prices.append(price)\n", " contents = datapoint[\"title\"] + str(datapoint[\"description\"]) + str(datapoint[\"features\"]) + str(datapoint[\"details\"])\n", " lengths.append(len(contents))\n", " except ValueError as e:\n", " pass" ] }, { "cell_type": "code", "execution_count": null, "id": "89078cb1-9679-4eb0-b295-599b8586bcd1", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of lengths\n", "\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Lengths: Avg {sum(lengths)/len(lengths):,.0f} and highest {max(lengths):,}\\n\")\n", "plt.xlabel('Length (chars)')\n", "plt.ylabel('Count')\n", "plt.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 6000, 100))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "c38e0c43-9f7a-450e-a911-c94d37d9b9c3", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.2f} and highest {max(prices):,}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "eabc7c61-0cd2-41f4-baa1-b85400bbf87f", "metadata": {}, "outputs": [], "source": [ "# So what is this item??\n", "\n", "for datapoint in dataset:\n", " try:\n", " price = float(datapoint[\"price\"])\n", " if price > 21000:\n", " print(datapoint['title'])\n", " except ValueError as e:\n", " pass" ] }, { "cell_type": "markdown", "id": "3668ae25-3461-4e6e-9ccb-221c1925a497", "metadata": {}, "source": [ "This is the closest I can find - looks like it's going at a bargain price!!\n", "\n", "https://www.amazon.com/TurboChef-Electric-Countertop-Microwave-Convection/dp/B01D05U9NO/" ] }, { "cell_type": "markdown", "id": "a0d02f58-23f6-4f81-a779-7c0555afd13d", "metadata": {}, "source": [ "## Now it's time to curate our dataset\n", "\n", "We select items that cost between 1 and 999 USD\n", "\n", "We will be create Item instances, which truncate the text to fit within 180 tokens using the right Tokenizer\n", "\n", "And will create a prompt to be used during Training.\n", "\n", "Items will be rejected if they don't have sufficient characters." ] }, { "cell_type": "code", "execution_count": null, "id": "430b432f-b769-41da-9506-a238cb5cf1b6", "metadata": {}, "outputs": [], "source": [ "# Create an Item object for each with a price\n", "\n", "items = []\n", "for datapoint in dataset:\n", " try:\n", " price = float(datapoint[\"price\"])\n", " if price > 0:\n", " item = Item(datapoint, price)\n", " if item.include:\n", " items.append(item)\n", " except ValueError as e:\n", " pass\n", "\n", "print(f\"There are {len(items):,} items\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0d570794-6f1d-462e-b567-a46bae3556a1", "metadata": {}, "outputs": [], "source": [ "# Look at the first item\n", "\n", "items[1]" ] }, { "cell_type": "code", "execution_count": null, "id": "70219e99-22cc-4e08-9121-51f9707caef0", "metadata": {}, "outputs": [], "source": [ "# Investigate the prompt that will be used during training - the model learns to complete this\n", "\n", "print(items[100].prompt)" ] }, { "cell_type": "code", "execution_count": null, "id": "d9998b8d-d746-4541-9ac2-701108e0e8fb", "metadata": {}, "outputs": [], "source": [ "# Investigate the prompt that will be used during testing - the model has to complete this\n", "\n", "print(items[100].test_prompt())" ] }, { "cell_type": "code", "execution_count": null, "id": "7a116369-335a-412b-b70c-2add6675c2e3", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of token counts\n", "\n", "tokens = [item.token_count for item in items]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n", "plt.xlabel('Length (tokens)')\n", "plt.ylabel('Count')\n", "plt.hist(tokens, rwidth=0.7, color=\"green\", bins=range(0, 300, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "8d1744aa-71e7-435e-876e-91f06583211a", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "prices = [item.price for item in items]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "2b58dc61-747f-46f7-b9e0-c205db4f3e5e", "metadata": {}, "source": [ "## Sidenote\n", "\n", "If you like the variety of colors that matplotlib can use in its charts, you should bookmark this:\n", "\n", "https://matplotlib.org/stable/gallery/color/named_colors.html\n", "\n", "## Todos for you:\n", "\n", "- Review the Item class and check you're comfortable with it\n", "- Examine some Item objects, look at the training prompt with `item.prompt` and test prompt with `item.test_prompt()`\n", "- Make some more histograms to better understand the data\n", "\n", "## Next time we will combine with many other types of product\n", "\n", "Like Electronics and Automotive. This will give us a massive dataset, and we can then be picky about choosing a subset that will be most suitable for training." ] }, { "cell_type": "code", "execution_count": null, "id": "01401283-d111-40a7-96e5-0ca05bb20857", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }