{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "28a0673e-96b5-43f2-8a8b-bd033bf851b0",
   "metadata": {},
   "source": [
    "# The Big Project begins!!\n",
    "\n",
    "## The Product Pricer\n",
    "\n",
    "A model that can estimate how much something costs, from its description.\n",
    "\n",
    "## Data Curation Part 1\n",
    "\n",
    "Today we'll begin our scrubbing and curating our dataset by focusing on a subset of the data: Home Appliances.\n",
    "\n",
    "The dataset is here:  \n",
    "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023\n",
    "\n",
    "And the folder with all the product datasets is here:  \n",
    "https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67cedf85-8125-4322-998e-9375fe745597",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from huggingface_hub import login\n",
    "from datasets import load_dataset, Dataset, DatasetDict\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7390a6aa-79cb-4dea-b6d7-de7e4b13e472",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "\n",
    "load_dotenv()\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
    "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0732274a-aa6a-44fc-aee2-40dc8a8e4451",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Log in to HuggingFace\n",
    "\n",
    "hf_token = os.environ['HF_TOKEN']\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5521526-0da9-42d7-99e3-f950fab71662",
   "metadata": {},
   "outputs": [],
   "source": [
    "# One more import - the Item class\n",
    "# If you get an error that you need to agree to Meta's terms when you run this, then follow the link it provides you and follow their instructions\n",
    "# You should get approved by Meta within minutes\n",
    "# Any problems - message me or email me!\n",
    "# With thanks to student Dr John S. for pointing out that this import needs to come after signing in to HF\n",
    "\n",
    "from items import Item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1adcf323-de9d-4c24-a9c3-d7ae554d06ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "049885d4-fdfa-4ff0-a932-4a2ed73928e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in our dataset\n",
    "\n",
    "dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_Appliances\", split=\"full\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cde08860-b393-49b8-a620-06a8c0990a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Number of Appliances: {len(dataset):,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e29a5ab-ca61-41cc-9b33-22d374681b85",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Investigate a particular datapoint\n",
    "datapoint = dataset[2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40a4e10f-6710-4780-a95e-6c0030c3fb87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Investigate\n",
    "\n",
    "print(datapoint[\"title\"])\n",
    "print(datapoint[\"description\"])\n",
    "print(datapoint[\"features\"])\n",
    "print(datapoint[\"details\"])\n",
    "print(datapoint[\"price\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d356c6f-b6e8-4e01-98cd-c562d132aafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many have prices?\n",
    "\n",
    "prices = 0\n",
    "for datapoint in dataset:\n",
    "    try:\n",
    "        price = float(datapoint[\"price\"])\n",
    "        if price > 0:\n",
    "            prices += 1\n",
    "    except ValueError as e:\n",
    "        pass\n",
    "\n",
    "print(f\"There are {prices:,} with prices which is {prices/len(dataset)*100:,.1f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd890259-aa25-4097-9524-f91c2bdd719b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For those with prices, gather the price and the length\n",
    "\n",
    "prices = []\n",
    "lengths = []\n",
    "for datapoint in dataset:\n",
    "    try:\n",
    "        price = float(datapoint[\"price\"])\n",
    "        if price > 0:\n",
    "            prices.append(price)\n",
    "            contents = datapoint[\"title\"] + str(datapoint[\"description\"]) + str(datapoint[\"features\"]) + str(datapoint[\"details\"])\n",
    "            lengths.append(len(contents))\n",
    "    except ValueError as e:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89078cb1-9679-4eb0-b295-599b8586bcd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of lengths\n",
    "\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Lengths: Avg {sum(lengths)/len(lengths):,.0f} and highest {max(lengths):,}\\n\")\n",
    "plt.xlabel('Length (chars)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 6000, 100))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c38e0c43-9f7a-450e-a911-c94d37d9b9c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of prices\n",
    "\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.2f} and highest {max(prices):,}\\n\")\n",
    "plt.xlabel('Price ($)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eabc7c61-0cd2-41f4-baa1-b85400bbf87f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# So what is this item??\n",
    "\n",
    "for datapoint in dataset:\n",
    "    try:\n",
    "        price = float(datapoint[\"price\"])\n",
    "        if price > 21000:\n",
    "            print(datapoint['title'])\n",
    "    except ValueError as e:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3668ae25-3461-4e6e-9ccb-221c1925a497",
   "metadata": {},
   "source": [
    "This is the closest I can find - looks like it's going at a bargain price!!\n",
    "\n",
    "https://www.amazon.com/TurboChef-Electric-Countertop-Microwave-Convection/dp/B01D05U9NO/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0d02f58-23f6-4f81-a779-7c0555afd13d",
   "metadata": {},
   "source": [
    "## Now it's time to curate our dataset\n",
    "\n",
    "We select items that cost between 1 and 999 USD\n",
    "\n",
    "We will be create Item instances, which truncate the text to fit within 180 tokens using the right Tokenizer\n",
    "\n",
    "And will create a prompt to be used during Training.\n",
    "\n",
    "Items will be rejected if they don't have sufficient characters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "430b432f-b769-41da-9506-a238cb5cf1b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an Item object for each with a price\n",
    "\n",
    "items = []\n",
    "for datapoint in dataset:\n",
    "    try:\n",
    "        price = float(datapoint[\"price\"])\n",
    "        if price > 0:\n",
    "            item = Item(datapoint, price)\n",
    "            if item.include:\n",
    "                items.append(item)\n",
    "    except ValueError as e:\n",
    "        pass\n",
    "\n",
    "print(f\"There are {len(items):,} items\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d570794-6f1d-462e-b567-a46bae3556a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Look at the first item\n",
    "\n",
    "items[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70219e99-22cc-4e08-9121-51f9707caef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Investigate the prompt that will be used during training - the model learns to complete this\n",
    "\n",
    "print(items[100].prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9998b8d-d746-4541-9ac2-701108e0e8fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Investigate the prompt that will be used during testing - the model has to complete this\n",
    "\n",
    "print(items[100].test_prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a116369-335a-412b-b70c-2add6675c2e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of token counts\n",
    "\n",
    "tokens = [item.token_count for item in items]\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
    "plt.xlabel('Length (tokens)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(tokens, rwidth=0.7, color=\"green\", bins=range(0, 300, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d1744aa-71e7-435e-876e-91f06583211a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the distribution of prices\n",
    "\n",
    "prices = [item.price for item in items]\n",
    "plt.figure(figsize=(15, 6))\n",
    "plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
    "plt.xlabel('Price ($)')\n",
    "plt.ylabel('Count')\n",
    "plt.hist(prices, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b58dc61-747f-46f7-b9e0-c205db4f3e5e",
   "metadata": {},
   "source": [
    "## Sidenote\n",
    "\n",
    "If you like the variety of colors that matplotlib can use in its charts, you should bookmark this:\n",
    "\n",
    "https://matplotlib.org/stable/gallery/color/named_colors.html\n",
    "\n",
    "## Todos for you:\n",
    "\n",
    "- Review the Item class and check you're comfortable with it\n",
    "- Examine some Item objects, look at the training prompt with `item.prompt` and test prompt with `item.test_prompt()`\n",
    "- Make some more histograms to better understand the data\n",
    "\n",
    "## Next time we will combine with many other types of product\n",
    "\n",
    "Like Electronics and Automotive. This will give us a massive dataset, and we can then be picky about choosing a subset that will be most suitable for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01401283-d111-40a7-96e5-0ca05bb20857",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}