{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "db8736a7-ed94-441c-9556-831fa57b5a10",
   "metadata": {},
   "source": [
    "# The Product Pricer Continued...\n",
    "\n",
    "## Testing Gemini-1.5-pro model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "681c717b-4c24-4ac3-a5f3-3c5881d6e70a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "from dotenv import load_dotenv\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import google.generativeai as google_genai\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21a3833e-4093-43b0-8f7b-839c50b911ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from items import Item\n",
    "from testing import Tester "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d05bdc-0155-4c72-a7ee-aa4e614ffd3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# environment\n",
    "load_dotenv()\n",
    "os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a6fb86-74a4-403c-ab25-6db2d74e9d2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "google_genai.configure()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c830ed3e-24ee-4af6-a07b-a1bfdcd39278",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c9b05f4-c9eb-462c-8d86-de9140a2d985",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load in the pickle files that are located in the `pickled_dataset` folder\n",
    "with open('train.pkl', 'rb') as file:\n",
    "    train = pickle.load(file)\n",
    "\n",
    "with open('test.pkl', 'rb') as file:\n",
    "    test = pickle.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc5c807b-c14c-458e-8cca-32bc0cc5b7c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to create the messages format required for Gemini 1.5 Pro\n",
    "# This function prepares the system and user messages in the format expected by Gemini models.\n",
    "def gemini_messages_for(item):\n",
    "    system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
    "    \n",
    "    # Modify the test prompt by removing \"to the nearest dollar\" and \"Price is $\"\n",
    "    # This ensures that the model receives a cleaner, simpler prompt.\n",
    "    user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n",
    "\n",
    "    # Reformat messages to Gemini’s expected format: messages = [{'role':'user', 'parts': ['hello']}]\n",
    "    return [\n",
    "        {\"role\": \"system\", \"parts\": [system_message]},  # System-level instruction\n",
    "        {\"role\": \"user\", \"parts\": [user_prompt]},       # User's query\n",
    "        {\"role\": \"model\", \"parts\": [\"Price is $\"]}  # Assistant's expected prefix for response\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6da66bb-bc4b-49ad-9224-a388470ef20b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage of the gemini_messages_for function\n",
    "gemini_messages_for(test[0])  # Generate message structure for the first test item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1af1888-f94a-4106-b0d8-8a70939eec4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility function to extract the numerical price from a given string\n",
    "# This function removes currency symbols and commas, then extracts the first number found.\n",
    "def get_price(s):\n",
    "    s = s.replace('$', '').replace(',', '')  # Remove currency symbols and formatting\n",
    "    match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s)  # Regular expression to find a number\n",
    "    return float(match.group()) if match else 0  # Convert matched value to float, return 0 if no match"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a053c1a9-f86e-427c-a6be-ed8ec7bd63a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage of get_price function\n",
    "get_price(\"The price is roughly $99.99 because blah blah\")  # Expected output: 99.99"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a88e34-1719-4d08-adbe-adb69dfe5e83",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get the estimated price using Gemini 1.5 Pro\n",
    "def gemini_1_point_5_pro(item):\n",
    "    messages = gemini_messages_for(item)  # Generate messages for the model\n",
    "    system_message = messages[0]['parts'][0]  # Extract system-level instruction\n",
    "    user_messages = messages[1:]  # Remove system message from messages list\n",
    "    \n",
    "    # Initialize Gemini 1.5 Pro model with system instruction\n",
    "    gemini = google_genai.GenerativeModel(\n",
    "        model_name=\"gemini-1.5-pro\",\n",
    "        system_instruction=system_message\n",
    "    )\n",
    "\n",
    "    # Generate response using Gemini API\n",
    "    response = gemini.generate_content(\n",
    "        contents=user_messages,\n",
    "        generation_config=google_genai.GenerationConfig(max_output_tokens=5)\n",
    "    )\n",
    "\n",
    "    # Extract text response and convert to numerical price\n",
    "    return get_price(response.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d89b10bb-8ebb-42ef-9146-f6e64e6849f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage:\n",
    "gemini_1_point_5_pro(test[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89ad07e6-a28a-4625-b61e-d2ce12d440fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve the actual price of the test item (for comparison)\n",
    "test[0].price    # Output: 374.41"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "384f28e5-e51f-4cd3-8d74-30a8275530db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the function for gemini-1.5 pro using the Tester framework\n",
    "Tester.test(gemini_1_point_5_pro, test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b627291-b02e-48dd-9130-703498135ddf",
   "metadata": {},
   "source": [
    "## Five, Gemini-2.0-flash"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee393a9-7afd-404f-92f2-a64bb4d5fb8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to get the estimated price using Gemini-2.0-flash-exp\n",
    "def gemini_2_point_0_flash_exp(item):\n",
    "    messages = gemini_messages_for(item)  # Generate messages for the model\n",
    "    system_message = messages[0]['parts'][0]  # Extract system-level instruction\n",
    "    user_messages = messages[1:]  # Remove system message from messages list\n",
    "    \n",
    "    # Initialize Gemini-2.0-flash-exp model with system instruction\n",
    "    gemini = google_genai.GenerativeModel(\n",
    "        model_name=\"gemini-2.0-flash-exp\",\n",
    "        system_instruction=system_message\n",
    "    )\n",
    "\n",
    "    # Adding a delay to avoid hitting the API rate limit and getting a \"ResourceExhausted: 429\" error\n",
    "    time.sleep(5)\n",
    "    \n",
    "    # Generate response using Gemini API\n",
    "    response = gemini.generate_content(\n",
    "        contents=user_messages,\n",
    "        generation_config=google_genai.GenerationConfig(max_output_tokens=5)\n",
    "    )\n",
    "\n",
    "    # Extract text response and convert to numerical price\n",
    "    return get_price(response.text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "203dc6f1-309e-46eb-9957-e06eed803cc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage:\n",
    "gemini_2_point_0_flash_exp(test[0])  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a844df09-d347-40b9-bb79-006ec4160aab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve the actual price of the test item (for comparison)\n",
    "test[0].price    # Output: 374.41"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "500b45c7-e5c1-44f2-95c9-1c3c06365339",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the function for gemini-2.0-flash-exp using the Tester framework\n",
    "Tester.test(gemini_2_point_0_flash_exp, test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "746b2d12-ba92-48e2-9065-c9a108d1593b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}