diff --git a/week2/community-contributions/day5-book-flight.ipynb b/week2/community-contributions/day5-book-flight.ipynb new file mode 100644 index 0000000..00bf022 --- /dev/null +++ b/week2/community-contributions/day5-book-flight.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "df2fc552-2c56-45bd-ac4e-d1554c022605", + "metadata": {}, + "source": [ + "# Project - Airline AI Assistant\n", + "I've added database connectivity to enable Openai to:\n", + "- Retrieve ticket prices\n", + "- Display the number of available seats for each flight\n", + "- List all available destination cities\n", + "- Facilitate seat bookings\n", + "\n", + "Once a booking is confirmed, an image of the booked destination city is displayed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "908cb842-c8a1-467d-8422-8834f8b7aecf", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import os\n", + "import json\n", + "import gradio as gr\n", + "import mysql.connector\n", + "import base64\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "from io import BytesIO\n", + "from pydub import AudioSegment\n", + "from pydub.playback import play\n", + "from PIL import Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7997c30-26f2-4f2e-957f-c1fade2ad101", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialization\n", + "load_dotenv()\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "MODEL = \"gpt-4o-mini\"\n", + "openai = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfa898fc-bfec-44ce-81fc-c6efed9b826f", + "metadata": {}, + "outputs": [], + "source": [ + "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", + "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n", + "system_message += \"Always be accurate. If you don't know the answer, say so.\"\n", + "system_message += \"Make sure you ask if they want to book a flight when appropriate.\"\n", + "system_message += \"If they book a flight make sure you respond with 'Booking confirmed' in your reply.\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07076d5b-2603-4fa4-a2ed-aa95d4a94131", + "metadata": {}, + "outputs": [], + "source": [ + "def get_db_connection():\n", + " return mysql.connector.connect(\n", + " host=os.getenv(\"DB_HOST\"),\n", + " user=os.getenv(\"DB_USER\"),\n", + " password=os.getenv(\"DB_PASSWORD\"),\n", + " database=os.getenv(\"DB_NAME\")\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a575906-943f-4733-85d4-b854eb27b318", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_ticket_price(destination_city):\n", + " db_connection = get_db_connection()\n", + " cursor = db_connection.cursor()\n", + " select_query = \"SELECT price FROM flights WHERE z_city = %s;\"\n", + " cursor.execute(select_query, (destination_city,))\n", + " # print(f\"QUERY: {select_query}\")\n", + " row = cursor.fetchone()\n", + " cursor.close()\n", + " db_connection.close()\n", + "\n", + " return float(row[0]) if row else None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "574fc230-137f-4085-93ac-ebbd01dc7d1e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_avail_seats(destination_city):\n", + " db_connection = get_db_connection()\n", + " cursor = db_connection.cursor()\n", + " select_query = \"\"\"\n", + " SELECT f.seats - COALESCE(b.booked, 0) AS available\n", + " FROM flights f\n", + " LEFT JOIN (\n", + " SELECT flight_number, COUNT(*) AS booked\n", + " FROM bookings\n", + " GROUP BY flight_number\n", + " ) b ON f.flight_number = b.flight_number\n", + " WHERE f.z_city = %s;\n", + " \"\"\"\n", + " cursor.execute(select_query, (destination_city,))\n", + " row = cursor.fetchone()\n", + "\n", + " cursor.close()\n", + " db_connection.close()\n", + "\n", + " return row[0] if row else None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26ff9b4b-2943-43d9-8c1a-8d7f3d528143", + "metadata": {}, + "outputs": [], + "source": [ + "def book_seat(destination_city, passenger):\n", + " db_connection = get_db_connection()\n", + " cursor = db_connection.cursor()\n", + "\n", + " cursor.execute(\"SELECT flight_number FROM flights WHERE z_city = %s LIMIT 1;\", (destination_city,))\n", + " flight = cursor.fetchone()\n", + "\n", + " if not flight:\n", + " cursor.close()\n", + " db_connection.close()\n", + " return {\"error\": f\"No available flights to {destination_city}.\"}\n", + "\n", + " flight_number = flight[0] # Extract the flight number from the result\n", + "\n", + " insert_query = \"INSERT INTO bookings (`name`, `flight_number`) VALUES (%s, %s);\"\n", + " cursor.execute(insert_query, (passenger, flight_number))\n", + " db_connection.commit()\n", + "\n", + " confirmation = {\n", + " \"message\": f\"Booking confirmed for {passenger} to {destination_city}.\",\n", + " \"flight_number\": flight_number\n", + " }\n", + "\n", + " cursor.close()\n", + " db_connection.close()\n", + " \n", + " return confirmation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "231eb10d-88ca-4f39-83e0-c4548149917e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_destinations():\n", + " db_connection = get_db_connection()\n", + " cursor = db_connection.cursor()\n", + " \n", + " select_query = \"SELECT DISTINCT z_city FROM flights;\" # Ensure unique destinations\n", + " cursor.execute(select_query)\n", + " rows = cursor.fetchall() # Fetch all rows\n", + " destinations = [row[0] for row in rows] if rows else [] # Extract city names\n", + " cursor.close()\n", + " db_connection.close()\n", + " \n", + " return destinations # Returns a list of destination cities" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "938f0d86-8cef-4f7f-bc82-7453ca3c096c", + "metadata": {}, + "outputs": [], + "source": [ + "tool_call = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_ticket_price\",\n", + " \"description\": \"Get the price of a return ticket to the destination city.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"destination_city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city that the customer wants to travel to\"\n", + " }\n", + " },\n", + " \"required\": [\"destination_city\"]\n", + " }\n", + " }\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_avail_seats\",\n", + " \"description\": \"Get the number of available seats to the destination city.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"destination_city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city that the customer wants to travel to\"\n", + " }\n", + " },\n", + " \"required\": [\"destination_city\"]\n", + " }\n", + " }\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_destinations\",\n", + " \"description\": \"Fetches available flight destinations (city pairs) and their corresponding prices.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {},\n", + " \"required\": []\n", + " }\n", + " }\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"book_seat\",\n", + " \"description\": \"Book seat to the destination city.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"destination_city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city that the customer wants to travel to\"\n", + " },\n", + " \"passenger\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The passenger booking the flight\"\n", + " }\n", + " },\n", + " \"required\": [\"destination_city\",\"passenger\"]\n", + " }\n", + " }\n", + " }\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7c02377-78d3-4f6d-88eb-d36c0124fdd4", + "metadata": {}, + "outputs": [], + "source": [ + "def handle_tool_call(message):\n", + " if not message.tool_calls:\n", + " raise ValueError(\"No tool calls found in the message.\")\n", + "\n", + " tool_call = message.tool_calls[0] \n", + " arguments = json.loads(tool_call.function.arguments)\n", + " city = arguments.get(\"destination_city\")\n", + " function_name = tool_call.function.name\n", + "\n", + " # Handle function calls\n", + " if function_name == \"get_ticket_price\":\n", + " reply = get_ticket_price(city)\n", + " key = \"price\"\n", + " elif function_name == \"get_avail_seats\":\n", + " reply = get_avail_seats(city)\n", + " key = \"seats\"\n", + " elif function_name == \"get_destinations\":\n", + " reply = get_destinations()\n", + " key = \"destinations\"\n", + " elif function_name == \"book_seat\":\n", + " passenger = arguments.get(\"passenger\") # Extract passenger name\n", + " if not passenger:\n", + " raise ValueError(\"Passenger name is required for booking.\")\n", + " reply = book_seat(city, passenger)\n", + " key = \"booking\"\n", + " else:\n", + " raise ValueError(f\"Unknown function: {function_name}\")\n", + "\n", + " response = {\n", + " \"role\": \"tool\",\n", + " \"content\": json.dumps({\"destination_city\": city, key: reply}),\n", + " \"tool_call_id\": tool_call.id\n", + " }\n", + "\n", + " return response, city" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb1ebaee-434c-4b24-87b9-3c179d0527c7", + "metadata": {}, + "outputs": [], + "source": [ + "def talker(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"tts-1\",\n", + " voice=\"alloy\",\n", + " input=message\n", + " )\n", + " \n", + " audio_stream = BytesIO(response.content)\n", + " audio = AudioSegment.from_file(audio_stream, format=\"mp3\")\n", + " play(audio)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c8f675b-f8bb-4173-9e47-24508778f224", + "metadata": {}, + "outputs": [], + "source": [ + "def draw_city(city):\n", + " image_response = openai.images.generate(\n", + " model=\"dall-e-3\",\n", + " prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a vibrant pop-art style\",\n", + " size=\"1024x1024\",\n", + " n=1,\n", + " response_format=\"b64_json\",\n", + " )\n", + " image_base64 = image_response.data[0].b64_json\n", + " image_data = base64.b64decode(image_base64)\n", + " return Image.open(BytesIO(image_data))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f38fed6-bcd9-4ad2-848a-16193c14a659", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(message, history):\n", + " history.append({\"role\": \"user\", \"content\": message})\n", + " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", + " # print(f\"BEFORE TOOL CALL: {message} \\n\")\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tool_call)\n", + " image = None\n", + " city = None\n", + " \n", + " if response.choices[0].finish_reason == \"tool_calls\":\n", + " tool_message = response.choices[0].message\n", + " response, city = handle_tool_call(tool_message)\n", + " messages.append(tool_message)\n", + " messages.append(response)\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages)\n", + " talker(response.choices[0].message.content) \n", + " \n", + " if \"Booking confirmed\" in response.choices[0].message.content and city:\n", + " image = draw_city(city)\n", + "\n", + " new_message = response.choices[0].message.content\n", + " history.append({\"role\": \"assistant\", \"content\": new_message})\n", + "\n", + " return \"\", history, image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "992bc241-ce17-4d57-9f9c-1baaf2088162", + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", + " image_output = gr.Image(height=600)\n", + " with gr.Row():\n", + " entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", + " with gr.Row():\n", + " clear = gr.Button(\"Clear\")\n", + "\n", + " entry.submit(chat, inputs=[entry, chatbot], outputs=[entry, chatbot, image_output])\n", + " clear.click(lambda: ([], None), inputs=None, outputs=[chatbot, image_output], queue=False)\n", + "\n", + "ui.launch(inbrowser=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}