From the uDemy course on LLM engineering.
https://www.udemy.com/course/llm-engineering-master-ai-and-large-language-models
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
799 lines
23 KiB
799 lines
23 KiB
{ |
|
"cells": [ |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "9491dd8f-8124-4a51-be3a-8f678c149dcf", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# imports\n", |
|
"\n", |
|
"import os\n", |
|
"import re\n", |
|
"import math\n", |
|
"import random\n", |
|
"import numpy as np\n", |
|
"from dotenv import load_dotenv\n", |
|
"from openai import OpenAI\n", |
|
"import anthropic\n", |
|
"from huggingface_hub import login\n", |
|
"from tqdm import tqdm\n", |
|
"import matplotlib.pyplot as plt\n", |
|
"from datasets import load_dataset, Dataset, DatasetDict\n", |
|
"from transformers import AutoTokenizer" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "9cd394a2-d8e6-4e8f-a120-50c0ee12620d", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# environment\n", |
|
"\n", |
|
"load_dotenv()\n", |
|
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", |
|
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", |
|
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "846ded5d-b7f5-4581-8f56-d9650ff329c1", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# initialize\n", |
|
"\n", |
|
"openai = OpenAI()\n", |
|
"claude = anthropic.Anthropic()\n", |
|
"OPENAI_MODEL = \"gpt-4o-mini\"\n", |
|
"CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", |
|
"hf_token = os.environ['HF_TOKEN']\n", |
|
"login(hf_token, add_to_git_credential=True)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "e81b23f7-8aa3-4590-ae5c-2d1bebd2f7c9", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"%matplotlib inline" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "8a45e4f9-4fcf-4f72-8db2-54cbb1889901", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Constants\n", |
|
"\n", |
|
"BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", |
|
"\n", |
|
"# Used for writing to output in color\n", |
|
"\n", |
|
"GREEN = \"\\033[92m\"\n", |
|
"YELLOW = \"\\033[93m\"\n", |
|
"RED = \"\\033[91m\"\n", |
|
"RESET = \"\\033[0m\"" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "b606ea85-4171-449d-8eda-a8f1a9b01464", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"#datasets = [\"raw_meta_Electronics\", \"raw_meta_Appliances\", \"raw_meta_Cell_Phones_and_Accessories\", \"raw_meta_Home_and_Kitchen\"]\n", |
|
"# datasets = [\"Electronics\", \"Appliances\", \"Cell_Phones_and_Accessories\", \"Home_and_Kitchen\", \"Tools_and_Home_Improvement\"]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "51af18a2-4122-4753-8f5d-622da2976cb5", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"dataset = load_dataset(\"McAuley-Lab/Amazon-Reviews-2023\", \"raw_meta_Electronics\", split=\"full\", trust_remote_code=True)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "141ddcdd-bd60-44d4-8c63-1c6717f5bafc", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"print(f\"There are {len(dataset):,} items in the dataset\")\n", |
|
"print(\"Here is the first:\")\n", |
|
"item = dataset[0]\n", |
|
"print(item['title'])\n", |
|
"print(item['description'])\n", |
|
"print(item['features'])\n", |
|
"print(item['price'])" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "f36c948d-e14d-44a0-9704-c11c589a26ee", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"class Item:\n", |
|
"\n", |
|
" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n", |
|
"\n", |
|
" def __init__(self, data):\n", |
|
" self.title = data['title']\n", |
|
" self.description = self.clean(data['description'])\n", |
|
" self.features = self.clean(data['features'])\n", |
|
" self.price = float(data['price'])\n", |
|
" self.price_str = str(round(self.price))\n", |
|
" self._token_count = None\n", |
|
" self.full_prompt = self.make_full_prompt()\n", |
|
" self.prompt = self.full_prompt.split('Price is $')[0] + 'Price is $'\n", |
|
" self.label = self.full_prompt.split('Price is $')[1]\n", |
|
"\n", |
|
" def clean(self, details):\n", |
|
" result = ' '.join(details)\n", |
|
" return re.sub(r'[\\[\\]【】\\s]+', ' ', result).strip()\n", |
|
"\n", |
|
" def question(self):\n", |
|
" prompt = \"How much does this cost?\\n\"\n", |
|
" prompt += f\"Title: {self.title}\\n\"\n", |
|
" prompt += f\"Description: {self.description}\\n\"\n", |
|
" prompt += f\"Features: {self.features}\\n\"\n", |
|
" return prompt\n", |
|
"\n", |
|
" def messages(self):\n", |
|
" return [\n", |
|
" {\"role\":\"system\", \"content\": \"You estimate product prices. Reply only with the price to the nearest dollar\"},\n", |
|
" {\"role\":\"user\", \"content\": self.question()},\n", |
|
" {\"role\":\"assistant\", \"content\": f\"Price is ${self.price_str}.00\"}\n", |
|
" ]\n", |
|
"\n", |
|
" def make_full_prompt(self):\n", |
|
" prompt = self.tokenizer.apply_chat_template(self.messages(), tokenize=False, add_generation_prompt=False)\n", |
|
" groups = prompt.split('\\n\\n')\n", |
|
" return groups[0]+'\\n\\n'+'\\n\\n'.join(groups[2:])\n", |
|
"\n", |
|
" def token_count(self):\n", |
|
" if self._token_count == None:\n", |
|
" self._token_count = len(self.tokenizer.encode(self.full_prompt))\n", |
|
" return self._token_count\n", |
|
"\n", |
|
" def tokens_between(self, low, high):\n", |
|
" token_count = self.token_count()\n", |
|
" return token_count >= low and token_count < high" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "059152d0-a68a-4e93-b759-45f3c6baf31e", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Create a list called \"items\" with all our datapoints that have a valid price\n", |
|
"\n", |
|
"from collections import Counter\n", |
|
"counts = Counter()\n", |
|
"items = []\n", |
|
"for data in tqdm(dataset):\n", |
|
" try:\n", |
|
" price_str = data['price']\n", |
|
" if float(price_str) > 0:\n", |
|
" items.append(Item(data))\n", |
|
" except ValueError:\n", |
|
" counts[data['price']]+=1" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "8752310a-ca69-4d43-b8bd-fd98aebbc805", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"counts.most_common(10)\n" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "011bffcf-03f8-4f0d-8999-b53d1ac88624", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Let's investigate:\n", |
|
"\n", |
|
"print(f\"There are {len(items):,} out of {len(dataset):,} with prices\\n\")\n", |
|
"print(f\"Item 0 has {items[0].token_count()} tokens:\\n\")\n", |
|
"print(items[0].full_prompt)\n", |
|
"print(f\"\\nItem 1 has {items[1].token_count()} tokens:\\n\")\n", |
|
"print(items[1].full_prompt)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "fcf74830-1e97-4543-b454-eefd314fc106", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution of character count\n", |
|
"\n", |
|
"lengths = [len(item.full_prompt) for item in items]\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Length')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 5000, 250))\n", |
|
"\n", |
|
"print(f\"Average length is {sum(lengths)/len(lengths):,.1f} and highest length is {max(lengths):,}\\n\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "af1d6c8b-f2ae-4691-9306-989b1bd45233", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"print(f\"There are total {len(items):,} items\")\n", |
|
"cutoff = 1500\n", |
|
"selection = [item for item in items if len(item.full_prompt) < cutoff]\n", |
|
"print(f\"There are total {len(selection):,} with under {cutoff:,} character training prompt\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "42231dc7-66fb-4437-ba08-7689514a8b19", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Calculate token sizes in selection\n", |
|
"\n", |
|
"token_counts = [item.token_count() for item in tqdm(selection)]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "d5dde349-610a-4e96-a2ea-9178a9c1fa2a", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution of tokens\n", |
|
"\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Number of tokens')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(token_counts, rwidth=0.7, color=\"orange\", bins=range(0, 500, 25))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "da0a20b4-8926-4eff-bf83-11c4f6b40455", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def report(item):\n", |
|
" prompt = item.full_prompt\n", |
|
" tokens = Item.tokenizer.encode(item.full_prompt)\n", |
|
" print(prompt)\n", |
|
" print(tokens[-8:])\n", |
|
" print(Item.tokenizer.batch_decode(tokens[-8:]))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "2378cb92-305a-49d0-8193-4ae09a0cccf8", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"report(items[0])" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "1232004a-ff9b-486a-a14b-70f21c217c8d", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Let's limit our dataset to documents with 60-180 tokens\n", |
|
"\n", |
|
"low_cutoff = 80\n", |
|
"high_cutoff = 240\n", |
|
"subset = [item for item in tqdm(selection) if item.tokens_between(low_cutoff, high_cutoff)]\n", |
|
"subset_count = len(subset)\n", |
|
"count = len(items)\n", |
|
"print(f\"\\nBetween {low_cutoff} and {high_cutoff}, we get {subset_count:,} out of {count:,} which is {subset_count/count*100:.1f}%\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "7bc11e4f-5a15-48fd-b571-92e2e10b0323", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution again to check it looks as expected\n", |
|
"\n", |
|
"token_counts = [item.token_count() for item in subset]\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Number of tokens')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "50d88feb-d0ee-4abf-a013-7d11a7e4e2cd", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution of prices\n", |
|
"\n", |
|
"prices = [float(item.price) for item in subset]\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Price ($)')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(prices, rwidth=0.7, color=\"darkblue\", bins=range(0, 500, 20))\n", |
|
"\n", |
|
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "3718a8e6-6c87-4351-8c27-9e61745b0991", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Pick the most expensive 52,000 items, then pick 12,000 of the next 20,000\n", |
|
"\n", |
|
"random.seed(42)\n", |
|
"sorted_subset = sorted(subset, key=lambda item: item.price, reverse=True)\n", |
|
"top_30k = sorted_subset[:62000]\n", |
|
"# other_12k = random.sample(sorted_subset[30000:50000], k=12000)\n", |
|
"# sample = top_30k + other_12k\n", |
|
"sample = top_30k\n", |
|
"print(len(sample))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "3cd1c4d3-b6e4-4f28-8ad4-709c4637626c", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution of prices\n", |
|
"\n", |
|
"prices = [float(item.price) for item in sample]\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Price ($)')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 500, 20))\n", |
|
"\n", |
|
"print(f\"Average price is ${sum(prices)/len(prices):.2f} and highest price is ${max(prices):,.2f}\\n\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "38d31aa3-8a3a-4626-9c50-f55635ca6d18", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"sizes = [len(item.full_prompt) for item in sample]\n", |
|
"prices = [item.price for item in sample]\n", |
|
"\n", |
|
"# Create the scatter plot\n", |
|
"plt.figure(figsize=(10, 6))\n", |
|
"plt.scatter(sizes, prices, s=2, color=\"red\")\n", |
|
"\n", |
|
"# Add labels and title\n", |
|
"plt.xlabel('Size')\n", |
|
"plt.ylabel('Price')\n", |
|
"plt.title('Is there a simple correlation?')\n", |
|
"\n", |
|
"# Display the plot\n", |
|
"plt.show()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "f8cfa1af-aadd-416b-b0f9-2bb5fd4d2263", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Plot the distribution again to check it looks as expected\n", |
|
"\n", |
|
"token_counts = [item.token_count() for item in sample]\n", |
|
"fig, ax = plt.subplots(1, 1)\n", |
|
"ax.set_xlabel('Number of tokens')\n", |
|
"ax.set_ylabel('Count of items');\n", |
|
"_ = ax.hist(token_counts, rwidth=0.7, color=\"purple\", bins=range(0, 300, 10))" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "59ef7aef-b6f6-4042-a2af-ddd5ae1c9999", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"report(sample[0])" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "cacb9059-5f44-4601-860a-30860cebe9c2", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"random.seed(42)\n", |
|
"random.shuffle(sample)\n", |
|
"train = sample[:60000]\n", |
|
"test = sample[60000:]\n", |
|
"print(f\"Divided into a training set of {len(train):,} items and test set of {len(test):,} items\")" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "dd7c5db1-4510-4768-bef1-bdac2a7b392f", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"average = sum(t.price for t in train)/len(train)\n", |
|
"average" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "95353e68-07ac-4f57-8d57-dd48cacb0e04", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"class TestRunner:\n", |
|
"\n", |
|
" def __init__(self, predictor, data, title, size=None):\n", |
|
" self.predictor = predictor\n", |
|
" self.data = data\n", |
|
" self.size = size or len(data)\n", |
|
" self.guesses = []\n", |
|
" self.truths = []\n", |
|
" self.errors = []\n", |
|
" self.title = title\n", |
|
"\n", |
|
" def run_datapoint(self, i):\n", |
|
" datapoint = self.data[i]\n", |
|
" guess = self.predictor(datapoint)\n", |
|
" truth = datapoint.price\n", |
|
" error = abs(guess - truth)\n", |
|
" color = RED if error>=20 else YELLOW if error>=10 else GREEN\n", |
|
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n", |
|
" self.guesses.append(guess)\n", |
|
" self.truths.append(truth)\n", |
|
" self.errors.append(error)\n", |
|
" print(f\"{color}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} Item: {title}{RESET}\")\n", |
|
"\n", |
|
" def chart(self):\n", |
|
" max_error = max(self.errors)\n", |
|
" colors = [(max_error - error)**3 for error in self.errors]\n", |
|
" plt.figure(figsize=(10, 6))\n", |
|
" plt.scatter(self.truths, self.guesses, s=3, c=colors, cmap='RdYlGn')\n", |
|
" plt.xlabel('Truth')\n", |
|
" plt.ylabel('Guess')\n", |
|
" plt.title(self.title)\n", |
|
" plt.show()\n", |
|
"\n", |
|
" def run(self):\n", |
|
" self.error = 0\n", |
|
" for i in range(self.size):\n", |
|
" self.run_datapoint(i)\n", |
|
" average_error = sum(self.errors) / self.size\n", |
|
" print(f\"Average Error = ${average_error:,.2f}\")\n", |
|
" hits = [e for e in self.errors if e<10]\n", |
|
" print(f\"Hit rate = {len(hits)/self.size*100:.1f}%\")\n", |
|
" self.chart()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "e3a8519f-c139-4c72-8d9c-39ccedda2f7b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"train_average = sum(t.price for t in train)/len(train)\n", |
|
"\n", |
|
"def flat_predictor(item):\n", |
|
" return train_average" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "739d2e33-55d4-4892-b42c-771131159c8d", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"TestRunner(flat_predictor, test, \"Flat Predictor Accuracy\", 100).run()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "d6a6c4a5-e817-46b8-99d2-9c4ecf9c8685", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"stop = set(['the', 'and', 'for', 'is', 'to', 'this', 'with', 'a', 'of', 'your', 'are', 'in','from', 'you', 'or', 'an'])\n", |
|
"\n", |
|
"def words(item):\n", |
|
" text = f\"{item.title} {item.description} {item.features}\"\n", |
|
" text = re.sub(r'[()\\[\\]{},\\'\"-]', ' ', text)\n", |
|
" text = re.sub(r'\\s+', ' ', text)\n", |
|
" words = text.strip().lower().split(' ')\n", |
|
" filtered = [word for word in words if word not in stop]\n", |
|
" return \" \".join(filtered)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "262fc576-7606-426c-8aea-5799b3952d2c", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"from sklearn.feature_extraction.text import CountVectorizer\n", |
|
"from sklearn.linear_model import LinearRegression\n", |
|
"import numpy as np\n", |
|
"\n", |
|
"np.random.seed(42)\n", |
|
"\n", |
|
"documents = [words(item) for item in train]\n", |
|
"labels = np.array([float(item.price) for item in train])\n", |
|
"\n", |
|
"vectorizer = CountVectorizer()\n", |
|
"X = vectorizer.fit_transform(documents)\n", |
|
"\n", |
|
"regressor = LinearRegression()\n", |
|
"regressor.fit(X, labels)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "bd782b21-8e44-409d-a7b6-f136974958b4", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def linear_regression_predictor(item):\n", |
|
" np.random.seed(42)\n", |
|
" x = vectorizer.transform([words(item)])\n", |
|
" return max(regressor.predict(x)[0], 0)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "80e77aae-0071-42e9-8e24-d3aec5256015", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"TestRunner(linear_regression_predictor, test, \"Linear Accuracy\", 200).run()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "a70d16ce-bdf1-4071-8c5a-5bddc2aa37e4", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"from sklearn.feature_extraction.text import TfidfVectorizer\n", |
|
"from sklearn.svm import SVR\n", |
|
"\n", |
|
"np.random.seed(42)\n", |
|
"\n", |
|
"documents = [words(item) for item in train]\n", |
|
"labels = np.array([float(item.price) for item in train])\n", |
|
"\n", |
|
"vectorizer = TfidfVectorizer()\n", |
|
"X = vectorizer.fit_transform(documents)\n", |
|
"\n", |
|
"regressor = SVR(kernel='linear')\n", |
|
"regressor.fit(X, labels)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "64560112-3bfb-45cc-b489-de619a2eca20", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def svr_predictor(item):\n", |
|
" np.random.seed(42)\n", |
|
" x = vectorizer.transform([words(item)])\n", |
|
" return max(regressor.predict(x)[0], 0)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "392598d4-2deb-4935-9175-fd111616b13c", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"TestRunner(svr_predictor, test, \"SVR Accuracy\", 200).run()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "60010699-d26b-4f93-a959-50272ada6a57", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def messages_for(item):\n", |
|
" system_message = \"You estimate product prices. Reply only with the price, no explanation\"\n", |
|
" user_prompt = item.question()\n", |
|
" return [\n", |
|
" {\"role\": \"system\", \"content\": system_message},\n", |
|
" {\"role\": \"user\", \"content\": user_prompt}\n", |
|
" ]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "2d5c1a62-9c6e-4c1c-b051-95a78e6e32a7", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def get_price(s):\n", |
|
" s = s.replace('$','').replace(',','')\n", |
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)\n", |
|
" return float(match.group()) if match else 0" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "9c845d34-1c73-4636-a6ec-cc6666bb39fa", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"def gpt_predictor(item):\n", |
|
" response = openai.chat.completions.create(\n", |
|
" model=OPENAI_MODEL,\n", |
|
" messages=messages_for(item),\n", |
|
" seed=42,\n", |
|
" max_tokens=8\n", |
|
" )\n", |
|
" reply = response.choices[0].message.content\n", |
|
" return get_price(reply)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "1b3eb3ef-90a8-4642-b503-c22e72c457f5", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"TestRunner(gpt_predictor, test, \"GPT-4o-mini Prediction Accuracy\", 200).run()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "f7e24d6b-59a2-464a-95a9-14a9fbfadd4d", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"train[0].full_prompt" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "059b6c74-917f-4cb1-b810-ce70735a57be", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"train_prompts = [item.full_prompt for item in train]\n", |
|
"train_prices = [item.price for item in train]\n", |
|
"test_prompts = [item.prompt for item in test]\n", |
|
"test_prices = [item.price for item in test]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "b8ba48cb-da5e-4ddb-8955-8a94e62ea8e0", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "f9ee2e90-79b6-4232-b955-b1c67bc3d600", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Create a Dataset from the lists\n", |
|
"train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n", |
|
"test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n", |
|
"dataset = DatasetDict({\n", |
|
" \"train\": train_dataset,\n", |
|
" \"test\": test_dataset\n", |
|
"})" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "e69e26a5-4b24-4e0f-8944-731c534b285b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"DATASET_NAME = \"ed-donner/electronics-instruct\"\n", |
|
"dataset.push_to_hub(DATASET_NAME, private=True)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "0282b9c5-019b-4e1c-910c-3f86b46b35dd", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
} |
|
], |
|
"metadata": { |
|
"kernelspec": { |
|
"display_name": "Python 3 (ipykernel)", |
|
"language": "python", |
|
"name": "python3" |
|
}, |
|
"language_info": { |
|
"codemirror_mode": { |
|
"name": "ipython", |
|
"version": 3 |
|
}, |
|
"file_extension": ".py", |
|
"mimetype": "text/x-python", |
|
"name": "python", |
|
"nbconvert_exporter": "python", |
|
"pygments_lexer": "ipython3", |
|
"version": "3.11.10" |
|
} |
|
}, |
|
"nbformat": 4, |
|
"nbformat_minor": 5 |
|
}
|
|
|