{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "9491dd8f-8124-4a51-be3a-8f678c149dcf", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "import re\n", "import math\n", "import random\n", "from datetime import datetime\n", "import numpy as np\n", "from typing import Optional\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import anthropic\n", "from huggingface_hub import login\n", "from tqdm import tqdm\n", "import matplotlib.pyplot as plt\n", "from datasets import load_dataset, Dataset, DatasetDict\n", "from concurrent.futures import ProcessPoolExecutor\n", "from transformers import AutoTokenizer\n", "from items_qwen import Item, read_dataset\n", "import pickle" ] }, { "cell_type": "code", "execution_count": null, "id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d", "metadata": {}, "outputs": [], "source": [ "# environment\n", "\n", "load_dotenv()\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", "os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" ] }, { "cell_type": "code", "execution_count": null, "id": "846ded5d-b7f5-4581-8f56-d9650ff329c1", "metadata": {}, "outputs": [], "source": [ "# initialize\n", "\n", "openai = OpenAI()\n", "claude = anthropic.Anthropic()\n", "OPENAI_MODEL = \"gpt-4o-mini\"\n", "CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", "hf_token = os.environ['HF_TOKEN']\n", "login(hf_token, add_to_git_credential=True)\n", "PHI3 = \"microsoft/Phi-3-medium-4k-instruct\"\n", "GEMMA = \"google/gemma-2-9b-it\"\n", "QWEN = \"Qwen/Qwen2-7B\"\n", "LLAMA = \"meta-llama/Meta-Llama-3.1-8B\"" ] }, { "cell_type": "code", "execution_count": null, "id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "id": "028e21ff-4f57-42c3-81de-cb0e83eb0d25", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901", "metadata": {}, "outputs": [], "source": [ "# Constants\n", "\n", "BASE_MODEL = GEMMA\n", "\n", "# Used for writing to output in color\n", "\n", "GREEN = \"\\033[92m\"\n", "YELLOW = \"\\033[93m\"\n", "RED = \"\\033[91m\"\n", "RESET = \"\\033[0m\"" ] }, { "cell_type": "code", "execution_count": null, "id": "3135d11e-1ab5-4cf9-a15f-bdcec7cba5cb", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", "\n", "for i in range(1,1000):\n", " text = str(i)\n", " tok = tokenizer.encode(text, add_special_tokens=False)\n", " print(f\"{text}={tok}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fb2ed609-a00a-4ff8-9f4d-8f2ff8ea26dd", "metadata": {}, "outputs": [], "source": [ "dataset_names = [\n", " \"Automotive\",\n", " \"Electronics\",\n", " \"Office_Products\",\n", " \"Tools_and_Home_Improvement\",\n", " \"Cell_Phones_and_Accessories\",\n", " \"Toys_and_Games\",\n", " \"Appliances\",\n", " \"Musical_Instruments\",\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "29bc061b-1e8d-4df6-ad49-c6148f527d49", "metadata": {}, "outputs": [], "source": [ "# start = datetime.now()\n", "# items = read_dataset(\"Appliances\")\n", "# finish = datetime.now()\n", "# print(f\"Completed in {(finish-start).total_seconds()/60:.1f} mins\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e80316d4-a84e-4b0d-8f36-988fe1bd2e2e", "metadata": {}, "outputs": [], "source": [ "# print(items[10000].prompt)" ] }, { "cell_type": "code", "execution_count": null, "id": "dd11853b-9e21-4b14-9a08-9d9f63636e1a", "metadata": {}, "outputs": [], "source": [ "start = datetime.now()\n", "items = []\n", "with ProcessPoolExecutor(max_workers=6) as pool:\n", " for results in pool.map(read_dataset, dataset_names):\n", " items.extend(results)\n", "finish = datetime.now()\n", "print(f\"Completed in {(finish-start).total_seconds()/60:.1f} mins\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7af95a63-2116-4b62-9d99-1808721410c6", "metadata": {}, "outputs": [], "source": [ "len(items)" ] }, { "cell_type": "code", "execution_count": null, "id": "3d7ab348-3708-4357-b96f-65839f897223", "metadata": {}, "outputs": [], "source": [ "max(item.token_count for item in items)" ] }, { "cell_type": "code", "execution_count": null, "id": "91171f6b-4624-401a-9af5-9c9a1ce434c0", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "b01322eb-d84d-46de-a12b-c2b027bac66d", "metadata": {}, "outputs": [], "source": [ "with open('items2.pkl', 'wb') as file:\n", " pickle.dump(items, file)" ] }, { "cell_type": "code", "execution_count": null, "id": "cecfb43d-3692-475d-b80d-512e6b9b55c2", "metadata": {}, "outputs": [], "source": [ "with open('items2.pkl', 'rb') as file:\n", " items = pickle.load(file)" ] }, { "cell_type": "code", "execution_count": null, "id": "d9949f7f-81d6-45d9-8f8f-d3690a8ffa85", "metadata": {}, "outputs": [], "source": [ "len(items)" ] }, { "cell_type": "code", "execution_count": null, "id": "011bffcf-03f8-4f0d-8999-b53d1ac88624", "metadata": {}, "outputs": [], "source": [ "# Let's investigate:\n", "\n", "print(f\"There are {len(items):,} items with prices\\n\")\n", "print(items[2000000].prompt)" ] }, { "cell_type": "code", "execution_count": null, "id": "fcf74830-1e97-4543-b454-eefd314fc106", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of character count\n", "\n", "lengths = [len(item.prompt) for item in items]\n", "fig, ax = plt.subplots(1, 1)\n", "ax.set_xlabel('Length')\n", "ax.set_ylabel('Count of items');\n", "_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 1000, 50))\n", "\n", "print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of tokens\n", "\n", "token_counts = [item.token_count for item in items]\n", "fig, ax = plt.subplots(1, 1)\n", "ax.set_xlabel('Number of tokens')\n", "ax.set_ylabel('Count of items');\n", "_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 200, 10))" ] }, { "cell_type": "code", "execution_count": null, "id": "da0a20b4-8926-4eff-bf83-11c4f6b40455", "metadata": {}, "outputs": [], "source": [ "def report(item):\n", " prompt = item.prompt\n", " tokens = Item.tokenizer.encode(item.prompt)\n", " print(prompt)\n", " print(tokens[-10:])\n", " print(Item.tokenizer.batch_decode(tokens[-10:]))" ] }, { "cell_type": "code", "execution_count": null, "id": "2378cb92-305a-49d0-8193-4ae09a0cccf8", "metadata": {}, "outputs": [], "source": [ "report(items[1])" ] }, { "cell_type": "code", "execution_count": null, "id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "prices = [float(item.price) for item in items]\n", "fig, ax = plt.subplots(1, 1)\n", "ax.set_xlabel('Price ($)')\n", "ax.set_ylabel('Count of items');\n", "_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 20))\n", "\n", "print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" ] }, { "cell_type": "code", "execution_count": null, "id": "3718a8e6-6c87-4351-8c27-9e61745b0991", "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "category_counts = Counter()\n", "for item in items:\n", " category_counts[item.category]+=1\n", "\n", "categories = category_counts.keys()\n", "counts = [category_counts[category] for category in categories]\n", "\n", "# Create bar chart\n", "plt.figure(figsize=(15, 6))\n", "plt.bar(categories, counts, color=\"purple\")\n", "\n", "# Customize the chart\n", "plt.title('How many in each category')\n", "plt.xlabel('Categories')\n", "plt.ylabel('Count')\n", "\n", "plt.xticks(rotation=30, ha='right')\n", "\n", "# Add value labels on top of each bar\n", "for i, v in enumerate(counts):\n", " plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n", "\n", "# Display the chart\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "e741d816-25d8-4372-9caa-d006b85818be", "metadata": {}, "outputs": [], "source": [ "from collections import defaultdict\n", "slots = defaultdict(list)\n", "for item in items:\n", " slots[round(item.price)].append(item)\n", "\n", "print(f\"\\nMinimum: {min([len(slot) for slot in slots.values()]):,}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "75541d34-37e9-4e59-bac8-87fd1d147d62", "metadata": {}, "outputs": [], "source": [ "np.random.seed(42)\n", "random.seed(42)\n", "sample = []\n", "for i in range(1, 1000):\n", " slot = slots[i]\n", " if i>=240:\n", " sample.extend(slot)\n", " elif len(slot) <= 1200:\n", " sample.extend(slot)\n", " else:\n", " weights = np.array([1 if item.category=='Automotive' else 5 for item in slot])\n", " weights = weights / np.sum(weights)\n", " selected_indices = np.random.choice(len(slot), size=1200, replace=False, p=weights)\n", " selected = [slot[i] for i in selected_indices]\n", " sample.extend(selected)\n", "len(sample)" ] }, { "cell_type": "code", "execution_count": null, "id": "6f56ae0c-c802-436d-be42-6439143b177f", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "prices = [float(item.price) for item in sample]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "a9bdc6d5-3766-4251-81bb-a1bcebe80ce1", "metadata": {}, "outputs": [], "source": [ "category_counts = Counter()\n", "for item in sample:\n", " category_counts[item.category]+=1\n", "\n", "categories = category_counts.keys()\n", "counts = [category_counts[category] for category in categories]\n", "\n", "# Create bar chart\n", "plt.figure(figsize=(15, 6))\n", "plt.bar(categories, counts, color=\"lightgreen\")\n", "\n", "# Customize the chart\n", "plt.title('How many in each category')\n", "plt.xlabel('Categories')\n", "plt.ylabel('Count')\n", "\n", "plt.xticks(rotation=30, ha='right')\n", "\n", "# Add value labels on top of each bar\n", "for i, v in enumerate(counts):\n", " plt.text(i, v, f\"{v:,}\", ha='center', va='bottom')\n", "\n", "# Display the chart\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "8d692349-6814-4b02-aa20-705279ee5295", "metadata": {}, "outputs": [], "source": [ "report(sample[400000])" ] }, { "cell_type": "code", "execution_count": null, "id": "6a28fa68-392a-4d8d-8c72-eeea4615d937", "metadata": {}, "outputs": [], "source": [ "category_counts = Counter()\n", "for item in sample:\n", " category_counts[item.category]+=1\n", "\n", "categories = category_counts.keys()\n", "counts = [category_counts[category] for category in categories]\n", "\n", "plt.figure(figsize=(12, 10))\n", "plt.pie(counts, labels=categories, autopct='%1.0f%%', startangle=90)\n", "\n", "# Add a circle at the center to create a donut chart (optional)\n", "centre_circle = plt.Circle((0,0), 0.70, fc='white')\n", "fig = plt.gcf()\n", "fig.gca().add_artist(centre_circle)\n", "\n", "# Customize the chart\n", "plt.title('Categories')\n", "\n", "# Equal aspect ratio ensures that pie is drawn as a circle\n", "plt.axis('equal') \n", "\n", "# Display the chart\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18", "metadata": {}, "outputs": [], "source": [ "sizes = [len(item.prompt) for item in sample]\n", "prices = [item.price for item in sample]\n", "\n", "# Create the scatter plot\n", "plt.figure(figsize=(15, 8))\n", "plt.scatter(sizes, prices, s=0.2, color=\"red\")\n", "\n", "# Add labels and title\n", "plt.xlabel('Size')\n", "plt.ylabel('Price')\n", "plt.title('Is there a simple correlation?')\n", "\n", "# Display the plot\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution again to check it looks as expected\n", "\n", "token_counts = [item.token_count for item in sample]\n", "fig, ax = plt.subplots(1, 1)\n", "ax.set_xlabel('Number of tokens')\n", "ax.set_ylabel('Count of items');\n", "_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" ] }, { "cell_type": "code", "execution_count": null, "id": "59ef7aef-b6f6-4042-a2af-ddd5ae1c9999", "metadata": {}, "outputs": [], "source": [ "report(sample[-2])" ] }, { "cell_type": "code", "execution_count": null, "id": "cacb9059-5f44-4601-860a-30860cebe9c2", "metadata": {}, "outputs": [], "source": [ "random.seed(42)\n", "random.shuffle(sample)\n", "train = sample[:400_000]\n", "test = sample[400_000:402_000]\n", "print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bf435bcd-accf-427c-82d5-02b33a56737c", "metadata": {}, "outputs": [], "source": [ "# Plot the distribution of prices\n", "\n", "prices = [float(item.price) for item in test[:250]]\n", "plt.figure(figsize=(15, 6))\n", "plt.title(f\"Avg {sum(prices)/len(prices):.2f} and highest {max(prices):,.2f}\\n\")\n", "plt.xlabel('Price ($)')\n", "plt.ylabel('Count')\n", "plt.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 1000, 10))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "6b26000a-e5a9-4ab7-83fc-8eb44cb12f94", "metadata": {}, "outputs": [], "source": [ "# del items, slots\n", "# import gc\n", "# gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "id": "3615bfdd-f23e-4005-96d8-7b52a1a439be", "metadata": {}, "outputs": [], "source": [ "import csv\n", "with open('test.csv', 'w') as csvfile:\n", " writer = csv.writer(csvfile)\n", " for t in test[:200]:\n", " writer.writerow([t.title, t.details, 0])" ] }, { "cell_type": "code", "execution_count": null, "id": "cb6907e9-37d7-4283-b1a9-8124f9f3439b", "metadata": {}, "outputs": [], "source": [ "human_predictions = []\n", "with open('human.csv', 'r') as csvfile:\n", " reader = csv.reader(csvfile)\n", " for row in reader:\n", " human_predictions.append(float(row[2]))" ] }, { "cell_type": "code", "execution_count": null, "id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f", "metadata": {}, "outputs": [], "source": [ "average = sum(t.price for t in train)/len(train)\n", "average" ] }, { "cell_type": "code", "execution_count": null, "id": "95353e68-07ac-4f57-8d57-dd48cacb0e04", "metadata": {}, "outputs": [], "source": [ "class TestRunner:\n", "\n", " def __init__(self, predictor, data, title, size=250):\n", " self.predictor = predictor\n", " self.data = data\n", " self.title = title\n", " self.size = size\n", " self.guesses = []\n", " self.truths = []\n", " self.errors = []\n", " self.sles = []\n", " self.colors = []\n", "\n", " def run_datapoint(self, i):\n", " datapoint = self.data[i]\n", " guess = self.predictor(datapoint)\n", " truth = datapoint.price\n", " error = abs(guess - truth)\n", " log_error = math.log(truth+1) - math.log(guess+1)\n", " sle = log_error ** 2\n", " color = RED if error>=40 else YELLOW if error>=20 else GREEN\n", " color_str = \"red\" if error>=40 else \"yellow\" if error>=20 else \"green\"\n", " title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n", " self.guesses.append(guess)\n", " self.truths.append(truth)\n", " self.errors.append(error)\n", " self.sles.append(sle)\n", " self.colors.append(color_str)\n", " print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n", "\n", " def chart(self, title):\n", " max_error = max(self.errors)\n", " plt.figure(figsize=(12, 8))\n", " plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n", " plt.xlabel('Ground Truth')\n", " plt.ylabel('Model Estimate')\n", " plt.title(title)\n", " plt.show()\n", "\n", " def report(self):\n", " average_error = sum(self.errors) / self.size\n", " rmsle = math.sqrt(sum(self.sles) / self.size)\n", " hits = [e for e in self.errors if e<20]\n", " title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={len(hits)/self.size*100:.1f}%\"\n", " self.chart(title)\n", "\n", " def run(self):\n", " self.error = 0\n", " for i in range(self.size):\n", " self.run_datapoint(i)\n", " self.report()\n", " return self" ] }, { "cell_type": "code", "execution_count": null, "id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b", "metadata": {}, "outputs": [], "source": [ "train_average = sum(t.price for t in train)/len(train)\n", "\n", "def flat_predictor(item):\n", " return train_average" ] }, { "cell_type": "code", "execution_count": null, "id": "739d2e33-55d4-4892-b42c-771131159c8d", "metadata": {}, "outputs": [], "source": [ "runner = TestRunner(flat_predictor, test, \"Flat Predictor\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "9d3a5d83-d90b-40af-979f-85aa21816578", "metadata": {}, "outputs": [], "source": [ "human_predictions = []\n", "with open('human.csv', 'r') as csvfile:\n", " reader = csv.reader(csvfile)\n", " for row in reader:\n", " human_predictions.append(float(row[2]))\n", "\n", "def human_predictor(item):\n", " index = test.index(item)\n", " if index==-1:\n", " raise ValueError(\"Index not found\")\n", " return human_predictions[index]" ] }, { "cell_type": "code", "execution_count": null, "id": "6c87c385-d6b7-4a4f-89eb-f4e250337d03", "metadata": {}, "outputs": [], "source": [ "runner = TestRunner(human_predictor, test, \"Human Predictor\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685", "metadata": {}, "outputs": [], "source": [ "stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an', 'on', 'by'])\n", "\n", "def words(item):\n", " text = f\"{item.title} {item.details}\"\n", " text = re.sub(r'[\\(\\)\\[\\]\\{\\},\\'\"\\- \\s]+', ' ', text)\n", " words = text.strip().lower().split(' ')\n", " filtered = [word for word in words if word not in stop]\n", " return \" \".join(filtered)" ] }, { "cell_type": "code", "execution_count": null, "id": "56682e9c-46c9-48f2-baea-e943804290f6", "metadata": {}, "outputs": [], "source": [ "documents = [words(item) for item in train]\n", "from collections import Counter\n", "count = Counter()\n", "for doc in documents:\n", " ws = doc.split(\" \")\n", " for w in ws:\n", " count[w]+=1\n", "count.most_common(30)" ] }, { "cell_type": "code", "execution_count": null, "id": "262fc576-7606-426c-8aea-5799b3952d2c", "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.linear_model import LinearRegression\n", "import numpy as np\n", "\n", "np.random.seed(42)\n", "\n", "labels = np.array([float(item.price) for item in train])\n", "\n", "vectorizer = CountVectorizer()\n", "X = vectorizer.fit_transform(documents)\n", "\n", "regressor = LinearRegression()\n", "regressor.fit(X, labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "bd782b21-8e44-409d-a7b6-f136974958b4", "metadata": {}, "outputs": [], "source": [ "def linear_regression_predictor(item):\n", " np.random.seed(42)\n", " x = vectorizer.transform([words(item)])\n", " return max(regressor.predict(x)[0], 0)" ] }, { "cell_type": "code", "execution_count": null, "id": "80e77aae-0071-42e9-8e24-d3aec5256015", "metadata": {}, "outputs": [], "source": [ "runner1 = TestRunner(linear_regression_predictor, test, \"Linear Regression\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4", "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.svm import SVR\n", "\n", "np.random.seed(42)\n", "\n", "labels = np.array([float(item.price) for item in train])\n", "\n", "vectorizer = TfidfVectorizer(max_features=20)\n", "X = vectorizer.fit_transform(documents)\n", "\n", "regressor = SVR(kernel='linear')\n", "regressor.fit(X, labels)" ] }, { "cell_type": "code", "execution_count": null, "id": "64560112-3bfb-45cc-b489-de619a2eca20", "metadata": {}, "outputs": [], "source": [ "def svr_predictor(item):\n", " np.random.seed(42)\n", " x = vectorizer.transform([words(item)])\n", " return max(regressor.predict(x)[0], 0)" ] }, { "cell_type": "code", "execution_count": null, "id": "392598d4-2deb-4935-9175-fd111616b13c", "metadata": {}, "outputs": [], "source": [ "runner2 = TestRunner(svr_predictor, test, \"SVR Accuracy\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "60010699-d26b-4f93-a959-50272ada6a57", "metadata": {}, "outputs": [], "source": [ "def messages_for(item):\n", " system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n", " user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\")\n", " return [\n", " {\"role\": \"system\", \"content\": system_message},\n", " {\"role\": \"user\", \"content\": user_prompt}\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7", "metadata": {}, "outputs": [], "source": [ "def get_price(s):\n", " s = s.replace('$','').replace(',','')\n", " match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", " return float(match.group()) if match else 0" ] }, { "cell_type": "code", "execution_count": null, "id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa", "metadata": {}, "outputs": [], "source": [ "def gpt_predictor(item):\n", " response = openai.chat.completions.create(\n", " model=OPENAI_MODEL, #\"gpt-4o-2024-08-06\", \n", " messages=messages_for(item),\n", " seed=42,\n", " max_tokens=6\n", " )\n", " reply = response.choices[0].message.content\n", " return get_price(reply)" ] }, { "cell_type": "code", "execution_count": null, "id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5", "metadata": {}, "outputs": [], "source": [ "runner = TestRunner(gpt_predictor, test, \"GPT-4o-mini Prediction Accuracy\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "c43e2865-ac54-4ea8-9d2c-bbbeedf89029", "metadata": {}, "outputs": [], "source": [ "def frontier_predictor(item):\n", " response = openai.chat.completions.create(\n", " model=\"gpt-4o-2024-08-06\", \n", " messages=messages_for(item),\n", " seed=42,\n", " max_tokens=6\n", " )\n", " reply = response.choices[0].message.content\n", " return get_price(reply)" ] }, { "cell_type": "code", "execution_count": null, "id": "4bf517b2-9a7e-4def-93ba-f728f16d91e2", "metadata": {}, "outputs": [], "source": [ "runner = TestRunner(frontier_predictor, test, \"GPT-4o Frontier Prediction\").run()" ] }, { "cell_type": "code", "execution_count": null, "id": "059b6c74-917f-4cb1-b810-ce70735a57be", "metadata": {}, "outputs": [], "source": [ "train_prompts = [item.prompt for item in train]\n", "train_prices = [item.price for item in train]\n", "test_prompts = [item.test_prompt() for item in test]\n", "test_prices = [item.price for item in test]" ] }, { "cell_type": "code", "execution_count": null, "id": "b8ba48cb-da5e-4ddb-8955-8a94e62ea8e0", "metadata": {}, "outputs": [], "source": [ "test_prompts[2]" ] }, { "cell_type": "code", "execution_count": null, "id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600", "metadata": {}, "outputs": [], "source": [ "# Create a Dataset from the lists\n", "train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", "test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", "dataset = DatasetDict({\n", " \"train\": train_dataset,\n", " \"test\": test_dataset\n", "})" ] }, { "cell_type": "code", "execution_count": null, "id": "e69e26a5-4b24-4e0f-8944-731c534b285b", "metadata": {}, "outputs": [], "source": [ "DATASET_NAME = \"ed-donner/pricer-data\"\n", "dataset.push_to_hub(DATASET_NAME, private=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }