19 changed files with 157 additions and 14126 deletions
File diff suppressed because one or more lines are too long
@ -1,799 +0,0 @@ |
|||||||
{ |
|
||||||
"cells": [ |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9491dd8f-8124-4a51-be3a-8f678c149dcf", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# imports\n", |
|
||||||
"\n", |
|
||||||
"import os\n", |
|
||||||
"import re\n", |
|
||||||
"import math\n", |
|
||||||
"import random\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"from dotenv import load_dotenv\n", |
|
||||||
"from openai import OpenAI\n", |
|
||||||
"import anthropic\n", |
|
||||||
"from huggingface_hub import login\n", |
|
||||||
"from tqdm import tqdm\n", |
|
||||||
"import matplotlib.pyplot as plt\n", |
|
||||||
"from datasets import load_dataset, Dataset, DatasetDict\n", |
|
||||||
"from transformers import AutoTokenizer" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# environment\n", |
|
||||||
"\n", |
|
||||||
"load_dotenv()\n", |
|
||||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "846ded5d-b7f5-4581-8f56-d9650ff329c1", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# initialize\n", |
|
||||||
"\n", |
|
||||||
"openai = OpenAI()\n", |
|
||||||
"claude = anthropic.Anthropic()\n", |
|
||||||
"OPENAI_MODEL = \"gpt-4o-mini\"\n", |
|
||||||
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", |
|
||||||
"hf_token = os.environ['HF_TOKEN']\n", |
|
||||||
"login(hf_token, add_to_git_credential=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"%matplotlib inline" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Constants\n", |
|
||||||
"\n", |
|
||||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", |
|
||||||
"\n", |
|
||||||
"# Used for writing to output in color\n", |
|
||||||
"\n", |
|
||||||
"GREEN = \"\\033[92m\"\n", |
|
||||||
"YELLOW = \"\\033[93m\"\n", |
|
||||||
"RED = \"\\033[91m\"\n", |
|
||||||
"RESET = \"\\033[0m\"" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "b606ea85-4171-449d-8eda-a8f1a9b01464", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"#datasets = [\"raw_meta_Electronics\", \"raw_meta_Appliances\", \"raw_meta_Cell_Phones_and_Accessories\", \"raw_meta_Home_and_Kitchen\"]\n", |
|
||||||
"# datasets = [\"Electronics\", \"Appliances\", \"Cell_Phones_and_Accessories\", \"Home_and_Kitchen\", \"Tools_and_Home_Improvement\"]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "51af18a2-4122-4753-8f5d-622da2976cb5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", \"raw_meta_Electronics\", split=\"full\", trust_remote_code=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "141ddcdd-bd60-44d4-8c63-1c6717f5bafc", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are {len(dataset):,} items in the dataset\")\n", |
|
||||||
"print(\"Here is the first:\")\n", |
|
||||||
"item = dataset[0]\n", |
|
||||||
"print(item['title'])\n", |
|
||||||
"print(item['description'])\n", |
|
||||||
"print(item['features'])\n", |
|
||||||
"print(item['price'])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f36c948d-e14d-44a0-9704-c11c589a26ee", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class Item:\n", |
|
||||||
"\n", |
|
||||||
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, data):\n", |
|
||||||
" self.title = data['title']\n", |
|
||||||
" self.description = self.clean(data['description'])\n", |
|
||||||
" self.features = self.clean(data['features'])\n", |
|
||||||
" self.price = float(data['price'])\n", |
|
||||||
" self.price_str = str(round(self.price))\n", |
|
||||||
" self._token_count = None\n", |
|
||||||
" self.full_prompt = self.make_full_prompt()\n", |
|
||||||
" self.prompt = self.full_prompt.split('Price is $')[0] + 'Price is $'\n", |
|
||||||
" self.label = self.full_prompt.split('Price is $')[1]\n", |
|
||||||
"\n", |
|
||||||
" def clean(self, details):\n", |
|
||||||
" result = ' '.join(details)\n", |
|
||||||
" return re.sub(r'[\\[\\]【】\\s]+', ' ', result).strip()\n", |
|
||||||
"\n", |
|
||||||
" def question(self):\n", |
|
||||||
" prompt = \"How much does this cost?\\n\"\n", |
|
||||||
" prompt += f\"Title: {self.title}\\n\"\n", |
|
||||||
" prompt += f\"Description: {self.description}\\n\"\n", |
|
||||||
" prompt += f\"Features: {self.features}\\n\"\n", |
|
||||||
" return prompt\n", |
|
||||||
"\n", |
|
||||||
" def messages(self):\n", |
|
||||||
" return [\n", |
|
||||||
" {\"role\":\"system\", \"content\": \"You estimate product prices. Reply only with the price to the nearest dollar\"},\n", |
|
||||||
" {\"role\":\"user\", \"content\": self.question()},\n", |
|
||||||
" {\"role\":\"assistant\", \"content\": f\"Price is ${self.price_str}.00\"}\n", |
|
||||||
" ]\n", |
|
||||||
"\n", |
|
||||||
" def make_full_prompt(self):\n", |
|
||||||
" prompt = self.tokenizer.apply_chat_template(self.messages(), tokenize=False, add_generation_prompt=False)\n", |
|
||||||
" groups = prompt.split('\\n\\n')\n", |
|
||||||
" return groups[0]+'\\n\\n'+'\\n\\n'.join(groups[2:])\n", |
|
||||||
"\n", |
|
||||||
" def token_count(self):\n", |
|
||||||
" if self._token_count == None:\n", |
|
||||||
" self._token_count = len(self.tokenizer.encode(self.full_prompt))\n", |
|
||||||
" return self._token_count\n", |
|
||||||
"\n", |
|
||||||
" def tokens_between(self, low, high):\n", |
|
||||||
" token_count = self.token_count()\n", |
|
||||||
" return token_count >= low and token_count < high" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "059152d0-a68a-4e93-b759-45f3c6baf31e", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Create a list called \"items\" with all our datapoints that have a valid price\n", |
|
||||||
"\n", |
|
||||||
"from collections import Counter\n", |
|
||||||
"counts = Counter()\n", |
|
||||||
"items = []\n", |
|
||||||
"for data in tqdm(dataset):\n", |
|
||||||
" try:\n", |
|
||||||
" price_str = data['price']\n", |
|
||||||
" if float(price_str) > 0:\n", |
|
||||||
" items.append(Item(data))\n", |
|
||||||
" except ValueError:\n", |
|
||||||
" counts[data['price']]+=1" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "8752310a-ca69-4d43-b8bd-fd98aebbc805", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"counts.most_common(10)\n" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "011bffcf-03f8-4f0d-8999-b53d1ac88624", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's investigate:\n", |
|
||||||
"\n", |
|
||||||
"print(f\"There are {len(items):,} out of {len(dataset):,} with prices\\n\")\n", |
|
||||||
"print(f\"Item 0 has {items[0].token_count()} tokens:\\n\")\n", |
|
||||||
"print(items[0].full_prompt)\n", |
|
||||||
"print(f\"\\nItem 1 has {items[1].token_count()} tokens:\\n\")\n", |
|
||||||
"print(items[1].full_prompt)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "fcf74830-1e97-4543-b454-eefd314fc106", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of character count\n", |
|
||||||
"\n", |
|
||||||
"lengths = [len(item.full_prompt) for item in items]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Length')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 5000, 250))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "af1d6c8b-f2ae-4691-9306-989b1bd45233", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are total {len(items):,} items\")\n", |
|
||||||
"cutoff = 1500\n", |
|
||||||
"selection = [item for item in items if len(item.full_prompt) < cutoff]\n", |
|
||||||
"print(f\"There are total {len(selection):,} with under {cutoff:,} character training prompt\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "42231dc7-66fb-4437-ba08-7689514a8b19", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Calculate token sizes in selection\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in tqdm(selection)]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of tokens\n", |
|
||||||
"\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 500, 25))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "da0a20b4-8926-4eff-bf83-11c4f6b40455", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def report(item):\n", |
|
||||||
" prompt = item.full_prompt\n", |
|
||||||
" tokens = Item.tokenizer.encode(item.full_prompt)\n", |
|
||||||
" print(prompt)\n", |
|
||||||
" print(tokens[-8:])\n", |
|
||||||
" print(Item.tokenizer.batch_decode(tokens[-8:]))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "2378cb92-305a-49d0-8193-4ae09a0cccf8", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"report(items[0])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1232004a-ff9b-486a-a14b-70f21c217c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's limit our dataset to documents with 60-180 tokens\n", |
|
||||||
"\n", |
|
||||||
"low_cutoff = 80\n", |
|
||||||
"high_cutoff = 240\n", |
|
||||||
"subset = [item for item in tqdm(selection) if item.tokens_between(low_cutoff, high_cutoff)]\n", |
|
||||||
"subset_count = len(subset)\n", |
|
||||||
"count = len(items)\n", |
|
||||||
"print(f\"\\nBetween {low_cutoff} and {high_cutoff}, we get {subset_count:,} out of {count:,} which is {subset_count/count*100:.1f}%\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "7bc11e4f-5a15-48fd-b571-92e2e10b0323", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"prices = [float(item.price) for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Price ($)')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 500, 20))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3718a8e6-6c87-4351-8c27-9e61745b0991", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Pick the most expensive 52,000 items, then pick 12,000 of the next 20,000\n", |
|
||||||
"\n", |
|
||||||
"random.seed(42)\n", |
|
||||||
"sorted_subset = sorted(subset, key=lambda item: item.price, reverse=True)\n", |
|
||||||
"top_30k = sorted_subset[:62000]\n", |
|
||||||
"# other_12k = random.sample(sorted_subset[30000:50000], k=12000)\n", |
|
||||||
"# sample = top_30k + other_12k\n", |
|
||||||
"sample = top_30k\n", |
|
||||||
"print(len(sample))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3cd1c4d3-b6e4-4f28-8ad4-709c4637626c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"prices = [float(item.price) for item in sample]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Price ($)')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 500, 20))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"sizes = [len(item.full_prompt) for item in sample]\n", |
|
||||||
"prices = [item.price for item in sample]\n", |
|
||||||
"\n", |
|
||||||
"# Create the scatter plot\n", |
|
||||||
"plt.figure(figsize=(10, 6))\n", |
|
||||||
"plt.scatter(sizes, prices, s=2, color=\"red\")\n", |
|
||||||
"\n", |
|
||||||
"# Add labels and title\n", |
|
||||||
"plt.xlabel('Size')\n", |
|
||||||
"plt.ylabel('Price')\n", |
|
||||||
"plt.title('Is there a simple correlation?')\n", |
|
||||||
"\n", |
|
||||||
"# Display the plot\n", |
|
||||||
"plt.show()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in sample]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "59ef7aef-b6f6-4042-a2af-ddd5ae1c9999", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"report(sample[0])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "cacb9059-5f44-4601-860a-30860cebe9c2", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"random.seed(42)\n", |
|
||||||
"random.shuffle(sample)\n", |
|
||||||
"train = sample[:60000]\n", |
|
||||||
"test = sample[60000:]\n", |
|
||||||
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"average = sum(t.price for t in train)/len(train)\n", |
|
||||||
"average" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "95353e68-07ac-4f57-8d57-dd48cacb0e04", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class TestRunner:\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, predictor, data, title, size=None):\n", |
|
||||||
" self.predictor = predictor\n", |
|
||||||
" self.data = data\n", |
|
||||||
" self.size = size or len(data)\n", |
|
||||||
" self.guesses = []\n", |
|
||||||
" self.truths = []\n", |
|
||||||
" self.errors = []\n", |
|
||||||
" self.title = title\n", |
|
||||||
"\n", |
|
||||||
" def run_datapoint(self, i):\n", |
|
||||||
" datapoint = self.data[i]\n", |
|
||||||
" guess = self.predictor(datapoint)\n", |
|
||||||
" truth = datapoint.price\n", |
|
||||||
" error = abs(guess - truth)\n", |
|
||||||
" color = RED if error>=20 else YELLOW if error>=10 else GREEN\n", |
|
||||||
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n", |
|
||||||
" self.guesses.append(guess)\n", |
|
||||||
" self.truths.append(truth)\n", |
|
||||||
" self.errors.append(error)\n", |
|
||||||
" print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} Item: {title}{RESET}\")\n", |
|
||||||
"\n", |
|
||||||
" def chart(self):\n", |
|
||||||
" max_error = max(self.errors)\n", |
|
||||||
" colors = [(max_error - error)**3 for error in self.errors]\n", |
|
||||||
" plt.figure(figsize=(10, 6))\n", |
|
||||||
" plt.scatter(self.truths, self.guesses, s=3, c=colors, cmap='RdYlGn')\n", |
|
||||||
" plt.xlabel('Truth')\n", |
|
||||||
" plt.ylabel('Guess')\n", |
|
||||||
" plt.title(self.title)\n", |
|
||||||
" plt.show()\n", |
|
||||||
"\n", |
|
||||||
" def run(self):\n", |
|
||||||
" self.error = 0\n", |
|
||||||
" for i in range(self.size):\n", |
|
||||||
" self.run_datapoint(i)\n", |
|
||||||
" average_error = sum(self.errors) / self.size\n", |
|
||||||
" print(f\"Average Error = ${average_error:,.2f}\")\n", |
|
||||||
" hits = [e for e in self.errors if e<10]\n", |
|
||||||
" print(f\"Hit rate = {len(hits)/self.size*100:.1f}%\")\n", |
|
||||||
" self.chart()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train_average = sum(t.price for t in train)/len(train)\n", |
|
||||||
"\n", |
|
||||||
"def flat_predictor(item):\n", |
|
||||||
" return train_average" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "739d2e33-55d4-4892-b42c-771131159c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(flat_predictor, test, \"Flat Predictor Accuracy\", 100).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an'])\n", |
|
||||||
"\n", |
|
||||||
"def words(item):\n", |
|
||||||
" text = f\"{item.title} {item.description} {item.features}\"\n", |
|
||||||
" text = re.sub(r'[()\\[\\]{},\\'\"-]', ' ', text)\n", |
|
||||||
" text = re.sub(r'\\s+', ' ', text)\n", |
|
||||||
" words = text.strip().lower().split(' ')\n", |
|
||||||
" filtered = [word for word in words if word not in stop]\n", |
|
||||||
" return \" \".join(filtered)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "262fc576-7606-426c-8aea-5799b3952d2c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import CountVectorizer\n", |
|
||||||
"from sklearn.linear_model import LinearRegression\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = CountVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = LinearRegression()\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "bd782b21-8e44-409d-a7b6-f136974958b4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def linear_regression_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "80e77aae-0071-42e9-8e24-d3aec5256015", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(linear_regression_predictor, test, \"Linear Accuracy\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n", |
|
||||||
"from sklearn.svm import SVR\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = TfidfVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = SVR(kernel='linear')\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "64560112-3bfb-45cc-b489-de619a2eca20", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def svr_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "392598d4-2deb-4935-9175-fd111616b13c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(svr_predictor, test, \"SVR Accuracy\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "60010699-d26b-4f93-a959-50272ada6a57", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def messages_for(item):\n", |
|
||||||
" system_message = \"You estimate product prices. Reply only with the price, no explanation\"\n", |
|
||||||
" user_prompt = item.question()\n", |
|
||||||
" return [\n", |
|
||||||
" {\"role\": \"system\", \"content\": system_message},\n", |
|
||||||
" {\"role\": \"user\", \"content\": user_prompt}\n", |
|
||||||
" ]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def get_price(s):\n", |
|
||||||
" s = s.replace('$','').replace(',','')\n", |
|
||||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", |
|
||||||
" return float(match.group()) if match else 0" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def gpt_predictor(item):\n", |
|
||||||
" response = openai.chat.completions.create(\n", |
|
||||||
" model=OPENAI_MODEL,\n", |
|
||||||
" messages=messages_for(item),\n", |
|
||||||
" seed=42,\n", |
|
||||||
" max_tokens=8\n", |
|
||||||
" )\n", |
|
||||||
" reply = response.choices[0].message.content\n", |
|
||||||
" return get_price(reply)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(gpt_predictor, test, \"GPT-4o-mini Prediction Accuracy\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f7e24d6b-59a2-464a-95a9-14a9fbfadd4d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train[0].full_prompt" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "059b6c74-917f-4cb1-b810-ce70735a57be", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train_prompts = [item.full_prompt for item in train]\n", |
|
||||||
"train_prices = [item.price for item in train]\n", |
|
||||||
"test_prompts = [item.prompt for item in test]\n", |
|
||||||
"test_prices = [item.price for item in test]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "b8ba48cb-da5e-4ddb-8955-8a94e62ea8e0", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Create a Dataset from the lists\n", |
|
||||||
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", |
|
||||||
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", |
|
||||||
"dataset = DatasetDict({\n", |
|
||||||
" \"train\": train_dataset,\n", |
|
||||||
" \"test\": test_dataset\n", |
|
||||||
"})" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e69e26a5-4b24-4e0f-8944-731c534b285b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"DATASET_NAME = \"ed-donner/electronics-instruct\"\n", |
|
||||||
"dataset.push_to_hub(DATASET_NAME, private=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [] |
|
||||||
} |
|
||||||
], |
|
||||||
"metadata": { |
|
||||||
"kernelspec": { |
|
||||||
"display_name": "Python 3 (ipykernel)", |
|
||||||
"language": "python", |
|
||||||
"name": "python3" |
|
||||||
}, |
|
||||||
"language_info": { |
|
||||||
"codemirror_mode": { |
|
||||||
"name": "ipython", |
|
||||||
"version": 3 |
|
||||||
}, |
|
||||||
"file_extension": ".py", |
|
||||||
"mimetype": "text/x-python", |
|
||||||
"name": "python", |
|
||||||
"nbconvert_exporter": "python", |
|
||||||
"pygments_lexer": "ipython3", |
|
||||||
"version": "3.11.10" |
|
||||||
} |
|
||||||
}, |
|
||||||
"nbformat": 4, |
|
||||||
"nbformat_minor": 5 |
|
||||||
} |
|
@ -1,942 +0,0 @@ |
|||||||
{ |
|
||||||
"cells": [ |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9491dd8f-8124-4a51-be3a-8f678c149dcf", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# imports\n", |
|
||||||
"\n", |
|
||||||
"import os\n", |
|
||||||
"import re\n", |
|
||||||
"import math\n", |
|
||||||
"import random\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"from typing import Optional\n", |
|
||||||
"from dotenv import load_dotenv\n", |
|
||||||
"from openai import OpenAI\n", |
|
||||||
"import anthropic\n", |
|
||||||
"from huggingface_hub import login\n", |
|
||||||
"from tqdm import tqdm\n", |
|
||||||
"import matplotlib.pyplot as plt\n", |
|
||||||
"from datasets import load_dataset, Dataset, DatasetDict\n", |
|
||||||
"from transformers import AutoTokenizer" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# environment\n", |
|
||||||
"\n", |
|
||||||
"load_dotenv()\n", |
|
||||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "535addd2-9590-42dd-81d8-9dbe06e0194a", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# constants\n", |
|
||||||
"\n", |
|
||||||
"MIN_TOKENS = 80\n", |
|
||||||
"MAX_TOKENS = 180\n", |
|
||||||
"CUTOFF_CHARS = MAX_TOKENS * 7" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "846ded5d-b7f5-4581-8f56-d9650ff329c1", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# initialize\n", |
|
||||||
"\n", |
|
||||||
"openai = OpenAI()\n", |
|
||||||
"claude = anthropic.Anthropic()\n", |
|
||||||
"OPENAI_MODEL = \"gpt-4o-mini\"\n", |
|
||||||
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", |
|
||||||
"hf_token = os.environ['HF_TOKEN']\n", |
|
||||||
"login(hf_token, add_to_git_credential=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"%matplotlib inline" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Constants\n", |
|
||||||
"\n", |
|
||||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", |
|
||||||
"\n", |
|
||||||
"# Used for writing to output in color\n", |
|
||||||
"\n", |
|
||||||
"GREEN = \"\\033[92m\"\n", |
|
||||||
"YELLOW = \"\\033[93m\"\n", |
|
||||||
"RED = \"\\033[91m\"\n", |
|
||||||
"RESET = \"\\033[0m\"" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "fb2ed609-a00a-4ff8-9f4d-8f2ff8ea26dd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"datasets = [\"Electronics\", \"Appliances\", \"Cell_Phones_and_Accessories\", \"Home_and_Kitchen\", \"Tools_and_Home_Improvement\"]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "51af18a2-4122-4753-8f5d-622da2976cb5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", \"raw_meta_Electronics\", split=\"full\", trust_remote_code=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "141ddcdd-bd60-44d4-8c63-1c6717f5bafc", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are {len(dataset):,} items in the dataset\")\n", |
|
||||||
"print(\"Here is the first:\")\n", |
|
||||||
"item = dataset[0]\n", |
|
||||||
"print(item['title'])\n", |
|
||||||
"print(item['description'])\n", |
|
||||||
"print(item['features'])\n", |
|
||||||
"print(item['price'])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f36c948d-e14d-44a0-9704-c11c589a26ee", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class Item:\n", |
|
||||||
" \n", |
|
||||||
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", |
|
||||||
" PREFIX = \"Price is $\"\n", |
|
||||||
"\n", |
|
||||||
" title: str\n", |
|
||||||
" price: float\n", |
|
||||||
" token_count: int = 0\n", |
|
||||||
" details: Optional[str]\n", |
|
||||||
" prompt: Optional[str]\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, data, price):\n", |
|
||||||
" self.title = data['title']\n", |
|
||||||
" self.price = price\n", |
|
||||||
" self.create_details(data)\n", |
|
||||||
" \n", |
|
||||||
" def create_details(self, data):\n", |
|
||||||
" self.details = '\\n'.join(data['description'])\n", |
|
||||||
" features = '\\n'.join(data['features'])\n", |
|
||||||
" if features:\n", |
|
||||||
" self.details += '\\n' + features\n", |
|
||||||
" self.details = re.sub(r'[\\[\\]【】\\s]+', ' ', self.details).strip()\n", |
|
||||||
" self.make_prompt()\n", |
|
||||||
"\n", |
|
||||||
" def question(self):\n", |
|
||||||
" prompt = \"How much does this cost?\\n\"\n", |
|
||||||
" prompt += f\"Title: {self.title}\\n\"\n", |
|
||||||
" prompt += f\"Details: {self.details}\\n\"\n", |
|
||||||
" return prompt\n", |
|
||||||
"\n", |
|
||||||
" def messages(self):\n", |
|
||||||
" return [\n", |
|
||||||
" {\"role\":\"system\", \"content\": \"You estimate product prices. Reply only with the price to the nearest dollar\"},\n", |
|
||||||
" {\"role\":\"user\", \"content\": self.question()},\n", |
|
||||||
" {\"role\":\"assistant\", \"content\": f\"{self.PREFIX}{str(round(self.price))}.00\"}\n", |
|
||||||
" ]\n", |
|
||||||
"\n", |
|
||||||
" def make_prompt(self):\n", |
|
||||||
" prompt = self.tokenizer.apply_chat_template(self.messages(), tokenize=False, add_generation_prompt=False)\n", |
|
||||||
" groups = prompt.split('\\n\\n')\n", |
|
||||||
" self.prompt = groups[0]+'\\n\\n'+'\\n\\n'.join(groups[2:])\n", |
|
||||||
"\n", |
|
||||||
" def count_tokens(self):\n", |
|
||||||
" self.token_count = len(self.tokenizer.encode(self.prompt))\n", |
|
||||||
"\n", |
|
||||||
" def tokens_between(self, low, high):\n", |
|
||||||
" return self.token_count >= low and self.token_count < high\n", |
|
||||||
"\n", |
|
||||||
" def test_prompt(self):\n", |
|
||||||
" return self.prompt.split(self.PREFIX)[0] + self.PREFIX" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "20d97009-6b35-4fdf-baae-59dbd1bf6f77", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def read_dataset(name):\n", |
|
||||||
" print(f\"Loading dataset {name}\")\n", |
|
||||||
" dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_{name}\", split=\"full\", trust_remote_code=True)\n", |
|
||||||
" results = []\n", |
|
||||||
" for data in tqdm(dataset):\n", |
|
||||||
" try:\n", |
|
||||||
" price_str = data['price']\n", |
|
||||||
" if price_str:\n", |
|
||||||
" price = float(price_str)\n", |
|
||||||
" if price > 0:\n", |
|
||||||
" results.append(Item(data, price))\n", |
|
||||||
" except ValueError:\n", |
|
||||||
" pass\n", |
|
||||||
" print(f\"Completed loading {name} with {len(results):,} datapoints\")\n", |
|
||||||
" return results" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "dd11853b-9e21-4b14-9a08-9d9f63636e1a", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"items = []\n", |
|
||||||
"for dataset in datasets:\n", |
|
||||||
" items.extend(read_dataset(dataset))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "011bffcf-03f8-4f0d-8999-b53d1ac88624", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's investigate:\n", |
|
||||||
"\n", |
|
||||||
"print(f\"There are {len(items):,} items with prices\\n\")\n", |
|
||||||
"print(items[0].prompt)\n", |
|
||||||
"print(items[1].prompt)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "fcf74830-1e97-4543-b454-eefd314fc106", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of character count\n", |
|
||||||
"\n", |
|
||||||
"lengths = [len(item.prompt) for item in items]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Length')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 5000, 250))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "af1d6c8b-f2ae-4691-9306-989b1bd45233", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are total {len(items):,} items\")\n", |
|
||||||
"selection = [item for item in items if len(item.prompt) < CUTOFF_CHARS]\n", |
|
||||||
"print(f\"There are total {len(selection):,} with under {CUTOFF_CHARS:,} character training prompt\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "42231dc7-66fb-4437-ba08-7689514a8b19", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Calculate token sizes in selection\n", |
|
||||||
"\n", |
|
||||||
"for item in tqdm(selection):\n", |
|
||||||
" item.count_tokens()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of tokens\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count for item in tqdm(selection)]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 500, 25))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "da0a20b4-8926-4eff-bf83-11c4f6b40455", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def report(item):\n", |
|
||||||
" prompt = item.prompt\n", |
|
||||||
" tokens = Item.tokenizer.encode(item.prompt)\n", |
|
||||||
" print(prompt)\n", |
|
||||||
" print(tokens[-10:])\n", |
|
||||||
" print(Item.tokenizer.batch_decode(tokens[-10:]))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "2378cb92-305a-49d0-8193-4ae09a0cccf8", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"report(items[0])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1232004a-ff9b-486a-a14b-70f21c217c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's limit our dataset to documents with 60-180 tokens\n", |
|
||||||
"\n", |
|
||||||
"subset = [item for item in tqdm(selection) if item.tokens_between(MIN_TOKENS, MAX_TOKENS)]\n", |
|
||||||
"subset_count = len(subset)\n", |
|
||||||
"count = len(items)\n", |
|
||||||
"print(f\"\\nBetween {MIN_TOKENS} and {MAX_TOKENS}, we get {subset_count:,} out of {count:,} which is {subset_count/count*100:.1f}%\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "7bc11e4f-5a15-48fd-b571-92e2e10b0323", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"prices = [float(item.price) for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Price ($)')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 500, 20))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3718a8e6-6c87-4351-8c27-9e61745b0991", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Pick the most expensive 52,000 items, then pick 12,000 of the next 20,000\n", |
|
||||||
"\n", |
|
||||||
"# random.seed(42)\n", |
|
||||||
"# subset2 = [item for item in subset if item.price <= 999]\n", |
|
||||||
"# sorted_subset2 = sorted(subset2, key=lambda item: item.price, reverse=True)\n", |
|
||||||
"# sample = sorted_subset2[:90_000]\n", |
|
||||||
"# other_12k = random.sample(sorted_subset2[90_000:130_000], k=15000)\n", |
|
||||||
"# sample += other_12k\n", |
|
||||||
"# print(f\"Created a sample of {len(sample):,} with prices ranging from ${sample[-1].price} to ${sample[0].price}\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f50917db-ab22-4ecd-a7f1-a2cd45ceb7e6", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"random.seed(42)\n", |
|
||||||
"subset = [item for item in subset if item.price <= 999]\n", |
|
||||||
"sorted_subset = sorted(subset, key=lambda item: item.price, reverse=True)\n", |
|
||||||
"sample = sorted_subset[:150_000]\n", |
|
||||||
"sample += random.sample(sorted_subset[150_000:300_000], k=50000)\n", |
|
||||||
"sample += random.sample(sorted_subset[300_000:], k=5000)\n", |
|
||||||
"print(f\"Created a sample of {len(sample):,} with prices ranging from ${sample[-1].price} to ${sample[0].price}\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3cd1c4d3-b6e4-4f28-8ad4-709c4637626c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"plt.figure(figsize=(10, 6))\n", |
|
||||||
"prices = [float(item.price) for item in sample]\n", |
|
||||||
"plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 20))\n", |
|
||||||
"\n", |
|
||||||
"plt.title(f\"Avg price ${sum(prices)/len(prices):.2f}\")\n", |
|
||||||
"plt.xlabel('Price ($)')\n", |
|
||||||
"plt.ylabel('Count of items')\n", |
|
||||||
"\n", |
|
||||||
"plt.show()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"sizes = [len(item.prompt) for item in sample]\n", |
|
||||||
"prices = [item.price for item in sample]\n", |
|
||||||
"\n", |
|
||||||
"# Create the scatter plot\n", |
|
||||||
"plt.figure(figsize=(10, 6))\n", |
|
||||||
"plt.scatter(sizes, prices, s=1, color=\"red\")\n", |
|
||||||
"\n", |
|
||||||
"# Add labels and title\n", |
|
||||||
"plt.xlabel('Size')\n", |
|
||||||
"plt.ylabel('Price')\n", |
|
||||||
"plt.title('Is there a simple correlation?')\n", |
|
||||||
"\n", |
|
||||||
"# Display the plot\n", |
|
||||||
"plt.show()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count for item in sample]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "59ef7aef-b6f6-4042-a2af-ddd5ae1c9999", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"report(sample[-1])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "cacb9059-5f44-4601-860a-30860cebe9c2", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"random.seed(42)\n", |
|
||||||
"random.shuffle(sample)\n", |
|
||||||
"train = sample[:200_000]\n", |
|
||||||
"test = sample[200_000:]\n", |
|
||||||
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "bf435bcd-accf-427c-82d5-02b33a56737c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"del items, subset, sorted_subset, selection" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "6b26000a-e5a9-4ab7-83fc-8eb44cb12f94", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"test[0].title" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3615bfdd-f23e-4005-96d8-7b52a1a439be", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"import csv\n", |
|
||||||
"with open('test.csv', 'w') as csvfile:\n", |
|
||||||
" writer = csv.writer(csvfile)\n", |
|
||||||
" for t in test[:200]:\n", |
|
||||||
" writer.writerow([t.title, t.details, 0])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "cb6907e9-37d7-4283-b1a9-8124f9f3439b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"human_predictions = []\n", |
|
||||||
"with open('human.csv', 'r') as csvfile:\n", |
|
||||||
" reader = csv.reader(csvfile)\n", |
|
||||||
" for row in reader:\n", |
|
||||||
" human_predictions.append(float(row[2]))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"average = sum(t.price for t in train)/len(train)\n", |
|
||||||
"average" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "95353e68-07ac-4f57-8d57-dd48cacb0e04", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class TestRunner:\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, predictor, data, title, size=200):\n", |
|
||||||
" self.predictor = predictor\n", |
|
||||||
" self.data = data\n", |
|
||||||
" self.title = title\n", |
|
||||||
" self.size = size\n", |
|
||||||
" self.guesses = []\n", |
|
||||||
" self.truths = []\n", |
|
||||||
" self.errors = []\n", |
|
||||||
" self.sles = []\n", |
|
||||||
" self.colors = []\n", |
|
||||||
"\n", |
|
||||||
" def run_datapoint(self, i):\n", |
|
||||||
" datapoint = self.data[i]\n", |
|
||||||
" guess = self.predictor(datapoint)\n", |
|
||||||
" truth = datapoint.price\n", |
|
||||||
" error = abs(guess - truth)\n", |
|
||||||
" log_error = math.log(truth+1) - math.log(guess+1)\n", |
|
||||||
" sle = log_error ** 2\n", |
|
||||||
" color = RED if error>=20 else YELLOW if error>=10 else GREEN\n", |
|
||||||
" color_str = \"red\" if error>=20 else \"yellow\" if error>=10 else \"green\"\n", |
|
||||||
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n", |
|
||||||
" self.guesses.append(guess)\n", |
|
||||||
" self.truths.append(truth)\n", |
|
||||||
" self.errors.append(error)\n", |
|
||||||
" self.sles.append(sle)\n", |
|
||||||
" self.colors.append(color_str)\n", |
|
||||||
" print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n", |
|
||||||
"\n", |
|
||||||
" def chart(self, title):\n", |
|
||||||
" max_error = max(self.errors)\n", |
|
||||||
" plt.figure(figsize=(12, 8))\n", |
|
||||||
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", |
|
||||||
" plt.xlabel('Ground Truth')\n", |
|
||||||
" plt.ylabel('Model Estimate')\n", |
|
||||||
" plt.title(title)\n", |
|
||||||
" plt.show()\n", |
|
||||||
"\n", |
|
||||||
" def report(self):\n", |
|
||||||
" average_error = sum(self.errors) / self.size\n", |
|
||||||
" rmsle = math.sqrt(sum(self.sles) / self.size)\n", |
|
||||||
" hits = [e for e in self.errors if e<10]\n", |
|
||||||
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={len(hits)/self.size*100:.1f}%\"\n", |
|
||||||
" self.chart(title)\n", |
|
||||||
"\n", |
|
||||||
" def run(self):\n", |
|
||||||
" self.error = 0\n", |
|
||||||
" for i in range(self.size):\n", |
|
||||||
" self.run_datapoint(i)\n", |
|
||||||
" self.report()\n", |
|
||||||
" return self" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train_average = sum(t.price for t in train)/len(train)\n", |
|
||||||
"\n", |
|
||||||
"def flat_predictor(item):\n", |
|
||||||
" return train_average" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "739d2e33-55d4-4892-b42c-771131159c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"runner = TestRunner(flat_predictor, test, \"Flat Predictor\").run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9d3a5d83-d90b-40af-979f-85aa21816578", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"human_predictions = []\n", |
|
||||||
"with open('human.csv', 'r') as csvfile:\n", |
|
||||||
" reader = csv.reader(csvfile)\n", |
|
||||||
" for row in reader:\n", |
|
||||||
" human_predictions.append(float(row[2]))\n", |
|
||||||
"\n", |
|
||||||
"def human_predictor(item):\n", |
|
||||||
" index = test.index(item)\n", |
|
||||||
" if index==-1:\n", |
|
||||||
" raise ValueError(\"Index not found\")\n", |
|
||||||
" return human_predictions[index]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "6c87c385-d6b7-4a4f-89eb-f4e250337d03", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"runner = TestRunner(human_predictor, test, \"Human Predictor\").run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an', 'on', 'by'])\n", |
|
||||||
"\n", |
|
||||||
"def words(item):\n", |
|
||||||
" text = f\"{item.title} {item.details}\"\n", |
|
||||||
" text = re.sub(r'[\\(\\)\\[\\]\\{\\},\\'\"\\- \\s]+', ' ', text)\n", |
|
||||||
" words = text.strip().lower().split(' ')\n", |
|
||||||
" filtered = [word for word in words if word not in stop]\n", |
|
||||||
" return \" \".join(filtered)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "56682e9c-46c9-48f2-baea-e943804290f6", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"from collections import Counter\n", |
|
||||||
"count = Counter()\n", |
|
||||||
"for doc in documents:\n", |
|
||||||
" ws = doc.split(\" \")\n", |
|
||||||
" for w in ws:\n", |
|
||||||
" count[w]+=1\n", |
|
||||||
"count.most_common(30)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "262fc576-7606-426c-8aea-5799b3952d2c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import CountVectorizer\n", |
|
||||||
"from sklearn.linear_model import LinearRegression\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = CountVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = LinearRegression()\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "bd782b21-8e44-409d-a7b6-f136974958b4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def linear_regression_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "80e77aae-0071-42e9-8e24-d3aec5256015", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(linear_regression_predictor, test, \"Linear Regression\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n", |
|
||||||
"from sklearn.svm import SVR\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = TfidfVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = SVR(kernel='linear')\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "64560112-3bfb-45cc-b489-de619a2eca20", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def svr_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "392598d4-2deb-4935-9175-fd111616b13c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(svr_predictor, test, \"SVR Accuracy\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "60010699-d26b-4f93-a959-50272ada6a57", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def messages_for(item):\n", |
|
||||||
" system_message = \"You estimate product prices. Reply only with the price, no explanation\"\n", |
|
||||||
" user_prompt = item.question()\n", |
|
||||||
" return [\n", |
|
||||||
" {\"role\": \"system\", \"content\": system_message},\n", |
|
||||||
" {\"role\": \"user\", \"content\": user_prompt}\n", |
|
||||||
" ]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def get_price(s):\n", |
|
||||||
" s = s.replace('$','').replace(',','')\n", |
|
||||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", |
|
||||||
" return float(match.group()) if match else 0" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def gpt_predictor(item):\n", |
|
||||||
" response = openai.chat.completions.create(\n", |
|
||||||
" model=OPENAI_MODEL,\n", |
|
||||||
" messages=messages_for(item),\n", |
|
||||||
" seed=42,\n", |
|
||||||
" max_tokens=8\n", |
|
||||||
" )\n", |
|
||||||
" reply = response.choices[0].message.content\n", |
|
||||||
" return get_price(reply)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"runner = TestRunner(gpt_predictor, test, \"GPT-4o Prediction Accuracy\", 200).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f7e24d6b-59a2-464a-95a9-14a9fbfadd4d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"report(train[1])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "059b6c74-917f-4cb1-b810-ce70735a57be", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train_prompts = [item.prompt for item in train]\n", |
|
||||||
"train_prices = [item.price for item in train]\n", |
|
||||||
"test_prompts = [item.test_prompt() for item in test]\n", |
|
||||||
"test_prices = [item.price for item in test]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "b8ba48cb-da5e-4ddb-8955-8a94e62ea8e0", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"test_prompts[1]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Create a Dataset from the lists\n", |
|
||||||
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", |
|
||||||
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", |
|
||||||
"dataset = DatasetDict({\n", |
|
||||||
" \"train\": train_dataset,\n", |
|
||||||
" \"test\": test_dataset\n", |
|
||||||
"})" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e69e26a5-4b24-4e0f-8944-731c534b285b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"DATASET_NAME = \"ed-donner/multi-instruct\"\n", |
|
||||||
"dataset.push_to_hub(DATASET_NAME, private=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [] |
|
||||||
} |
|
||||||
], |
|
||||||
"metadata": { |
|
||||||
"kernelspec": { |
|
||||||
"display_name": "Python 3 (ipykernel)", |
|
||||||
"language": "python", |
|
||||||
"name": "python3" |
|
||||||
}, |
|
||||||
"language_info": { |
|
||||||
"codemirror_mode": { |
|
||||||
"name": "ipython", |
|
||||||
"version": 3 |
|
||||||
}, |
|
||||||
"file_extension": ".py", |
|
||||||
"mimetype": "text/x-python", |
|
||||||
"name": "python", |
|
||||||
"nbconvert_exporter": "python", |
|
||||||
"pygments_lexer": "ipython3", |
|
||||||
"version": "3.11.10" |
|
||||||
} |
|
||||||
}, |
|
||||||
"nbformat": 4, |
|
||||||
"nbformat_minor": 5 |
|
||||||
} |
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,718 +0,0 @@ |
|||||||
{ |
|
||||||
"cells": [ |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9491dd8f-8124-4a51-be3a-8f678c149dcf", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# imports\n", |
|
||||||
"\n", |
|
||||||
"import os\n", |
|
||||||
"import re\n", |
|
||||||
"import math\n", |
|
||||||
"import random\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"from dotenv import load_dotenv\n", |
|
||||||
"from openai import OpenAI\n", |
|
||||||
"import anthropic\n", |
|
||||||
"from huggingface_hub import login\n", |
|
||||||
"from tqdm import tqdm\n", |
|
||||||
"import matplotlib.pyplot as plt\n", |
|
||||||
"from datasets import load_dataset, Dataset, DatasetDict\n", |
|
||||||
"from transformers import AutoTokenizer" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# environment\n", |
|
||||||
"\n", |
|
||||||
"load_dotenv()\n", |
|
||||||
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", |
|
||||||
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "846ded5d-b7f5-4581-8f56-d9650ff329c1", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# initialize\n", |
|
||||||
"\n", |
|
||||||
"openai = OpenAI()\n", |
|
||||||
"claude = anthropic.Anthropic()\n", |
|
||||||
"OPENAI_MODEL = \"gpt-4o-mini\"\n", |
|
||||||
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", |
|
||||||
"hf_token = os.environ['HF_TOKEN']\n", |
|
||||||
"login(hf_token, add_to_git_credential=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"%matplotlib inline" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Constants\n", |
|
||||||
"\n", |
|
||||||
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n", |
|
||||||
"\n", |
|
||||||
"# Used for writing to output in color\n", |
|
||||||
"\n", |
|
||||||
"GREEN = \"\\033[92m\"\n", |
|
||||||
"YELLOW = \"\\033[93m\"\n", |
|
||||||
"RED = \"\\033[91m\"\n", |
|
||||||
"RESET = \"\\033[0m\"" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "51af18a2-4122-4753-8f5d-622da2976cb5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", \"raw_meta_Electronics\", split=\"full\", trust_remote_code=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "141ddcdd-bd60-44d4-8c63-1c6717f5bafc", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are {len(dataset):,} items in the dataset\")\n", |
|
||||||
"print(\"Here is the first:\")\n", |
|
||||||
"item = dataset[0]\n", |
|
||||||
"print(item['title'])\n", |
|
||||||
"print(item['description'])\n", |
|
||||||
"print(item['features'])\n", |
|
||||||
"print(item['price'])" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f36c948d-e14d-44a0-9704-c11c589a26ee", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class Item:\n", |
|
||||||
"\n", |
|
||||||
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, data):\n", |
|
||||||
" self.title = data['title']\n", |
|
||||||
" self.description = self.clean(data['description'])\n", |
|
||||||
" self.features = self.clean(data['features'])\n", |
|
||||||
" self.price_str = data['price']\n", |
|
||||||
" self.price = float(self.price_str)\n", |
|
||||||
" self._token_count = None\n", |
|
||||||
"\n", |
|
||||||
" def clean(self, details):\n", |
|
||||||
" result = ' '.join(details)\n", |
|
||||||
" return re.sub(r'[\\[\\]【】\\s]+', ' ', result).strip()\n", |
|
||||||
"\n", |
|
||||||
" def question(self):\n", |
|
||||||
" prompt = \"How much does this cost?\\n\"\n", |
|
||||||
" prompt += f\"Title: {self.title}\\n\"\n", |
|
||||||
" prompt += f\"Description: {self.description}\\n\"\n", |
|
||||||
" prompt += f\"Features: {self.features}\\n\"\n", |
|
||||||
" return prompt\n", |
|
||||||
"\n", |
|
||||||
" def inference_prompt(self):\n", |
|
||||||
" return f\"{self.question()}Answer: $\"\n", |
|
||||||
"\n", |
|
||||||
" def train_prompt(self):\n", |
|
||||||
" return f\"{self.inference_prompt()}{self.price_str}\"\n", |
|
||||||
"\n", |
|
||||||
" def token_count(self):\n", |
|
||||||
" if self._token_count == None:\n", |
|
||||||
" self._token_count = len(self.tokenizer.encode(self.train_prompt()))\n", |
|
||||||
" return self._token_count\n", |
|
||||||
"\n", |
|
||||||
" def tokens_between(self, low, high):\n", |
|
||||||
" token_count = self.token_count()\n", |
|
||||||
" return token_count >= low and token_count < high" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "059152d0-a68a-4e93-b759-45f3c6baf31e", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Create a list called \"items\" with all our datapoints that have a valid price\n", |
|
||||||
"\n", |
|
||||||
"items = []\n", |
|
||||||
"for data in tqdm(dataset):\n", |
|
||||||
" try:\n", |
|
||||||
" if float(data['price']) > 0:\n", |
|
||||||
" items.append(Item(data))\n", |
|
||||||
" except ValueError:\n", |
|
||||||
" pass" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "011bffcf-03f8-4f0d-8999-b53d1ac88624", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's investigate:\n", |
|
||||||
"\n", |
|
||||||
"print(f\"There are {len(items):,} out of {len(dataset):,} with prices\\n\")\n", |
|
||||||
"print(f\"Item 0 has {items[0].token_count()} tokens:\\n\")\n", |
|
||||||
"print(items[0].train_prompt())\n", |
|
||||||
"print(f\"\\nItem 1 has {items[1].token_count()} tokens:\\n\")\n", |
|
||||||
"print(items[1].train_prompt())" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "fcf74830-1e97-4543-b454-eefd314fc106", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of character count\n", |
|
||||||
"\n", |
|
||||||
"lengths = [len(item.train_prompt()) for item in items]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Length')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 5000, 250))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "af1d6c8b-f2ae-4691-9306-989b1bd45233", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"print(f\"There are total {len(items):,} items\")\n", |
|
||||||
"cutoff = 1200\n", |
|
||||||
"selection = [item for item in items if len(item.train_prompt()) < cutoff]\n", |
|
||||||
"print(f\"There are total {len(selection):,} with under {cutoff:,} character training prompt\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "42231dc7-66fb-4437-ba08-7689514a8b19", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Calculate token sizes in selection\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in tqdm(selection)]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of tokens\n", |
|
||||||
"\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 500, 25))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1232004a-ff9b-486a-a14b-70f21c217c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Let's limit our dataset to documents with 60-180 tokens\n", |
|
||||||
"\n", |
|
||||||
"low_cutoff = 60\n", |
|
||||||
"high_cutoff = 180\n", |
|
||||||
"subset = [item for item in tqdm(selection) if item.tokens_between(low_cutoff, high_cutoff)]\n", |
|
||||||
"subset_count = len(subset)\n", |
|
||||||
"count = len(items)\n", |
|
||||||
"print(f\"\\nBetween {low_cutoff} and {high_cutoff}, we get {subset_count:,} out of {count:,} which is {subset_count/count*100:.1f}%\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "7bc11e4f-5a15-48fd-b571-92e2e10b0323", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"prices = [float(item.price) for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Price ($)')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 500, 20))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3718a8e6-6c87-4351-8c27-9e61745b0991", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Pick the most expensive 30,000 items, then pick 12,000 of the next 20,000\n", |
|
||||||
"\n", |
|
||||||
"random.seed(42)\n", |
|
||||||
"sorted_subset = sorted(subset, key=lambda item: item.price, reverse=True)\n", |
|
||||||
"top_30k = sorted_subset[:30000]\n", |
|
||||||
"other_12k = random.sample(sorted_subset[30000:50000], k=12000)\n", |
|
||||||
"sample = top_30k + other_12k\n", |
|
||||||
"print(len(sample))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "3cd1c4d3-b6e4-4f28-8ad4-709c4637626c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution of prices\n", |
|
||||||
"\n", |
|
||||||
"prices = [float(item.price) for item in sample]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Price ($)')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 500, 20))\n", |
|
||||||
"\n", |
|
||||||
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"sizes = [len(item.train_prompt()) for item in sample]\n", |
|
||||||
"prices = [item.price for item in sample]\n", |
|
||||||
"\n", |
|
||||||
"# Create the scatter plot\n", |
|
||||||
"plt.figure(figsize=(10, 6))\n", |
|
||||||
"plt.scatter(sizes, prices, s=2, color=\"red\")\n", |
|
||||||
"\n", |
|
||||||
"# Add labels and title\n", |
|
||||||
"plt.xlabel('Size')\n", |
|
||||||
"plt.ylabel('Price')\n", |
|
||||||
"plt.title('Item Price vs Size')\n", |
|
||||||
"\n", |
|
||||||
"# Display the plot\n", |
|
||||||
"plt.show()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Plot the distribution again to check it looks as expected\n", |
|
||||||
"\n", |
|
||||||
"token_counts = [item.token_count() for item in subset]\n", |
|
||||||
"fig, ax = plt.subplots(1, 1)\n", |
|
||||||
"ax.set_xlabel('Number of tokens')\n", |
|
||||||
"ax.set_ylabel('Count of items');\n", |
|
||||||
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "cacb9059-5f44-4601-860a-30860cebe9c2", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"random.seed(42)\n", |
|
||||||
"random.shuffle(sample)\n", |
|
||||||
"train = sample[:40000]\n", |
|
||||||
"test = sample[40000:]\n", |
|
||||||
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"average = sum(t.price for t in train)/len(train)\n", |
|
||||||
"average" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "95353e68-07ac-4f57-8d57-dd48cacb0e04", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"class TestRunner:\n", |
|
||||||
"\n", |
|
||||||
" def __init__(self, predictor, data, size=None):\n", |
|
||||||
" self.predictor = predictor\n", |
|
||||||
" self.data = data\n", |
|
||||||
" self.size = size or len(data)\n", |
|
||||||
" self.guesses = []\n", |
|
||||||
" self.truths = []\n", |
|
||||||
" self.errors = []\n", |
|
||||||
"\n", |
|
||||||
" def run_datapoint(self, i):\n", |
|
||||||
" datapoint = self.data[i]\n", |
|
||||||
" guess = self.predictor(datapoint)\n", |
|
||||||
" truth = datapoint.price\n", |
|
||||||
" error = abs(guess - truth)\n", |
|
||||||
" color = RED if error>=20 else YELLOW if error>=10 else GREEN\n", |
|
||||||
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n", |
|
||||||
" self.guesses.append(guess)\n", |
|
||||||
" self.truths.append(truth)\n", |
|
||||||
" self.errors.append(error)\n", |
|
||||||
" print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} Item: {title}{RESET}\")\n", |
|
||||||
"\n", |
|
||||||
" def chart(self):\n", |
|
||||||
" max_error = max(self.errors)\n", |
|
||||||
" colors = [(max_error - error)**3 for error in self.errors]\n", |
|
||||||
" plt.figure(figsize=(10, 6))\n", |
|
||||||
" plt.scatter(self.truths, self.guesses, s=2, c=colors, cmap='RdYlGn')\n", |
|
||||||
" plt.xlabel('Truth')\n", |
|
||||||
" plt.ylabel('Guess')\n", |
|
||||||
" plt.title('Guess vs Truth')\n", |
|
||||||
" plt.show()\n", |
|
||||||
"\n", |
|
||||||
" def run(self):\n", |
|
||||||
" self.error = 0\n", |
|
||||||
" for i in range(self.size):\n", |
|
||||||
" self.run_datapoint(i)\n", |
|
||||||
" average_error = sum(self.errors) / self.size\n", |
|
||||||
" print(f\"Average Error = ${average_error:,.2f}\")\n", |
|
||||||
" hits = [e for e in self.errors if e<10]\n", |
|
||||||
" print(f\"Hit rate = {len(hits)/self.size*100:.1f}%\")\n", |
|
||||||
" self.chart()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def flat_predictor(item):\n", |
|
||||||
" return 218.28366025006002" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "739d2e33-55d4-4892-b42c-771131159c8d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(flat_predictor, test, 100).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an'])\n", |
|
||||||
"\n", |
|
||||||
"def words(item):\n", |
|
||||||
" text = f\"{item.title} {item.description} {item.features}\"\n", |
|
||||||
" text = re.sub(r'[()\\[\\]{},\\'\"-]', ' ', text)\n", |
|
||||||
" text = re.sub(r'\\s+', ' ', text)\n", |
|
||||||
" words = text.strip().lower().split(' ')\n", |
|
||||||
" filtered = [word for word in words if word not in stop]\n", |
|
||||||
" return \" \".join(filtered)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "262fc576-7606-426c-8aea-5799b3952d2c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import CountVectorizer\n", |
|
||||||
"from sklearn.linear_model import LinearRegression\n", |
|
||||||
"import numpy as np\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = CountVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = LinearRegression()\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "bd782b21-8e44-409d-a7b6-f136974958b4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def linear_regression_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "80e77aae-0071-42e9-8e24-d3aec5256015", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(linear_regression_predictor, test, 100).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"from sklearn.feature_extraction.text import TfidfVectorizer\n", |
|
||||||
"from sklearn.svm import SVR\n", |
|
||||||
"\n", |
|
||||||
"np.random.seed(42)\n", |
|
||||||
"\n", |
|
||||||
"documents = [words(item) for item in train]\n", |
|
||||||
"labels = np.array([float(item.price) for item in train])\n", |
|
||||||
"\n", |
|
||||||
"vectorizer = TfidfVectorizer()\n", |
|
||||||
"X = vectorizer.fit_transform(documents)\n", |
|
||||||
"\n", |
|
||||||
"regressor = SVR(kernel='linear')\n", |
|
||||||
"regressor.fit(X, labels)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "64560112-3bfb-45cc-b489-de619a2eca20", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def svr_predictor(item):\n", |
|
||||||
" np.random.seed(42)\n", |
|
||||||
" x = vectorizer.transform([words(item)])\n", |
|
||||||
" return max(regressor.predict(x)[0], 0)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "392598d4-2deb-4935-9175-fd111616b13c", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(svr_predictor, test, 100).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "60010699-d26b-4f93-a959-50272ada6a57", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def messages_for(item):\n", |
|
||||||
" system_message = \"You predict prices based on a description. Reply only with the price in $, no explanation or comments\"\n", |
|
||||||
" user_prompt = item.question()\n", |
|
||||||
" return [\n", |
|
||||||
" {\"role\": \"system\", \"content\": system_message},\n", |
|
||||||
" {\"role\": \"user\", \"content\": user_prompt}\n", |
|
||||||
" ]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def get_price(s):\n", |
|
||||||
" s = s.replace('$','').replace(',','')\n", |
|
||||||
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", |
|
||||||
" return float(match.group()) if match else 0" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"def gpt_predictor(item):\n", |
|
||||||
" response = openai.chat.completions.create(\n", |
|
||||||
" model=OPENAI_MODEL,\n", |
|
||||||
" messages=messages_for(item),\n", |
|
||||||
" seed=42\n", |
|
||||||
" )\n", |
|
||||||
" reply = response.choices[0].message.content\n", |
|
||||||
" return get_price(reply)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"TestRunner(gpt_predictor, test, 100).run()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f7e24d6b-59a2-464a-95a9-14a9fbfadd4d", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"test[0].train_prompt()" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "059b6c74-917f-4cb1-b810-ce70735a57be", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"train_prompts = [item.train_prompt() for item in train]\n", |
|
||||||
"train_prices = [item.price for item in train]\n", |
|
||||||
"test_prompts = [item.inference_prompt() for item in test]\n", |
|
||||||
"test_prices = [item.price for item in test]" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"# Create a Dataset from the lists\n", |
|
||||||
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", |
|
||||||
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", |
|
||||||
"dataset = DatasetDict({\n", |
|
||||||
" \"train\": train_dataset,\n", |
|
||||||
" \"test\": test_dataset\n", |
|
||||||
"})" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "e69e26a5-4b24-4e0f-8944-731c534b285b", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [ |
|
||||||
"DATASET_NAME = \"ed-donner/electronics\"\n", |
|
||||||
"dataset.push_to_hub(DATASET_NAME, private=True)" |
|
||||||
] |
|
||||||
}, |
|
||||||
{ |
|
||||||
"cell_type": "code", |
|
||||||
"execution_count": null, |
|
||||||
"id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd", |
|
||||||
"metadata": {}, |
|
||||||
"outputs": [], |
|
||||||
"source": [] |
|
||||||
} |
|
||||||
], |
|
||||||
"metadata": { |
|
||||||
"kernelspec": { |
|
||||||
"display_name": "Python 3 (ipykernel)", |
|
||||||
"language": "python", |
|
||||||
"name": "python3" |
|
||||||
}, |
|
||||||
"language_info": { |
|
||||||
"codemirror_mode": { |
|
||||||
"name": "ipython", |
|
||||||
"version": 3 |
|
||||||
}, |
|
||||||
"file_extension": ".py", |
|
||||||
"mimetype": "text/x-python", |
|
||||||
"name": "python", |
|
||||||
"nbconvert_exporter": "python", |
|
||||||
"pygments_lexer": "ipython3", |
|
||||||
"version": "3.11.10" |
|
||||||
} |
|
||||||
}, |
|
||||||
"nbformat": 4, |
|
||||||
"nbformat_minor": 5 |
|
||||||
} |
|
|
@ -1,101 +0,0 @@ |
|||||||
from typing import Optional |
|
||||||
from tqdm import tqdm |
|
||||||
from datasets import load_dataset |
|
||||||
from transformers import AutoTokenizer |
|
||||||
import re |
|
||||||
|
|
||||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
||||||
MIN_TOKENS = 100 |
|
||||||
MAX_TOKENS = 141 |
|
||||||
|
|
||||||
class Item: |
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
||||||
PREFIX = "Price is $" |
|
||||||
|
|
||||||
title: str |
|
||||||
price: float |
|
||||||
category: str |
|
||||||
token_count: int = 0 |
|
||||||
text: Optional[str] |
|
||||||
details: Optional[str] |
|
||||||
prompt: Optional[str] = None |
|
||||||
include = False |
|
||||||
|
|
||||||
def __init__(self, data, price, category): |
|
||||||
self.title = data['title'] |
|
||||||
self.price = price |
|
||||||
self.category = category |
|
||||||
self.parse(data) |
|
||||||
|
|
||||||
def scrub_details(self): |
|
||||||
details = self.details |
|
||||||
removals = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "] |
|
||||||
for remove in removals: |
|
||||||
details = details.replace(remove, "") |
|
||||||
return details |
|
||||||
|
|
||||||
|
|
||||||
def parse(self, data): |
|
||||||
self.text = self.title + '\n' |
|
||||||
self.text += '\n'.join(data['description'])+ '\n' |
|
||||||
self.details = data['details'] |
|
||||||
if self.details: |
|
||||||
self.text += self.scrub_details() + '\n' |
|
||||||
features = '\n'.join(data['features']) |
|
||||||
if features: |
|
||||||
self.text += '\n' + features |
|
||||||
self.text = re.sub(r'[:\[\]"{}【】\s]+', ' ', self.text).strip() |
|
||||||
self.text = self.text.replace(" ,", ",").replace(",,,",",").replace(",,",",") |
|
||||||
tokens = self.tokenizer.encode(self.text, add_special_tokens=False) |
|
||||||
if len(tokens) > MIN_TOKENS: |
|
||||||
tokens = tokens[:MAX_TOKENS] |
|
||||||
self.text = self.tokenizer.decode(tokens) |
|
||||||
self.make_prompt() |
|
||||||
self.count_tokens() |
|
||||||
self.include = True |
|
||||||
|
|
||||||
def question(self): |
|
||||||
prompt = "How much is this?\n" |
|
||||||
prompt += f"{self.text}\n" |
|
||||||
return prompt |
|
||||||
|
|
||||||
def messages(self): |
|
||||||
return [ |
|
||||||
{"role":"system", "content": "You estimate prices to the nearest dollar"}, |
|
||||||
{"role":"user", "content": self.question()}, |
|
||||||
{"role":"assistant", "content": f"{self.PREFIX}{str(round(self.price))}.00"} |
|
||||||
] |
|
||||||
|
|
||||||
def make_prompt(self): |
|
||||||
prompt = self.tokenizer.apply_chat_template(self.messages(), tokenize=False, add_generation_prompt=False) |
|
||||||
groups = prompt.split('\n\n') |
|
||||||
self.prompt = groups[0]+'\n\n'+'\n\n'.join(groups[2:]) |
|
||||||
|
|
||||||
def count_tokens(self): |
|
||||||
self.token_count = len(self.tokenizer.encode(self.prompt)) |
|
||||||
|
|
||||||
def tokens_between(self, low, high): |
|
||||||
return self.token_count >= low and self.token_count < high |
|
||||||
|
|
||||||
def test_prompt(self): |
|
||||||
return self.prompt.split(self.PREFIX)[0] + self.PREFIX |
|
||||||
|
|
||||||
def read_dataset(name): |
|
||||||
print(f"Loading dataset {name}", flush=True) |
|
||||||
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{name}", split="full", trust_remote_code=True) |
|
||||||
results = [] |
|
||||||
for data in dataset: |
|
||||||
try: |
|
||||||
price_str = data['price'] |
|
||||||
if price_str: |
|
||||||
price = float(price_str) |
|
||||||
if price >= 0.5 and price <= 999.49: |
|
||||||
item = Item(data, price, name) |
|
||||||
if item.include: |
|
||||||
results.append(item) |
|
||||||
except ValueError: |
|
||||||
pass |
|
||||||
print(f"Completed loading {name} with {len(results):,} datapoints", flush=True) |
|
||||||
del dataset |
|
||||||
return results |
|
@ -1,94 +0,0 @@ |
|||||||
from typing import Optional |
|
||||||
from tqdm import tqdm |
|
||||||
from datasets import load_dataset |
|
||||||
from transformers import AutoTokenizer |
|
||||||
import re |
|
||||||
|
|
||||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B" |
|
||||||
MIN_TOKENS = 150 |
|
||||||
MAX_TOKENS = 160 |
|
||||||
MIN_CHARS = 300 |
|
||||||
CEILING_CHARS = MAX_TOKENS * 7 |
|
||||||
|
|
||||||
class Item: |
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
||||||
eos = tokenizer.eos_token |
|
||||||
bos = tokenizer.bos_token |
|
||||||
PREFIX = "Price is $" |
|
||||||
QUESTION = "How much does this cost to the nearest dollar?" |
|
||||||
|
|
||||||
title: str |
|
||||||
price: float |
|
||||||
category: str |
|
||||||
token_count: int = 0 |
|
||||||
text: Optional[str] |
|
||||||
details: Optional[str] |
|
||||||
prompt: Optional[str] = None |
|
||||||
include = False |
|
||||||
|
|
||||||
def __init__(self, data, price, category): |
|
||||||
self.title = data['title'] |
|
||||||
self.price = price |
|
||||||
self.category = category |
|
||||||
self.parse(data) |
|
||||||
|
|
||||||
def scrub_details(self): |
|
||||||
details = self.details |
|
||||||
removals = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "] |
|
||||||
for remove in removals: |
|
||||||
details = details.replace(remove, "") |
|
||||||
return details |
|
||||||
|
|
||||||
def scrub(self, stuff): |
|
||||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip() |
|
||||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",") |
|
||||||
words = stuff.split(' ') |
|
||||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)] |
|
||||||
return " ".join(select) |
|
||||||
|
|
||||||
def parse(self, data): |
|
||||||
contents = '\n'.join(data['description']) |
|
||||||
if contents: |
|
||||||
contents += '\n' |
|
||||||
features = '\n'.join(data['features']) |
|
||||||
if features: |
|
||||||
contents += features + '\n' |
|
||||||
self.details = data['details'] |
|
||||||
if self.details: |
|
||||||
contents += self.scrub_details() + '\n' |
|
||||||
if len(contents) > MIN_CHARS: |
|
||||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents[:CEILING_CHARS])}" |
|
||||||
tokens = self.tokenizer.encode(text, add_special_tokens=False) |
|
||||||
if len(tokens) > MIN_TOKENS: |
|
||||||
tokens = tokens[:MAX_TOKENS] |
|
||||||
text = self.tokenizer.decode(tokens) |
|
||||||
self.make_prompt(text) |
|
||||||
self.include = True |
|
||||||
|
|
||||||
def make_prompt(self, text): |
|
||||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n" |
|
||||||
self.prompt += f"{self.PREFIX}{str(round(self.price))}.00" |
|
||||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False)) |
|
||||||
|
|
||||||
def test_prompt(self): |
|
||||||
return self.prompt.split(self.PREFIX)[0] + self.PREFIX |
|
||||||
|
|
||||||
def read_dataset(name): |
|
||||||
print(f"Loading dataset {name}", flush=True) |
|
||||||
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{name}", split="full", trust_remote_code=True) |
|
||||||
results = [] |
|
||||||
for data in dataset: |
|
||||||
try: |
|
||||||
price_str = data['price'] |
|
||||||
if price_str: |
|
||||||
price = float(price_str) |
|
||||||
if price >= 0.5 and price <= 999.49: |
|
||||||
item = Item(data, price, name) |
|
||||||
if item.include: |
|
||||||
results.append(item) |
|
||||||
except ValueError: |
|
||||||
pass |
|
||||||
print(f"Completed loading {name} with {len(results):,} datapoints", flush=True) |
|
||||||
del dataset |
|
||||||
return results |
|
@ -1,133 +0,0 @@ |
|||||||
from typing import Optional |
|
||||||
from datetime import datetime |
|
||||||
from tqdm import tqdm |
|
||||||
from datasets import load_dataset |
|
||||||
from transformers import AutoTokenizer |
|
||||||
import re |
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor |
|
||||||
|
|
||||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B" |
|
||||||
MIN_TOKENS = 150 |
|
||||||
MAX_TOKENS = 160 |
|
||||||
MIN_CHARS = 300 |
|
||||||
CEILING_CHARS = MAX_TOKENS * 7 |
|
||||||
|
|
||||||
class Item: |
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
||||||
PREFIX = "Price is $" |
|
||||||
QUESTION = "How much does this cost to the nearest dollar?" |
|
||||||
|
|
||||||
title: str |
|
||||||
price: float |
|
||||||
category: str |
|
||||||
token_count: int = 0 |
|
||||||
details: Optional[str] |
|
||||||
prompt: Optional[str] = None |
|
||||||
include = False |
|
||||||
|
|
||||||
def __init__(self, data, price): |
|
||||||
self.title = data['title'] |
|
||||||
self.price = price |
|
||||||
self.parse(data) |
|
||||||
|
|
||||||
def scrub_details(self): |
|
||||||
details = self.details |
|
||||||
removals = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "] |
|
||||||
for remove in removals: |
|
||||||
details = details.replace(remove, "") |
|
||||||
return details |
|
||||||
|
|
||||||
def scrub(self, stuff): |
|
||||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip() |
|
||||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",") |
|
||||||
words = stuff.split(' ') |
|
||||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)] |
|
||||||
return " ".join(select) |
|
||||||
|
|
||||||
def parse(self, data): |
|
||||||
contents = '\n'.join(data['description']) |
|
||||||
if contents: |
|
||||||
contents += '\n' |
|
||||||
features = '\n'.join(data['features']) |
|
||||||
if features: |
|
||||||
contents += features + '\n' |
|
||||||
self.details = data['details'] |
|
||||||
if self.details: |
|
||||||
contents += self.scrub_details() + '\n' |
|
||||||
if len(contents) > MIN_CHARS: |
|
||||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents[:CEILING_CHARS])}" |
|
||||||
tokens = self.tokenizer.encode(text, add_special_tokens=False) |
|
||||||
if len(tokens) > MIN_TOKENS: |
|
||||||
tokens = tokens[:MAX_TOKENS] |
|
||||||
text = self.tokenizer.decode(tokens) |
|
||||||
self.make_prompt(text) |
|
||||||
self.include = True |
|
||||||
|
|
||||||
def make_prompt(self, text): |
|
||||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n" |
|
||||||
self.prompt += f"{self.PREFIX}{str(round(self.price))}.00" |
|
||||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False)) |
|
||||||
|
|
||||||
def test_prompt(self): |
|
||||||
return self.prompt.split(self.PREFIX)[0] + self.PREFIX |
|
||||||
|
|
||||||
|
|
||||||
class ItemLoader: |
|
||||||
|
|
||||||
def __init__(self, name): |
|
||||||
self.name = name |
|
||||||
self.dataset = None |
|
||||||
|
|
||||||
def from_datapoint(self, datapoint): |
|
||||||
try: |
|
||||||
price_str = datapoint['price'] |
|
||||||
if price_str: |
|
||||||
price = float(price_str) |
|
||||||
if price >= 0.5 and price <= 999.49: |
|
||||||
item = Item(datapoint, price) |
|
||||||
if item.include: |
|
||||||
return item |
|
||||||
except ValueError: |
|
||||||
pass |
|
||||||
return None |
|
||||||
|
|
||||||
def from_chunk(self, chunk): |
|
||||||
batch = [] |
|
||||||
for datapoint in chunk: |
|
||||||
result = self.from_datapoint(datapoint) |
|
||||||
if result: |
|
||||||
batch.append(result) |
|
||||||
return batch |
|
||||||
|
|
||||||
def make_chunks(self): |
|
||||||
print("Preparing data chunks...", end="", flush=True) |
|
||||||
size = len(self.dataset) |
|
||||||
chunks = [] |
|
||||||
for i in range(0, size, 1000): |
|
||||||
chunks.append(self.dataset.select(range(i, min(i + 1000, size)))) |
|
||||||
print(" done.", flush=True) |
|
||||||
return chunks |
|
||||||
|
|
||||||
def load_in_parallel(self, chunks, workers): |
|
||||||
results = [] |
|
||||||
with ProcessPoolExecutor(max_workers=6) as pool: |
|
||||||
for batch in tqdm(pool.map(self.from_chunk, chunks), total=len(chunks)): |
|
||||||
results.extend(batch) |
|
||||||
for result in results: |
|
||||||
result.category = self.name |
|
||||||
return results |
|
||||||
|
|
||||||
def load(self, workers=8): |
|
||||||
start = datetime.now() |
|
||||||
print(f"Loading dataset {self.name}", flush=True) |
|
||||||
self.dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{self.name}", split="full", trust_remote_code=True) |
|
||||||
chunks = self.make_chunks() |
|
||||||
results = self.load_in_parallel(chunks, workers) |
|
||||||
finish = datetime.now() |
|
||||||
print(f"Completed loading {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins", flush=True) |
|
||||||
return results |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,94 +0,0 @@ |
|||||||
from typing import Optional |
|
||||||
from tqdm import tqdm |
|
||||||
from datasets import load_dataset |
|
||||||
from transformers import AutoTokenizer |
|
||||||
import re |
|
||||||
|
|
||||||
BASE_MODEL = "Qwen/Qwen2-7B" |
|
||||||
MIN_TOKENS = 150 |
|
||||||
MAX_TOKENS = 160 |
|
||||||
MIN_CHARS = 300 |
|
||||||
CEILING_CHARS = MAX_TOKENS * 7 |
|
||||||
|
|
||||||
class Item: |
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
||||||
eos = tokenizer.eos_token |
|
||||||
bos = tokenizer.bos_token |
|
||||||
PREFIX = "Price is $" |
|
||||||
QUESTION = "How much does this cost to the nearest dollar?" |
|
||||||
|
|
||||||
title: str |
|
||||||
price: float |
|
||||||
category: str |
|
||||||
token_count: int = 0 |
|
||||||
text: Optional[str] |
|
||||||
details: Optional[str] |
|
||||||
prompt: Optional[str] = None |
|
||||||
include = False |
|
||||||
|
|
||||||
def __init__(self, data, price, category): |
|
||||||
self.title = data['title'] |
|
||||||
self.price = price |
|
||||||
self.category = category |
|
||||||
self.parse(data) |
|
||||||
|
|
||||||
def scrub_details(self): |
|
||||||
details = self.details |
|
||||||
removals = ['"Batteries Included?": "No"', '"Batteries Included?": "Yes"', '"Batteries Required?": "No"', '"Batteries Required?": "Yes"', "By Manufacturer", "Item", "Date First", "Package", ":", "Number of", "Best Sellers", "Number", "Product "] |
|
||||||
for remove in removals: |
|
||||||
details = details.replace(remove, "") |
|
||||||
return details |
|
||||||
|
|
||||||
def scrub(self, stuff): |
|
||||||
stuff = re.sub(r'[:\[\]"{}【】\s]+', ' ', stuff).strip() |
|
||||||
stuff = stuff.replace(" ,", ",").replace(",,,",",").replace(",,",",") |
|
||||||
words = stuff.split(' ') |
|
||||||
select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)] |
|
||||||
return " ".join(select) |
|
||||||
|
|
||||||
def parse(self, data): |
|
||||||
contents = '\n'.join(data['description']) |
|
||||||
if contents: |
|
||||||
contents += '\n' |
|
||||||
features = '\n'.join(data['features']) |
|
||||||
if features: |
|
||||||
contents += features + '\n' |
|
||||||
self.details = data['details'] |
|
||||||
if self.details: |
|
||||||
contents += self.scrub_details() + '\n' |
|
||||||
if len(contents) > MIN_CHARS: |
|
||||||
text = f"{self.scrub(self.title)}\n{self.scrub(contents[:CEILING_CHARS])}" |
|
||||||
tokens = self.tokenizer.encode(text, add_special_tokens=False) |
|
||||||
if len(tokens) > MIN_TOKENS: |
|
||||||
tokens = tokens[:MAX_TOKENS] |
|
||||||
text = self.tokenizer.decode(tokens) |
|
||||||
self.make_prompt(text) |
|
||||||
self.include = True |
|
||||||
|
|
||||||
def make_prompt(self, text): |
|
||||||
self.prompt = f"{self.QUESTION}\n\n{text}\n\n" |
|
||||||
self.prompt += f"{self.PREFIX}{str(round(self.price))}.00" |
|
||||||
self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False)) |
|
||||||
|
|
||||||
def test_prompt(self): |
|
||||||
return self.prompt.split(self.PREFIX)[0] + self.PREFIX |
|
||||||
|
|
||||||
def read_dataset(name): |
|
||||||
print(f"Loading dataset {name}", flush=True) |
|
||||||
dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{name}", split="full", trust_remote_code=True) |
|
||||||
results = [] |
|
||||||
for data in dataset: |
|
||||||
try: |
|
||||||
price_str = data['price'] |
|
||||||
if price_str: |
|
||||||
price = float(price_str) |
|
||||||
if price >= 0.5 and price <= 999.49: |
|
||||||
item = Item(data, price, name) |
|
||||||
if item.include: |
|
||||||
results.append(item) |
|
||||||
except ValueError: |
|
||||||
pass |
|
||||||
print(f"Completed loading {name} with {len(results):,} datapoints", flush=True) |
|
||||||
del dataset |
|
||||||
return results |
|
|
@ -0,0 +1,39 @@ |
|||||||
|
{ |
||||||
|
"cells": [ |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "c25e5705-7078-4d4a-9fa2-8aaa528ffced", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"# Week 7 Day 1\n", |
||||||
|
"\n", |
||||||
|
"Fine-tune an open-source model to Predict Product Prices\n", |
||||||
|
"\n", |
||||||
|
"Please see this notebook in Google Colab:\n", |
||||||
|
"\n", |
||||||
|
"https://colab.research.google.com/drive/15rqdMTJwK76icPBxNoqhI7Ww8UM-Y7ni?usp=sharing" |
||||||
|
] |
||||||
|
} |
||||||
|
], |
||||||
|
"metadata": { |
||||||
|
"kernelspec": { |
||||||
|
"display_name": "Python 3 (ipykernel)", |
||||||
|
"language": "python", |
||||||
|
"name": "python3" |
||||||
|
}, |
||||||
|
"language_info": { |
||||||
|
"codemirror_mode": { |
||||||
|
"name": "ipython", |
||||||
|
"version": 3 |
||||||
|
}, |
||||||
|
"file_extension": ".py", |
||||||
|
"mimetype": "text/x-python", |
||||||
|
"name": "python", |
||||||
|
"nbconvert_exporter": "python", |
||||||
|
"pygments_lexer": "ipython3", |
||||||
|
"version": "3.11.10" |
||||||
|
} |
||||||
|
}, |
||||||
|
"nbformat": 4, |
||||||
|
"nbformat_minor": 5 |
||||||
|
} |
@ -0,0 +1,39 @@ |
|||||||
|
{ |
||||||
|
"cells": [ |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "c25e5705-7078-4d4a-9fa2-8aaa528ffced", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"# Week 7 Day 2\n", |
||||||
|
"\n", |
||||||
|
"Fine-tune an open-source model to Predict Product Prices\n", |
||||||
|
"\n", |
||||||
|
"Please see this notebook in Google Colab:\n", |
||||||
|
"\n", |
||||||
|
"https://colab.research.google.com/drive/1T72pbfZw32fq-clQEp-p8YQ4_qFKv4TP?usp=sharing" |
||||||
|
] |
||||||
|
} |
||||||
|
], |
||||||
|
"metadata": { |
||||||
|
"kernelspec": { |
||||||
|
"display_name": "Python 3 (ipykernel)", |
||||||
|
"language": "python", |
||||||
|
"name": "python3" |
||||||
|
}, |
||||||
|
"language_info": { |
||||||
|
"codemirror_mode": { |
||||||
|
"name": "ipython", |
||||||
|
"version": 3 |
||||||
|
}, |
||||||
|
"file_extension": ".py", |
||||||
|
"mimetype": "text/x-python", |
||||||
|
"name": "python", |
||||||
|
"nbconvert_exporter": "python", |
||||||
|
"pygments_lexer": "ipython3", |
||||||
|
"version": "3.11.10" |
||||||
|
} |
||||||
|
}, |
||||||
|
"nbformat": 4, |
||||||
|
"nbformat_minor": 5 |
||||||
|
} |
@ -0,0 +1,39 @@ |
|||||||
|
{ |
||||||
|
"cells": [ |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "c25e5705-7078-4d4a-9fa2-8aaa528ffced", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"# Week 7 Days 3 and 4\n", |
||||||
|
"\n", |
||||||
|
"Fine-tune an open-source model to Predict Product Prices\n", |
||||||
|
"\n", |
||||||
|
"Please see this notebook in Google Colab:\n", |
||||||
|
"\n", |
||||||
|
"https://colab.research.google.com/drive/1csEdaECRtjV_1p9zMkaKKjCpYnltlN3M?usp=sharing" |
||||||
|
] |
||||||
|
} |
||||||
|
], |
||||||
|
"metadata": { |
||||||
|
"kernelspec": { |
||||||
|
"display_name": "Python 3 (ipykernel)", |
||||||
|
"language": "python", |
||||||
|
"name": "python3" |
||||||
|
}, |
||||||
|
"language_info": { |
||||||
|
"codemirror_mode": { |
||||||
|
"name": "ipython", |
||||||
|
"version": 3 |
||||||
|
}, |
||||||
|
"file_extension": ".py", |
||||||
|
"mimetype": "text/x-python", |
||||||
|
"name": "python", |
||||||
|
"nbconvert_exporter": "python", |
||||||
|
"pygments_lexer": "ipython3", |
||||||
|
"version": "3.11.10" |
||||||
|
} |
||||||
|
}, |
||||||
|
"nbformat": 4, |
||||||
|
"nbformat_minor": 5 |
||||||
|
} |
@ -0,0 +1,39 @@ |
|||||||
|
{ |
||||||
|
"cells": [ |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "c25e5705-7078-4d4a-9fa2-8aaa528ffced", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"# Week 7 Days 5\n", |
||||||
|
"\n", |
||||||
|
"Fine-tune an open-source model to Predict Product Prices\n", |
||||||
|
"\n", |
||||||
|
"Please see this notebook in Google Colab:\n", |
||||||
|
"\n", |
||||||
|
"https://colab.research.google.com/drive/1igA0HF0gvQqbdBD4GkcK3GpHtuDLijYn?usp=sharing" |
||||||
|
] |
||||||
|
} |
||||||
|
], |
||||||
|
"metadata": { |
||||||
|
"kernelspec": { |
||||||
|
"display_name": "Python 3 (ipykernel)", |
||||||
|
"language": "python", |
||||||
|
"name": "python3" |
||||||
|
}, |
||||||
|
"language_info": { |
||||||
|
"codemirror_mode": { |
||||||
|
"name": "ipython", |
||||||
|
"version": 3 |
||||||
|
}, |
||||||
|
"file_extension": ".py", |
||||||
|
"mimetype": "text/x-python", |
||||||
|
"name": "python", |
||||||
|
"nbconvert_exporter": "python", |
||||||
|
"pygments_lexer": "ipython3", |
||||||
|
"version": "3.11.10" |
||||||
|
} |
||||||
|
}, |
||||||
|
"nbformat": 4, |
||||||
|
"nbformat_minor": 5 |
||||||
|
} |
Loading…
Reference in new issue