From the uDemy course on LLM engineering.
https://www.udemy.com/course/llm-engineering-master-ai-and-large-language-models
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
313 lines
9.2 KiB
313 lines
9.2 KiB
{ |
|
"cells": [ |
|
{ |
|
"cell_type": "markdown", |
|
"id": "db8736a7-ed94-441c-9556-831fa57b5a10", |
|
"metadata": {}, |
|
"source": [ |
|
"# The Product Pricer Continued...\n", |
|
"\n", |
|
"## Testing Gemini-1.5-pro model" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "681c717b-4c24-4ac3-a5f3-3c5881d6e70a", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"import os\n", |
|
"import re\n", |
|
"from dotenv import load_dotenv\n", |
|
"import matplotlib.pyplot as plt\n", |
|
"import pickle\n", |
|
"import google.generativeai as google_genai\n", |
|
"import time" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "21a3833e-4093-43b0-8f7b-839c50b911ea", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"from items import Item\n", |
|
"from testing import Tester " |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "36d05bdc-0155-4c72-a7ee-aa4e614ffd3c", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# environment\n", |
|
"load_dotenv()\n", |
|
"os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "b0a6fb86-74a4-403c-ab25-6db2d74e9d2b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"google_genai.configure()" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "c830ed3e-24ee-4af6-a07b-a1bfdcd39278", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"%matplotlib inline" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "5c9b05f4-c9eb-462c-8d86-de9140a2d985", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Load in the pickle files that are located in the `pickled_dataset` folder\n", |
|
"with open('train.pkl', 'rb') as file:\n", |
|
" train = pickle.load(file)\n", |
|
"\n", |
|
"with open('test.pkl', 'rb') as file:\n", |
|
" test = pickle.load(file)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "fc5c807b-c14c-458e-8cca-32bc0cc5b7c3", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Function to create the messages format required for Gemini 1.5 Pro\n", |
|
"# This function prepares the system and user messages in the format expected by Gemini models.\n", |
|
"def gemini_messages_for(item):\n", |
|
" system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n", |
|
" \n", |
|
" # Modify the test prompt by removing \"to the nearest dollar\" and \"Price is $\"\n", |
|
" # This ensures that the model receives a cleaner, simpler prompt.\n", |
|
" user_prompt = item.test_prompt().replace(\" to the nearest dollar\", \"\").replace(\"\\n\\nPrice is $\", \"\")\n", |
|
"\n", |
|
" # Reformat messages to Gemini’s expected format: messages = [{'role':'user', 'parts': ['hello']}]\n", |
|
" return [\n", |
|
" {\"role\": \"system\", \"parts\": [system_message]}, # System-level instruction\n", |
|
" {\"role\": \"user\", \"parts\": [user_prompt]}, # User's query\n", |
|
" {\"role\": \"model\", \"parts\": [\"Price is $\"]} # Assistant's expected prefix for response\n", |
|
" ]" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "d6da66bb-bc4b-49ad-9224-a388470ef20b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Example usage of the gemini_messages_for function\n", |
|
"gemini_messages_for(test[0]) # Generate message structure for the first test item" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "b1af1888-f94a-4106-b0d8-8a70939eec4e", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Utility function to extract the numerical price from a given string\n", |
|
"# This function removes currency symbols and commas, then extracts the first number found.\n", |
|
"def get_price(s):\n", |
|
" s = s.replace('$', '').replace(',', '') # Remove currency symbols and formatting\n", |
|
" match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", s) # Regular expression to find a number\n", |
|
" return float(match.group()) if match else 0 # Convert matched value to float, return 0 if no match" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "a053c1a9-f86e-427c-a6be-ed8ec7bd63a5", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Example usage of get_price function\n", |
|
"get_price(\"The price is roughly $99.99 because blah blah\") # Expected output: 99.99" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "34a88e34-1719-4d08-adbe-adb69dfe5e83", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Function to get the estimated price using Gemini 1.5 Pro\n", |
|
"def gemini_1_point_5_pro(item):\n", |
|
" messages = gemini_messages_for(item) # Generate messages for the model\n", |
|
" system_message = messages[0]['parts'][0] # Extract system-level instruction\n", |
|
" user_messages = messages[1:] # Remove system message from messages list\n", |
|
" \n", |
|
" # Initialize Gemini 1.5 Pro model with system instruction\n", |
|
" gemini = google_genai.GenerativeModel(\n", |
|
" model_name=\"gemini-1.5-pro\",\n", |
|
" system_instruction=system_message\n", |
|
" )\n", |
|
"\n", |
|
" # Generate response using Gemini API\n", |
|
" response = gemini.generate_content(\n", |
|
" contents=user_messages,\n", |
|
" generation_config=google_genai.GenerationConfig(max_output_tokens=5)\n", |
|
" )\n", |
|
"\n", |
|
" # Extract text response and convert to numerical price\n", |
|
" return get_price(response.text)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "d89b10bb-8ebb-42ef-9146-f6e64e6849f9", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Example usage:\n", |
|
"gemini_1_point_5_pro(test[0])" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "89ad07e6-a28a-4625-b61e-d2ce12d440fc", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Retrieve the actual price of the test item (for comparison)\n", |
|
"test[0].price # Output: 374.41" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "384f28e5-e51f-4cd3-8d74-30a8275530db", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Test the function for gemini-1.5 pro using the Tester framework\n", |
|
"Tester.test(gemini_1_point_5_pro, test)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "markdown", |
|
"id": "9b627291-b02e-48dd-9130-703498135ddf", |
|
"metadata": {}, |
|
"source": [ |
|
"## Five, Gemini-2.0-flash" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "0ee393a9-7afd-404f-92f2-a64bb4d5fb8b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Function to get the estimated price using Gemini-2.0-flash-exp\n", |
|
"def gemini_2_point_0_flash_exp(item):\n", |
|
" messages = gemini_messages_for(item) # Generate messages for the model\n", |
|
" system_message = messages[0]['parts'][0] # Extract system-level instruction\n", |
|
" user_messages = messages[1:] # Remove system message from messages list\n", |
|
" \n", |
|
" # Initialize Gemini-2.0-flash-exp model with system instruction\n", |
|
" gemini = google_genai.GenerativeModel(\n", |
|
" model_name=\"gemini-2.0-flash-exp\",\n", |
|
" system_instruction=system_message\n", |
|
" )\n", |
|
"\n", |
|
" # Adding a delay to avoid hitting the API rate limit and getting a \"ResourceExhausted: 429\" error\n", |
|
" time.sleep(5)\n", |
|
" \n", |
|
" # Generate response using Gemini API\n", |
|
" response = gemini.generate_content(\n", |
|
" contents=user_messages,\n", |
|
" generation_config=google_genai.GenerationConfig(max_output_tokens=5)\n", |
|
" )\n", |
|
"\n", |
|
" # Extract text response and convert to numerical price\n", |
|
" return get_price(response.text)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "203dc6f1-309e-46eb-9957-e06eed803cc8", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Example usage:\n", |
|
"gemini_2_point_0_flash_exp(test[0]) " |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "a844df09-d347-40b9-bb79-006ec4160aab", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Retrieve the actual price of the test item (for comparison)\n", |
|
"test[0].price # Output: 374.41" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "500b45c7-e5c1-44f2-95c9-1c3c06365339", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [ |
|
"# Test the function for gemini-2.0-flash-exp using the Tester framework\n", |
|
"Tester.test(gemini_2_point_0_flash_exp, test)" |
|
] |
|
}, |
|
{ |
|
"cell_type": "code", |
|
"execution_count": null, |
|
"id": "746b2d12-ba92-48e2-9065-c9a108d1593b", |
|
"metadata": {}, |
|
"outputs": [], |
|
"source": [] |
|
} |
|
], |
|
"metadata": { |
|
"kernelspec": { |
|
"display_name": "Python 3 (ipykernel)", |
|
"language": "python", |
|
"name": "python3" |
|
}, |
|
"language_info": { |
|
"codemirror_mode": { |
|
"name": "ipython", |
|
"version": 3 |
|
}, |
|
"file_extension": ".py", |
|
"mimetype": "text/x-python", |
|
"name": "python", |
|
"nbconvert_exporter": "python", |
|
"pygments_lexer": "ipython3", |
|
"version": "3.11.11" |
|
} |
|
}, |
|
"nbformat": 4, |
|
"nbformat_minor": 5 |
|
}
|
|
|