You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

412 lines
12 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "28a0673e-96b5-43f2-8a8b-bd033bf851b0",
"metadata": {},
"source": [
"# The Big Project begins!!\n",
"\n",
"## The Product Pricer\n",
"\n",
"A model that can estimate how much something costs, from its description.\n",
"\n",
"## Data Curation Part 1\n",
"\n",
"Today we'll begin our scrubbing and curating our dataset by focusing on a subset of the data: Home Appliances.\n",
"\n",
"The dataset is here: \n",
"https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023\n",
"\n",
"And the folder with all the product datasets is here: \n",
"https://huggingface.co/datasets/McAuley-Lab/Amazon-Reviews-2023/tree/main/raw/meta_categories"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67cedf85-8125-4322-998e-9375fe745597",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7390a6aa-79cb-4dea-b6d7-de7e4b13e472",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv()\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0732274a-aa6a-44fc-aee2-40dc8a8e4451",
"metadata": {},
"outputs": [],
"source": [
"# Log in to HuggingFace\n",
"\n",
"hf_token = os.environ['HF_TOKEN']\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5521526-0da9-42d7-99e3-f950fab71662",
"metadata": {},
"outputs": [],
"source": [
"# One more import - the Item class\n",
"# If you get an error that you need to agree to Meta's terms when you run this, then follow the link it provides you and follow their instructions\n",
"# You should get approved by Meta within minutes\n",
"# Any problems - message me or email me!\n",
"# With thanks to student Dr John S. for pointing out that this import needs to come after signing in to HF\n",
"\n",
"from items import Item"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1adcf323-de9d-4c24-a9c3-d7ae554d06ca",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "049885d4-fdfa-4ff0-a932-4a2ed73928e2",
"metadata": {},
"outputs": [],
"source": [
"# Load in our dataset\n",
"\n",
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_Appliances\", split=\"full\", trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cde08860-b393-49b8-a620-06a8c0990a64",
"metadata": {},
"outputs": [],
"source": [
"print(f\"Number of Appliances: {len(dataset):,}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3e29a5ab-ca61-41cc-9b33-22d374681b85",
"metadata": {},
"outputs": [],
"source": [
"# Investigate a particular datapoint\n",
"datapoint = dataset[2]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40a4e10f-6710-4780-a95e-6c0030c3fb87",
"metadata": {},
"outputs": [],
"source": [
"# Investigate\n",
"\n",
"print(datapoint[\"title\"])\n",
"print(datapoint[\"description\"])\n",
"print(datapoint[\"features\"])\n",
"print(datapoint[\"details\"])\n",
"print(datapoint[\"price\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d356c6f-b6e8-4e01-98cd-c562d132aafa",
"metadata": {},
"outputs": [],
"source": [
"# How many have prices?\n",
"\n",
"prices = 0\n",
"for datapoint in dataset:\n",
" try:\n",
" price = float(datapoint[\"price\"])\n",
" if price > 0:\n",
" prices += 1\n",
" except ValueError as e:\n",
" pass\n",
"\n",
"print(f\"There are {prices:,} with prices which is {prices/len(dataset)*100:,.1f}%\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd890259-aa25-4097-9524-f91c2bdd719b",
"metadata": {},
"outputs": [],
"source": [
"# For those with prices, gather the price and the length\n",
"\n",
"prices = []\n",
"lengths = []\n",
"for datapoint in dataset:\n",
" try:\n",
" price = float(datapoint[\"price\"])\n",
" if price > 0:\n",
" prices.append(price)\n",
" contents = datapoint[\"title\"] + str(datapoint[\"description\"]) + str(datapoint[\"features\"]) + str(datapoint[\"details\"])\n",
" lengths.append(len(contents))\n",
" except ValueError as e:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "89078cb1-9679-4eb0-b295-599b8586bcd1",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of lengths\n",
"\n",
"plt.figure(figsize=(15, 6))\n",
"plt.title(f\"Lengths: Avg {sum(lengths)/len(lengths):,.0f} and highest {max(lengths):,}\\n\")\n",
"plt.xlabel('Length (chars)')\n",
"plt.ylabel('Count')\n",
"plt.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 6000, 100))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c38e0c43-9f7a-450e-a911-c94d37d9b9c3",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices\n",
"\n",
"plt.figure(figsize=(15, 6))\n",
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.2f} and highest {max(prices):,}\\n\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count')\n",
"plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 10))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eabc7c61-0cd2-41f4-baa1-b85400bbf87f",
"metadata": {},
"outputs": [],
"source": [
"# So what is this item??\n",
"\n",
"for datapoint in dataset:\n",
" try:\n",
" price = float(datapoint[\"price\"])\n",
" if price > 21000:\n",
" print(datapoint['title'])\n",
" except ValueError as e:\n",
" pass"
]
},
{
"cell_type": "markdown",
"id": "3668ae25-3461-4e6e-9ccb-221c1925a497",
"metadata": {},
"source": [
"This is the closest I can find - looks like it's going at a bargain price!!\n",
"\n",
"https://www.amazon.com/TurboChef-Electric-Countertop-Microwave-Convection/dp/B01D05U9NO/"
]
},
{
"cell_type": "markdown",
"id": "a0d02f58-23f6-4f81-a779-7c0555afd13d",
"metadata": {},
"source": [
"## Now it's time to curate our dataset\n",
"\n",
"We select items that cost between 1 and 999 USD\n",
"\n",
"We will be create Item instances, which truncate the text to fit within 180 tokens using the right Tokenizer\n",
"\n",
"And will create a prompt to be used during Training.\n",
"\n",
"Items will be rejected if they don't have sufficient characters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "430b432f-b769-41da-9506-a238cb5cf1b6",
"metadata": {},
"outputs": [],
"source": [
"# Create an Item object for each with a price\n",
"\n",
"items = []\n",
"for datapoint in dataset:\n",
" try:\n",
" price = float(datapoint[\"price\"])\n",
" if price > 0:\n",
" item = Item(datapoint, price)\n",
" if item.include:\n",
" items.append(item)\n",
" except ValueError as e:\n",
" pass\n",
"\n",
"print(f\"There are {len(items):,} items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d570794-6f1d-462e-b567-a46bae3556a1",
"metadata": {},
"outputs": [],
"source": [
"# Look at the first item\n",
"\n",
"items[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70219e99-22cc-4e08-9121-51f9707caef0",
"metadata": {},
"outputs": [],
"source": [
"# Investigate the prompt that will be used during training - the model learns to complete this\n",
"\n",
"print(items[100].prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9998b8d-d746-4541-9ac2-701108e0e8fb",
"metadata": {},
"outputs": [],
"source": [
"# Investigate the prompt that will be used during testing - the model has to complete this\n",
"\n",
"print(items[100].test_prompt())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a116369-335a-412b-b70c-2add6675c2e3",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of token counts\n",
"\n",
"tokens = [item.token_count for item in items]\n",
"plt.figure(figsize=(15, 6))\n",
"plt.title(f\"Token counts: Avg {sum(tokens)/len(tokens):,.1f} and highest {max(tokens):,}\\n\")\n",
"plt.xlabel('Length (tokens)')\n",
"plt.ylabel('Count')\n",
"plt.hist(tokens, rwidth=0.7, color=\"green\", bins=range(0, 300, 10))\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d1744aa-71e7-435e-876e-91f06583211a",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices\n",
"\n",
"prices = [item.price for item in items]\n",
"plt.figure(figsize=(15, 6))\n",
"plt.title(f\"Prices: Avg {sum(prices)/len(prices):,.1f} and highest {max(prices):,}\\n\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count')\n",
"plt.hist(prices, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "2b58dc61-747f-46f7-b9e0-c205db4f3e5e",
"metadata": {},
"source": [
"## Sidenote\n",
"\n",
"If you like the variety of colors that matplotlib can use in its charts, you should bookmark this:\n",
"\n",
"https://matplotlib.org/stable/gallery/color/named_colors.html\n",
"\n",
"## Todos for you:\n",
"\n",
"- Review the Item class and check you're comfortable with it\n",
"- Examine some Item objects, look at the training prompt with `item.prompt` and test prompt with `item.test_prompt()`\n",
"- Make some more histograms to better understand the data\n",
"\n",
"## Next time we will combine with many other types of product\n",
"\n",
"Like Electronics and Automotive. This will give us a massive dataset, and we can then be picky about choosing a subset that will be most suitable for training."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01401283-d111-40a7-96e5-0ca05bb20857",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}