You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

942 lines
27 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "9491dd8f-8124-4a51-be3a-8f678c149dcf",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import re\n",
"import math\n",
"import random\n",
"import numpy as np\n",
"from typing import Optional\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import anthropic\n",
"from huggingface_hub import login\n",
"from tqdm import tqdm\n",
"import matplotlib.pyplot as plt\n",
"from datasets import load_dataset, Dataset, DatasetDict\n",
"from transformers import AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv()\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "535addd2-9590-42dd-81d8-9dbe06e0194a",
"metadata": {},
"outputs": [],
"source": [
"# constants\n",
"\n",
"MIN_TOKENS = 80\n",
"MAX_TOKENS = 180\n",
"CUTOFF_CHARS = MAX_TOKENS * 7"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "846ded5d-b7f5-4581-8f56-d9650ff329c1",
"metadata": {},
"outputs": [],
"source": [
"# initialize\n",
"\n",
"openai = OpenAI()\n",
"claude = anthropic.Anthropic()\n",
"OPENAI_MODEL = \"gpt-4o-mini\"\n",
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n",
"hf_token = os.environ['HF_TOKEN']\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901",
"metadata": {},
"outputs": [],
"source": [
"# Constants\n",
"\n",
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
"\n",
"# Used for writing to output in color\n",
"\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"RESET = \"\\033[0m\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb2ed609-a00a-4ff8-9f4d-8f2ff8ea26dd",
"metadata": {},
"outputs": [],
"source": [
"datasets = [\"Electronics\", \"Appliances\", \"Cell_Phones_and_Accessories\", \"Home_and_Kitchen\", \"Tools_and_Home_Improvement\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51af18a2-4122-4753-8f5d-622da2976cb5",
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", \"raw_meta_Electronics\", split=\"full\", trust_remote_code=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "141ddcdd-bd60-44d4-8c63-1c6717f5bafc",
"metadata": {},
"outputs": [],
"source": [
"print(f\"There are {len(dataset):,} items in the dataset\")\n",
"print(\"Here is the first:\")\n",
"item = dataset[0]\n",
"print(item['title'])\n",
"print(item['description'])\n",
"print(item['features'])\n",
"print(item['price'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f36c948d-e14d-44a0-9704-c11c589a26ee",
"metadata": {},
"outputs": [],
"source": [
"class Item:\n",
" \n",
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
" PREFIX = \"Price is $\"\n",
"\n",
" title: str\n",
" price: float\n",
" token_count: int = 0\n",
" details: Optional[str]\n",
" prompt: Optional[str]\n",
"\n",
" def __init__(self, data, price):\n",
" self.title = data['title']\n",
" self.price = price\n",
" self.create_details(data)\n",
" \n",
" def create_details(self, data):\n",
" self.details = '\\n'.join(data['description'])\n",
" features = '\\n'.join(data['features'])\n",
" if features:\n",
" self.details += '\\n' + features\n",
" self.details = re.sub(r'[\\[\\]【】\\s]+', ' ', self.details).strip()\n",
" self.make_prompt()\n",
"\n",
" def question(self):\n",
" prompt = \"How much does this cost?\\n\"\n",
" prompt += f\"Title: {self.title}\\n\"\n",
" prompt += f\"Details: {self.details}\\n\"\n",
" return prompt\n",
"\n",
" def messages(self):\n",
" return [\n",
" {\"role\":\"system\", \"content\": \"You estimate product prices. Reply only with the price to the nearest dollar\"},\n",
" {\"role\":\"user\", \"content\": self.question()},\n",
" {\"role\":\"assistant\", \"content\": f\"{self.PREFIX}{str(round(self.price))}.00\"}\n",
" ]\n",
"\n",
" def make_prompt(self):\n",
" prompt = self.tokenizer.apply_chat_template(self.messages(), tokenize=False, add_generation_prompt=False)\n",
" groups = prompt.split('\\n\\n')\n",
" self.prompt = groups[0]+'\\n\\n'+'\\n\\n'.join(groups[2:])\n",
"\n",
" def count_tokens(self):\n",
" self.token_count = len(self.tokenizer.encode(self.prompt))\n",
"\n",
" def tokens_between(self, low, high):\n",
" return self.token_count >= low and self.token_count < high\n",
"\n",
" def test_prompt(self):\n",
" return self.prompt.split(self.PREFIX)[0] + self.PREFIX"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20d97009-6b35-4fdf-baae-59dbd1bf6f77",
"metadata": {},
"outputs": [],
"source": [
"def read_dataset(name):\n",
" print(f\"Loading dataset {name}\")\n",
" dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", f\"raw_meta_{name}\", split=\"full\", trust_remote_code=True)\n",
" results = []\n",
" for data in tqdm(dataset):\n",
" try:\n",
" price_str = data['price']\n",
" if price_str:\n",
" price = float(price_str)\n",
" if price > 0:\n",
" results.append(Item(data, price))\n",
" except ValueError:\n",
" pass\n",
" print(f\"Completed loading {name} with {len(results):,} datapoints\")\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd11853b-9e21-4b14-9a08-9d9f63636e1a",
"metadata": {},
"outputs": [],
"source": [
"items = []\n",
"for dataset in datasets:\n",
" items.extend(read_dataset(dataset))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "011bffcf-03f8-4f0d-8999-b53d1ac88624",
"metadata": {},
"outputs": [],
"source": [
"# Let's investigate:\n",
"\n",
"print(f\"There are {len(items):,} items with prices\\n\")\n",
"print(items[0].prompt)\n",
"print(items[1].prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcf74830-1e97-4543-b454-eefd314fc106",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of character count\n",
"\n",
"lengths = [len(item.prompt) for item in items]\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.set_xlabel('Length')\n",
"ax.set_ylabel('Count of items');\n",
"_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 5000, 250))\n",
"\n",
"print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af1d6c8b-f2ae-4691-9306-989b1bd45233",
"metadata": {},
"outputs": [],
"source": [
"print(f\"There are total {len(items):,} items\")\n",
"selection = [item for item in items if len(item.prompt) < CUTOFF_CHARS]\n",
"print(f\"There are total {len(selection):,} with under {CUTOFF_CHARS:,} character training prompt\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42231dc7-66fb-4437-ba08-7689514a8b19",
"metadata": {},
"outputs": [],
"source": [
"# Calculate token sizes in selection\n",
"\n",
"for item in tqdm(selection):\n",
" item.count_tokens()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of tokens\n",
"\n",
"token_counts = [item.token_count for item in tqdm(selection)]\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.set_xlabel('Number of tokens')\n",
"ax.set_ylabel('Count of items');\n",
"_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 500, 25))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "da0a20b4-8926-4eff-bf83-11c4f6b40455",
"metadata": {},
"outputs": [],
"source": [
"def report(item):\n",
" prompt = item.prompt\n",
" tokens = Item.tokenizer.encode(item.prompt)\n",
" print(prompt)\n",
" print(tokens[-10:])\n",
" print(Item.tokenizer.batch_decode(tokens[-10:]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2378cb92-305a-49d0-8193-4ae09a0cccf8",
"metadata": {},
"outputs": [],
"source": [
"report(items[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1232004a-ff9b-486a-a14b-70f21c217c8d",
"metadata": {},
"outputs": [],
"source": [
"# Let's limit our dataset to documents with 60-180 tokens\n",
"\n",
"subset = [item for item in tqdm(selection) if item.tokens_between(MIN_TOKENS, MAX_TOKENS)]\n",
"subset_count = len(subset)\n",
"count = len(items)\n",
"print(f\"\\nBetween {MIN_TOKENS} and {MAX_TOKENS}, we get {subset_count:,} out of {count:,} which is {subset_count/count*100:.1f}%\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bc11e4f-5a15-48fd-b571-92e2e10b0323",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution again to check it looks as expected\n",
"\n",
"token_counts = [item.token_count for item in subset]\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.set_xlabel('Number of tokens')\n",
"ax.set_ylabel('Count of items');\n",
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices\n",
"\n",
"prices = [float(item.price) for item in subset]\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.set_xlabel('Price ($)')\n",
"ax.set_ylabel('Count of items');\n",
"_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 500, 20))\n",
"\n",
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3718a8e6-6c87-4351-8c27-9e61745b0991",
"metadata": {},
"outputs": [],
"source": [
"# Pick the most expensive 52,000 items, then pick 12,000 of the next 20,000\n",
"\n",
"# random.seed(42)\n",
"# subset2 = [item for item in subset if item.price <= 999]\n",
"# sorted_subset2 = sorted(subset2, key=lambda item: item.price, reverse=True)\n",
"# sample = sorted_subset2[:90_000]\n",
"# other_12k = random.sample(sorted_subset2[90_000:130_000], k=15000)\n",
"# sample += other_12k\n",
"# print(f\"Created a sample of {len(sample):,} with prices ranging from ${sample[-1].price} to ${sample[0].price}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f50917db-ab22-4ecd-a7f1-a2cd45ceb7e6",
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"subset = [item for item in subset if item.price <= 999]\n",
"sorted_subset = sorted(subset, key=lambda item: item.price, reverse=True)\n",
"sample = sorted_subset[:150_000]\n",
"sample += random.sample(sorted_subset[150_000:300_000], k=50000)\n",
"sample += random.sample(sorted_subset[300_000:], k=5000)\n",
"print(f\"Created a sample of {len(sample):,} with prices ranging from ${sample[-1].price} to ${sample[0].price}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cd1c4d3-b6e4-4f28-8ad4-709c4637626c",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution of prices\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"prices = [float(item.price) for item in sample]\n",
"plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 1000, 20))\n",
"\n",
"plt.title(f\"Avg price ${sum(prices)/len(prices):.2f}\")\n",
"plt.xlabel('Price ($)')\n",
"plt.ylabel('Count of items')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18",
"metadata": {},
"outputs": [],
"source": [
"sizes = [len(item.prompt) for item in sample]\n",
"prices = [item.price for item in sample]\n",
"\n",
"# Create the scatter plot\n",
"plt.figure(figsize=(10, 6))\n",
"plt.scatter(sizes, prices, s=1, color=\"red\")\n",
"\n",
"# Add labels and title\n",
"plt.xlabel('Size')\n",
"plt.ylabel('Price')\n",
"plt.title('Is there a simple correlation?')\n",
"\n",
"# Display the plot\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263",
"metadata": {},
"outputs": [],
"source": [
"# Plot the distribution again to check it looks as expected\n",
"\n",
"token_counts = [item.token_count for item in sample]\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.set_xlabel('Number of tokens')\n",
"ax.set_ylabel('Count of items');\n",
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59ef7aef-b6f6-4042-a2af-ddd5ae1c9999",
"metadata": {},
"outputs": [],
"source": [
"report(sample[-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cacb9059-5f44-4601-860a-30860cebe9c2",
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"random.shuffle(sample)\n",
"train = sample[:200_000]\n",
"test = sample[200_000:]\n",
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf435bcd-accf-427c-82d5-02b33a56737c",
"metadata": {},
"outputs": [],
"source": [
"del items, subset, sorted_subset, selection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b26000a-e5a9-4ab7-83fc-8eb44cb12f94",
"metadata": {},
"outputs": [],
"source": [
"test[0].title"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3615bfdd-f23e-4005-96d8-7b52a1a439be",
"metadata": {},
"outputs": [],
"source": [
"import csv\n",
"with open('test.csv', 'w') as csvfile:\n",
" writer = csv.writer(csvfile)\n",
" for t in test[:200]:\n",
" writer.writerow([t.title, t.details, 0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb6907e9-37d7-4283-b1a9-8124f9f3439b",
"metadata": {},
"outputs": [],
"source": [
"human_predictions = []\n",
"with open('human.csv', 'r') as csvfile:\n",
" reader = csv.reader(csvfile)\n",
" for row in reader:\n",
" human_predictions.append(float(row[2]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f",
"metadata": {},
"outputs": [],
"source": [
"average = sum(t.price for t in train)/len(train)\n",
"average"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95353e68-07ac-4f57-8d57-dd48cacb0e04",
"metadata": {},
"outputs": [],
"source": [
"class TestRunner:\n",
"\n",
" def __init__(self, predictor, data, title, size=200):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title\n",
" self.size = size\n",
" self.guesses = []\n",
" self.truths = []\n",
" self.errors = []\n",
" self.sles = []\n",
" self.colors = []\n",
"\n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" guess = self.predictor(datapoint)\n",
" truth = datapoint.price\n",
" error = abs(guess - truth)\n",
" log_error = math.log(truth+1) - math.log(guess+1)\n",
" sle = log_error ** 2\n",
" color = RED if error>=20 else YELLOW if error>=10 else GREEN\n",
" color_str = \"red\" if error>=20 else \"yellow\" if error>=10 else \"green\"\n",
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color_str)\n",
" print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart(self, title):\n",
" max_error = max(self.errors)\n",
" plt.figure(figsize=(12, 8))\n",
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
" plt.xlabel('Ground Truth')\n",
" plt.ylabel('Model Estimate')\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" average_error = sum(self.errors) / self.size\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" hits = [e for e in self.errors if e<10]\n",
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={len(hits)/self.size*100:.1f}%\"\n",
" self.chart(title)\n",
"\n",
" def run(self):\n",
" self.error = 0\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" self.report()\n",
" return self"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b",
"metadata": {},
"outputs": [],
"source": [
"train_average = sum(t.price for t in train)/len(train)\n",
"\n",
"def flat_predictor(item):\n",
" return train_average"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "739d2e33-55d4-4892-b42c-771131159c8d",
"metadata": {},
"outputs": [],
"source": [
"runner = TestRunner(flat_predictor, test, \"Flat Predictor\").run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d3a5d83-d90b-40af-979f-85aa21816578",
"metadata": {},
"outputs": [],
"source": [
"human_predictions = []\n",
"with open('human.csv', 'r') as csvfile:\n",
" reader = csv.reader(csvfile)\n",
" for row in reader:\n",
" human_predictions.append(float(row[2]))\n",
"\n",
"def human_predictor(item):\n",
" index = test.index(item)\n",
" if index==-1:\n",
" raise ValueError(\"Index not found\")\n",
" return human_predictions[index]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c87c385-d6b7-4a4f-89eb-f4e250337d03",
"metadata": {},
"outputs": [],
"source": [
"runner = TestRunner(human_predictor, test, \"Human Predictor\").run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685",
"metadata": {},
"outputs": [],
"source": [
"stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an', 'on', 'by'])\n",
"\n",
"def words(item):\n",
" text = f\"{item.title} {item.details}\"\n",
" text = re.sub(r'[\\(\\)\\[\\]\\{\\},\\'\"\\- \\s]+', ' ', text)\n",
" words = text.strip().lower().split(' ')\n",
" filtered = [word for word in words if word not in stop]\n",
" return \" \".join(filtered)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56682e9c-46c9-48f2-baea-e943804290f6",
"metadata": {},
"outputs": [],
"source": [
"documents = [words(item) for item in train]\n",
"from collections import Counter\n",
"count = Counter()\n",
"for doc in documents:\n",
" ws = doc.split(\" \")\n",
" for w in ws:\n",
" count[w]+=1\n",
"count.most_common(30)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "262fc576-7606-426c-8aea-5799b3952d2c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.linear_model import LinearRegression\n",
"import numpy as np\n",
"\n",
"np.random.seed(42)\n",
"\n",
"labels = np.array([float(item.price) for item in train])\n",
"\n",
"vectorizer = CountVectorizer()\n",
"X = vectorizer.fit_transform(documents)\n",
"\n",
"regressor = LinearRegression()\n",
"regressor.fit(X, labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd782b21-8e44-409d-a7b6-f136974958b4",
"metadata": {},
"outputs": [],
"source": [
"def linear_regression_predictor(item):\n",
" np.random.seed(42)\n",
" x = vectorizer.transform([words(item)])\n",
" return max(regressor.predict(x)[0], 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80e77aae-0071-42e9-8e24-d3aec5256015",
"metadata": {},
"outputs": [],
"source": [
"TestRunner(linear_regression_predictor, test, \"Linear Regression\", 200).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"from sklearn.svm import SVR\n",
"\n",
"np.random.seed(42)\n",
"\n",
"documents = [words(item) for item in train]\n",
"labels = np.array([float(item.price) for item in train])\n",
"\n",
"vectorizer = TfidfVectorizer()\n",
"X = vectorizer.fit_transform(documents)\n",
"\n",
"regressor = SVR(kernel='linear')\n",
"regressor.fit(X, labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64560112-3bfb-45cc-b489-de619a2eca20",
"metadata": {},
"outputs": [],
"source": [
"def svr_predictor(item):\n",
" np.random.seed(42)\n",
" x = vectorizer.transform([words(item)])\n",
" return max(regressor.predict(x)[0], 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "392598d4-2deb-4935-9175-fd111616b13c",
"metadata": {},
"outputs": [],
"source": [
"TestRunner(svr_predictor, test, \"SVR Accuracy\", 200).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60010699-d26b-4f93-a959-50272ada6a57",
"metadata": {},
"outputs": [],
"source": [
"def messages_for(item):\n",
" system_message = \"You estimate product prices. Reply only with the price, no explanation\"\n",
" user_prompt = item.question()\n",
" return [\n",
" {\"role\": \"system\", \"content\": system_message},\n",
" {\"role\": \"user\", \"content\": user_prompt}\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7",
"metadata": {},
"outputs": [],
"source": [
"def get_price(s):\n",
" s = s.replace('$','').replace(',','')\n",
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n",
" return float(match.group()) if match else 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa",
"metadata": {},
"outputs": [],
"source": [
"def gpt_predictor(item):\n",
" response = openai.chat.completions.create(\n",
" model=OPENAI_MODEL,\n",
" messages=messages_for(item),\n",
" seed=42,\n",
" max_tokens=8\n",
" )\n",
" reply = response.choices[0].message.content\n",
" return get_price(reply)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5",
"metadata": {},
"outputs": [],
"source": [
"runner = TestRunner(gpt_predictor, test, \"GPT-4o Prediction Accuracy\", 200).run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7e24d6b-59a2-464a-95a9-14a9fbfadd4d",
"metadata": {},
"outputs": [],
"source": [
"report(train[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "059b6c74-917f-4cb1-b810-ce70735a57be",
"metadata": {},
"outputs": [],
"source": [
"train_prompts = [item.prompt for item in train]\n",
"train_prices = [item.price for item in train]\n",
"test_prompts = [item.test_prompt() for item in test]\n",
"test_prices = [item.price for item in test]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8ba48cb-da5e-4ddb-8955-8a94e62ea8e0",
"metadata": {},
"outputs": [],
"source": [
"test_prompts[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600",
"metadata": {},
"outputs": [],
"source": [
"# Create a Dataset from the lists\n",
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
"dataset = DatasetDict({\n",
" \"train\": train_dataset,\n",
" \"test\": test_dataset\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e69e26a5-4b24-4e0f-8944-731c534b285b",
"metadata": {},
"outputs": [],
"source": [
"DATASET_NAME = \"ed-donner/multi-instruct\"\n",
"dataset.push_to_hub(DATASET_NAME, private=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}