You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

915 lines
25 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "db8736a7-ed94-441c-9556-831fa57b5a10",
"metadata": {},
"source": [
"# The Product Pricer Continued\n",
"\n",
"A model that can estimate how much something costs, from its description.\n",
"\n",
"## Baseline Models\n",
"\n",
"Today we work on the simplest models to act as a starting point that we will beat."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "681c717b-4c24-4ac3-a5f3-3c5881d6e70a",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import math\n",
"import json\n",
"import random\n",
"from dotenv import load_dotenv\n",
"from huggingface_hub import login\n",
"from items import Item\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pickle\n",
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "933b6e75-3661-4f30-b0b5-c28d04e3748e",
"metadata": {},
"outputs": [],
"source": [
"# More imports for our traditional machine learning\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"from sklearn.preprocessing import StandardScaler"
]
},
{
"cell_type": "markdown",
"id": "b3c87c11-8dbe-4b8c-8989-01e3d3a60026",
"metadata": {},
"source": [
"## NLP imports\n",
"\n",
"In the next cell, we have more imports for our NLP related machine learning. \n",
"If the gensim import gives you an error like \"Cannot import name 'triu' from 'scipy.linalg' then please run in another cell: \n",
"`!pip install \"scipy<1.13\"` \n",
"As described on StackOverflow [here](https://stackoverflow.com/questions/78279136/importerror-cannot-import-name-triu-from-scipy-linalg-when-importing-gens). \n",
"Many thanks to students Arnaldo G and Ard V for sorting this."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42cf33b7-7abd-44ba-9780-c156b70473b5",
"metadata": {},
"outputs": [],
"source": [
"# NLP related imports\n",
"\n",
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from gensim.models import Word2Vec\n",
"from gensim.utils import simple_preprocess"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1ac3ec0-183c-4a12-920b-b06397f86815",
"metadata": {},
"outputs": [],
"source": [
"# Finally, more imports for more advanced machine learning\n",
"\n",
"from sklearn.svm import LinearSVR\n",
"from sklearn.ensemble import RandomForestRegressor"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c01ee5f-c4fc-44fe-9d3a-907e8a0426d2",
"metadata": {},
"outputs": [],
"source": [
"# Constants - used for printing to stdout in color\n",
"\n",
"GREEN = \"\\033[92m\"\n",
"YELLOW = \"\\033[93m\"\n",
"RED = \"\\033[91m\"\n",
"RESET = \"\\033[0m\"\n",
"COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36d05bdc-0155-4c72-a7ee-aa4e614ffd3c",
"metadata": {},
"outputs": [],
"source": [
"# environment\n",
"\n",
"load_dotenv(override=True)\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4dd3aad2-6f99-433c-8792-e461d2f06622",
"metadata": {},
"outputs": [],
"source": [
"# Log in to HuggingFace\n",
"\n",
"hf_token = os.environ['HF_TOKEN']\n",
"login(hf_token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c830ed3e-24ee-4af6-a07b-a1bfdcd39278",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "5105e13c-bca0-4c70-bfaa-649345f53322",
"metadata": {},
"source": [
"# Loading the pkl files\n",
"\n",
"Let's avoid curating all our data again! Load in the pickle files\n",
"\n",
"If you didn't already create these in Day 2, you can also download them from my google drive (you'll also find the slides here): \n",
"https://drive.google.com/drive/folders/1JwNorpRHdnf_pU0GE5yYtfKlyrKC3CoV?usp=sharing\n",
"\n",
"But note that the files are quite large - you might need to get a coffee!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c9b05f4-c9eb-462c-8d86-de9140a2d985",
"metadata": {},
"outputs": [],
"source": [
"with open('train.pkl', 'rb') as file:\n",
" train = pickle.load(file)\n",
"\n",
"with open('test.pkl', 'rb') as file:\n",
" test = pickle.load(file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a84638f7-5ff7-4f54-8751-3ef156264aee",
"metadata": {},
"outputs": [],
"source": [
"# Remind ourselves the training prompt\n",
"\n",
"print(train[0].prompt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7619c85-6e9e-48a1-8efe-c6a60471b87c",
"metadata": {},
"outputs": [],
"source": [
"# Remind a test prompt\n",
"\n",
"print(train[0].price)"
]
},
{
"cell_type": "markdown",
"id": "bcccf130-125a-4958-bac3-f46dfcb29b3f",
"metadata": {},
"source": [
"## Unveiling a mighty script that we will use a lot!\n",
"\n",
"A rather pleasing Test Harness that will evaluate any model against 250 items from the Test set\n",
"\n",
"And show us the results in a visually satisfying way.\n",
"\n",
"You write a function of this form:\n",
"\n",
"```\n",
"def my_prediction_function(item):\n",
" # my code here\n",
" return my_estimate\n",
"```\n",
"\n",
"And then you call:\n",
"\n",
"`Tester.test(my_prediction_function)`\n",
"\n",
"To evaluate your model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5793f5c-e23e-4a74-9496-1e30dd1e8935",
"metadata": {},
"outputs": [],
"source": [
"class Tester:\n",
"\n",
" def __init__(self, predictor, title=None, data=test, size=250):\n",
" self.predictor = predictor\n",
" self.data = data\n",
" self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
" self.size = size\n",
" self.guesses = []\n",
" self.truths = []\n",
" self.errors = []\n",
" self.sles = []\n",
" self.colors = []\n",
"\n",
" def color_for(self, error, truth):\n",
" if error<40 or error/truth < 0.2:\n",
" return \"green\"\n",
" elif error<80 or error/truth < 0.4:\n",
" return \"orange\"\n",
" else:\n",
" return \"red\"\n",
" \n",
" def run_datapoint(self, i):\n",
" datapoint = self.data[i]\n",
" guess = self.predictor(datapoint)\n",
" truth = datapoint.price\n",
" error = abs(guess - truth)\n",
" log_error = math.log(truth+1) - math.log(guess+1)\n",
" sle = log_error ** 2\n",
" color = self.color_for(error, truth)\n",
" title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n",
" self.guesses.append(guess)\n",
" self.truths.append(truth)\n",
" self.errors.append(error)\n",
" self.sles.append(sle)\n",
" self.colors.append(color)\n",
" print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
"\n",
" def chart(self, title):\n",
" max_error = max(self.errors)\n",
" plt.figure(figsize=(12, 8))\n",
" max_val = max(max(self.truths), max(self.guesses))\n",
" plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
" plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
" plt.xlabel('Ground Truth')\n",
" plt.ylabel('Model Estimate')\n",
" plt.xlim(0, max_val)\n",
" plt.ylim(0, max_val)\n",
" plt.title(title)\n",
" plt.show()\n",
"\n",
" def report(self):\n",
" average_error = sum(self.errors) / self.size\n",
" rmsle = math.sqrt(sum(self.sles) / self.size)\n",
" hits = sum(1 for color in self.colors if color==\"green\")\n",
" title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
" self.chart(title)\n",
"\n",
" def run(self):\n",
" self.error = 0\n",
" for i in range(self.size):\n",
" self.run_datapoint(i)\n",
" self.report()\n",
"\n",
" @classmethod\n",
" def test(cls, function):\n",
" cls(function).run()"
]
},
{
"cell_type": "markdown",
"id": "066fef03-8338-4526-9df3-89b649ad4f0a",
"metadata": {},
"source": [
"# Now for something basic\n",
"\n",
"What's the very simplest model you could imagine?\n",
"\n",
"Let's start with a random number generator!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "66ea68e8-ab1b-4f0d-aba4-a59574d8f85e",
"metadata": {},
"outputs": [],
"source": [
"def random_pricer(item):\n",
" return random.randrange(1,1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53d941cb-5b73-44ea-b893-3a0ce9997066",
"metadata": {},
"outputs": [],
"source": [
"# Set the random seed\n",
"\n",
"random.seed(42)\n",
"\n",
"# Run our TestRunner\n",
"Tester.test(random_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97451c73-9c1b-43a8-b3b9-9c41942e48a2",
"metadata": {},
"outputs": [],
"source": [
"# That was fun!\n",
"# We can do better - here's another rather trivial model\n",
"\n",
"training_prices = [item.price for item in train]\n",
"training_average = sum(training_prices) / len(training_prices)\n",
"\n",
"def constant_pricer(item):\n",
" return training_average"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cf384eb-30c2-40d8-b7e5-48942ac6a969",
"metadata": {},
"outputs": [],
"source": [
"# Run our constant predictor\n",
"Tester.test(constant_pricer)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6173a4b9-f0b9-407b-a1de-dc8f045adb8f",
"metadata": {},
"outputs": [],
"source": [
"train[0].details"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce16eee8-bb34-4914-9aa5-57e30a567842",
"metadata": {},
"outputs": [],
"source": [
"# Create a new \"features\" field on items, and populate it with json parsed from the details dict\n",
"\n",
"for item in train:\n",
" item.features = json.loads(item.details)\n",
"for item in test:\n",
" item.features = json.loads(item.details)\n",
"\n",
"# Look at one"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac702a10-dccb-43d4-887b-6f92a0fb298f",
"metadata": {},
"outputs": [],
"source": [
"train[0].features.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd7a41c5-0c51-41be-a61d-8e80c3e90930",
"metadata": {},
"outputs": [],
"source": [
"# Look at 20 most common features in training set\n",
"\n",
"feature_count = Counter()\n",
"for item in train:\n",
" for f in item.features.keys():\n",
" feature_count[f]+=1\n",
"\n",
"feature_count.most_common(40)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cef84a9-4932-48fd-9f7a-51cfc06e3216",
"metadata": {},
"outputs": [],
"source": [
"# Now some janky code to pluck out the Item Weight\n",
"# Don't worry too much about this: spoiler alert, it's not going to be much use in training!\n",
"\n",
"def get_weight(item):\n",
" weight_str = item.features.get('Item Weight')\n",
" if weight_str:\n",
" parts = weight_str.split(' ')\n",
" amount = float(parts[0])\n",
" unit = parts[1].lower()\n",
" if unit==\"pounds\":\n",
" return amount\n",
" elif unit==\"ounces\":\n",
" return amount / 16\n",
" elif unit==\"grams\":\n",
" return amount / 453.592\n",
" elif unit==\"milligrams\":\n",
" return amount / 453592\n",
" elif unit==\"kilograms\":\n",
" return amount / 0.453592\n",
" elif unit==\"hundredths\" and parts[2].lower()==\"pounds\":\n",
" return amount / 100\n",
" else:\n",
" print(weight_str)\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4848b4a-3c5a-4168-83a5-57a1f3ff270d",
"metadata": {},
"outputs": [],
"source": [
"weights = [get_weight(t) for t in train]\n",
"weights = [w for w in weights if w]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0cd11cc8-f16e-4991-b531-482189ddc4b6",
"metadata": {},
"outputs": [],
"source": [
"average_weight = sum(weights)/len(weights)\n",
"average_weight"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efe8ec7f-9777-464f-a809-b06b7033bdb2",
"metadata": {},
"outputs": [],
"source": [
"def get_weight_with_default(item):\n",
" weight = get_weight(item)\n",
" return weight or average_weight"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c2659fef-a455-431a-9a0e-59342b80084b",
"metadata": {},
"outputs": [],
"source": [
"def get_rank(item):\n",
" rank_dict = item.features.get(\"Best Sellers Rank\")\n",
" if rank_dict:\n",
" ranks = rank_dict.values()\n",
" return sum(ranks)/len(ranks)\n",
" return None"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20b9b5be-30bc-4d3a-8492-fbae119421a0",
"metadata": {},
"outputs": [],
"source": [
"ranks = [get_rank(t) for t in train]\n",
"ranks = [r for r in ranks if r]\n",
"average_rank = sum(ranks)/len(ranks)\n",
"average_rank"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "081e646a-ea50-4ec3-9512-6d5f96f8aef6",
"metadata": {},
"outputs": [],
"source": [
"def get_rank_with_default(item):\n",
" rank = get_rank(item)\n",
" return rank or average_rank"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afd5daf7-cb2b-47af-bf17-dd71a9db65d0",
"metadata": {},
"outputs": [],
"source": [
"def get_text_length(item):\n",
" return len(item.test_prompt())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85c89012-a922-401b-8a3b-94af641bf27a",
"metadata": {},
"outputs": [],
"source": [
"# investigate the brands\n",
"\n",
"brands = Counter()\n",
"for t in train:\n",
" brand = t.features.get(\"Brand\")\n",
" if brand:\n",
" brands[brand]+=1\n",
"\n",
"# Look at most common 40 brands\n",
"\n",
"brands.most_common(40)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "386dde54-e028-4a6d-b291-cce889ac1fa3",
"metadata": {},
"outputs": [],
"source": [
"TOP_ELECTRONICS_BRANDS = [\"hp\", \"dell\", \"lenovo\", \"samsung\", \"asus\", \"sony\", \"canon\", \"apple\", \"intel\"]\n",
"def is_top_electronics_brand(item):\n",
" brand = item.features.get(\"Brand\")\n",
" return brand and brand.lower() in TOP_ELECTRONICS_BRANDS"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c31c9c59-9d0d-47a8-a046-f20ed8d38d4c",
"metadata": {},
"outputs": [],
"source": [
"def get_features(item):\n",
" return {\n",
" \"weight\": get_weight_with_default(item),\n",
" \"rank\": get_rank_with_default(item),\n",
" \"text_length\": get_text_length(item),\n",
" \"is_top_electronics_brand\": 1 if is_top_electronics_brand(item) else 0\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88850855-f5bd-4be2-9d7c-75bf8a21609b",
"metadata": {},
"outputs": [],
"source": [
"# Look at features in a training item\n",
"get_features(train[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee9b5298-68b7-497d-8b2e-875287bb25b2",
"metadata": {},
"outputs": [],
"source": [
"# A utility function to convert our features into a pandas dataframe\n",
"\n",
"def list_to_dataframe(items):\n",
" features = [get_features(item) for item in items]\n",
" df = pd.DataFrame(features)\n",
" df['price'] = [item.price for item in items]\n",
" return df\n",
"\n",
"train_df = list_to_dataframe(train)\n",
"test_df = list_to_dataframe(test[:250])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc1d68e0-ab33-40f4-9334-461d426af25c",
"metadata": {},
"outputs": [],
"source": [
"# Traditional Linear Regression!\n",
"\n",
"np.random.seed(42)\n",
"\n",
"# Separate features and target\n",
"feature_columns = ['weight', 'rank', 'text_length', 'is_top_electronics_brand']\n",
"\n",
"X_train = train_df[feature_columns]\n",
"y_train = train_df['price']\n",
"X_test = test_df[feature_columns]\n",
"y_test = test_df['price']\n",
"\n",
"# Train a Linear Regression\n",
"model = LinearRegression()\n",
"model.fit(X_train, y_train)\n",
"\n",
"for feature, coef in zip(feature_columns, model.coef_):\n",
" print(f\"{feature}: {coef}\")\n",
"print(f\"Intercept: {model.intercept_}\")\n",
"\n",
"# Predict the test set and evaluate\n",
"y_pred = model.predict(X_test)\n",
"mse = mean_squared_error(y_test, y_pred)\n",
"r2 = r2_score(y_test, y_pred)\n",
"\n",
"print(f\"Mean Squared Error: {mse}\")\n",
"print(f\"R-squared Score: {r2}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6561c3c7-ac7f-458b-983c-4a164b9d02c3",
"metadata": {},
"outputs": [],
"source": [
"# Function to predict price for a new item\n",
"\n",
"def linear_regression_pricer(item):\n",
" features = get_features(item)\n",
" features_df = pd.DataFrame([features])\n",
" return model.predict(features_df)[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bf2caa4-657a-4fc6-9dcb-bed7eaf8dd65",
"metadata": {},
"outputs": [],
"source": [
"# test it\n",
"\n",
"Tester.test(linear_regression_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79e1574b-52ef-49cc-bfb5-e97252ed5db8",
"metadata": {},
"outputs": [],
"source": [
"# For the next few models, we prepare our documents and prices\n",
"# Note that we use the test prompt for the documents, otherwise we'll reveal the answer!!\n",
"\n",
"prices = np.array([float(item.price) for item in train])\n",
"documents = [item.test_prompt() for item in train]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e126c22e-53e7-4967-9ebb-6b7dd7fe4ade",
"metadata": {},
"outputs": [],
"source": [
"# Use the CountVectorizer for a Bag of Words model\n",
"\n",
"np.random.seed(42)\n",
"vectorizer = CountVectorizer(max_features=1000, stop_words='english')\n",
"X = vectorizer.fit_transform(documents)\n",
"regressor = LinearRegression()\n",
"regressor.fit(X, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b7148d3-3202-4536-a75c-1627495c51d3",
"metadata": {},
"outputs": [],
"source": [
"def bow_lr_pricer(item):\n",
" x = vectorizer.transform([item.test_prompt()])\n",
" return max(regressor.predict(x)[0], 0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38f7f7d0-d22c-4282-92e5-9666a7b8535d",
"metadata": {},
"outputs": [],
"source": [
"# test it\n",
"\n",
"Tester.test(bow_lr_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b623079e-54fa-418f-b209-7d54ebbcc23a",
"metadata": {},
"outputs": [],
"source": [
"# The amazing word2vec model, implemented in gensim NLP library\n",
"\n",
"np.random.seed(42)\n",
"\n",
"# Preprocess the documents\n",
"processed_docs = [simple_preprocess(doc) for doc in documents]\n",
"\n",
"# Train Word2Vec model\n",
"w2v_model = Word2Vec(sentences=processed_docs, vector_size=400, window=5, min_count=1, workers=8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3de4efc7-68a6-4443-b9fd-70ee9d722362",
"metadata": {},
"outputs": [],
"source": [
"# This step of averaging vectors across the document is a weakness in our approach\n",
"\n",
"def document_vector(doc):\n",
" doc_words = simple_preprocess(doc)\n",
" word_vectors = [w2v_model.wv[word] for word in doc_words if word in w2v_model.wv]\n",
" return np.mean(word_vectors, axis=0) if word_vectors else np.zeros(w2v_model.vector_size)\n",
"\n",
"# Create feature matrix\n",
"X_w2v = np.array([document_vector(doc) for doc in documents])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f05eeec-dab8-4007-8e8c-dcf4175b8861",
"metadata": {},
"outputs": [],
"source": [
"# Run Linear Regression on word2vec\n",
"\n",
"word2vec_lr_regressor = LinearRegression()\n",
"word2vec_lr_regressor.fit(X_w2v, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e43d3fb9-e013-4573-90bf-9a522132b555",
"metadata": {},
"outputs": [],
"source": [
"def word2vec_lr_pricer(item):\n",
" doc = item.test_prompt()\n",
" doc_vector = document_vector(doc)\n",
" return max(0, word2vec_lr_regressor.predict([doc_vector])[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6740319d-5c8e-4125-9106-97e2e8ab72c7",
"metadata": {},
"outputs": [],
"source": [
"Tester.test(word2vec_lr_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d6d3265-37c1-464c-a489-5be4df0a7276",
"metadata": {},
"outputs": [],
"source": [
"# Support Vector Machines\n",
"\n",
"np.random.seed(42)\n",
"svr_regressor = LinearSVR()\n",
"\n",
"svr_regressor.fit(X_w2v, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fcc289e6-56a1-4119-864f-2fdf8efde643",
"metadata": {},
"outputs": [],
"source": [
"def svr_pricer(item):\n",
" np.random.seed(42)\n",
" doc = item.test_prompt()\n",
" doc_vector = document_vector(doc)\n",
" return max(float(svr_regressor.predict([doc_vector])[0]),0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "80286a48-7cca-40e6-af76-a814a23bb9dc",
"metadata": {},
"outputs": [],
"source": [
"Tester.test(svr_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6c44fe4-e4d9-4559-a8ed-d8f97e25b69f",
"metadata": {},
"outputs": [],
"source": [
"# And the powerful Random Forest regression\n",
"\n",
"rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=8)\n",
"rf_model.fit(X_w2v, prices)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a38812d0-913b-400b-804f-51434d895d05",
"metadata": {},
"outputs": [],
"source": [
"def random_forest_pricer(item):\n",
" doc = item.test_prompt()\n",
" doc_vector = document_vector(doc)\n",
" return max(0, rf_model.predict([doc_vector])[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88b51c01-c791-4fdc-8010-00b2e486b8ce",
"metadata": {},
"outputs": [],
"source": [
"Tester.test(random_forest_pricer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc85b271-4c92-480c-8843-2d7713b0fa57",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}