{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "df2fc552-2c56-45bd-ac4e-d1554c022605",
   "metadata": {},
   "source": [
    "# Project - Airline AI Assistant\n",
    "I've added database connectivity to enable Openai to:\n",
    "- Retrieve ticket prices\n",
    "- Display the number of available seats for each flight\n",
    "- List all available destination cities\n",
    "- Facilitate seat bookings\n",
    "\n",
    "Once a booking is confirmed, an image of the booked destination city is displayed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "908cb842-c8a1-467d-8422-8834f8b7aecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import os\n",
    "import json\n",
    "import gradio as gr\n",
    "import mysql.connector\n",
    "import base64\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "from io import BytesIO\n",
    "from pydub import AudioSegment\n",
    "from pydub.playback import play\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7997c30-26f2-4f2e-957f-c1fade2ad101",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization\n",
    "load_dotenv()\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "MODEL = \"gpt-4o-mini\"\n",
    "openai = OpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfa898fc-bfec-44ce-81fc-c6efed9b826f",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n",
    "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n",
    "system_message += \"Always be accurate. If you don't know the answer, say so.\"\n",
    "system_message += \"Make sure you ask if they want to book a flight when appropriate.\"\n",
    "system_message += \"If they book a flight make sure you respond with 'Booking confirmed' in your reply.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07076d5b-2603-4fa4-a2ed-aa95d4a94131",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_db_connection():\n",
    "    return mysql.connector.connect(\n",
    "        host=os.getenv(\"DB_HOST\"),\n",
    "        user=os.getenv(\"DB_USER\"),\n",
    "        password=os.getenv(\"DB_PASSWORD\"),\n",
    "        database=os.getenv(\"DB_NAME\")\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a575906-943f-4733-85d4-b854eb27b318",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_ticket_price(destination_city):\n",
    "    db_connection = get_db_connection()\n",
    "    cursor = db_connection.cursor()\n",
    "    select_query = \"SELECT price FROM flights WHERE z_city = %s;\"\n",
    "    cursor.execute(select_query, (destination_city,))\n",
    "    # print(f\"QUERY: {select_query}\")\n",
    "    row = cursor.fetchone()\n",
    "    cursor.close()\n",
    "    db_connection.close()\n",
    "\n",
    "    return float(row[0]) if row else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "574fc230-137f-4085-93ac-ebbd01dc7d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_avail_seats(destination_city):\n",
    "    db_connection = get_db_connection()\n",
    "    cursor = db_connection.cursor()\n",
    "    select_query = \"\"\"\n",
    "        SELECT f.seats - COALESCE(b.booked, 0) AS available\n",
    "        FROM flights f\n",
    "        LEFT JOIN (\n",
    "            SELECT flight_number, COUNT(*) AS booked\n",
    "            FROM bookings\n",
    "            GROUP BY flight_number\n",
    "        ) b ON f.flight_number = b.flight_number\n",
    "        WHERE f.z_city = %s;\n",
    "    \"\"\"\n",
    "    cursor.execute(select_query, (destination_city,))\n",
    "    row = cursor.fetchone()\n",
    "\n",
    "    cursor.close()\n",
    "    db_connection.close()\n",
    "\n",
    "    return row[0] if row else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26ff9b4b-2943-43d9-8c1a-8d7f3d528143",
   "metadata": {},
   "outputs": [],
   "source": [
    "def book_seat(destination_city, passenger):\n",
    "    db_connection = get_db_connection()\n",
    "    cursor = db_connection.cursor()\n",
    "\n",
    "    cursor.execute(\"SELECT flight_number FROM flights WHERE z_city = %s LIMIT 1;\", (destination_city,))\n",
    "    flight = cursor.fetchone()\n",
    "\n",
    "    if not flight:\n",
    "        cursor.close()\n",
    "        db_connection.close()\n",
    "        return {\"error\": f\"No available flights to {destination_city}.\"}\n",
    "\n",
    "    flight_number = flight[0]  # Extract the flight number from the result\n",
    "\n",
    "    insert_query = \"INSERT INTO bookings (`name`, `flight_number`) VALUES (%s, %s);\"\n",
    "    cursor.execute(insert_query, (passenger, flight_number))\n",
    "    db_connection.commit()\n",
    "\n",
    "    confirmation = {\n",
    "        \"message\": f\"Booking confirmed for {passenger} to {destination_city}.\",\n",
    "        \"flight_number\": flight_number\n",
    "    }\n",
    "\n",
    "    cursor.close()\n",
    "    db_connection.close()\n",
    "    \n",
    "    return confirmation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "231eb10d-88ca-4f39-83e0-c4548149917e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_destinations():\n",
    "    db_connection = get_db_connection()\n",
    "    cursor = db_connection.cursor()\n",
    "    \n",
    "    select_query = \"SELECT DISTINCT z_city FROM flights;\"  # Ensure unique destinations\n",
    "    cursor.execute(select_query)\n",
    "    rows = cursor.fetchall()  # Fetch all rows\n",
    "    destinations = [row[0] for row in rows] if rows else []  # Extract city names\n",
    "    cursor.close()\n",
    "    db_connection.close()\n",
    "    \n",
    "    return destinations  # Returns a list of destination cities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "938f0d86-8cef-4f7f-bc82-7453ca3c096c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_ticket_price\",\n",
    "            \"description\": \"Get the price of a return ticket to the destination city.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"destination_city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city that the customer wants to travel to\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"destination_city\"]\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_avail_seats\",\n",
    "            \"description\": \"Get the number of available seats to the destination city.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"destination_city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city that the customer wants to travel to\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"destination_city\"]\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_destinations\",\n",
    "            \"description\": \"Fetches available flight destinations (city pairs) and their corresponding prices.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {},\n",
    "                \"required\": []\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"book_seat\",\n",
    "            \"description\": \"Book seat to the destination city.\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"destination_city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city that the customer wants to travel to\"\n",
    "                    },\n",
    "                    \"passenger\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The passenger booking the flight\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"destination_city\",\"passenger\"]\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7c02377-78d3-4f6d-88eb-d36c0124fdd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_tool_call(message):\n",
    "    if not message.tool_calls:\n",
    "        raise ValueError(\"No tool calls found in the message.\")\n",
    "\n",
    "    tool_call = message.tool_calls[0]  \n",
    "    arguments = json.loads(tool_call.function.arguments)\n",
    "    city = arguments.get(\"destination_city\")\n",
    "    function_name = tool_call.function.name\n",
    "\n",
    "    # Handle function calls\n",
    "    if function_name == \"get_ticket_price\":\n",
    "        reply = get_ticket_price(city)\n",
    "        key = \"price\"\n",
    "    elif function_name == \"get_avail_seats\":\n",
    "        reply = get_avail_seats(city)\n",
    "        key = \"seats\"\n",
    "    elif function_name == \"get_destinations\":\n",
    "        reply = get_destinations()\n",
    "        key = \"destinations\"\n",
    "    elif function_name == \"book_seat\":\n",
    "        passenger = arguments.get(\"passenger\")  # Extract passenger name\n",
    "        if not passenger:\n",
    "            raise ValueError(\"Passenger name is required for booking.\")\n",
    "        reply = book_seat(city, passenger)\n",
    "        key = \"booking\"\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown function: {function_name}\")\n",
    "\n",
    "    response = {\n",
    "        \"role\": \"tool\",\n",
    "        \"content\": json.dumps({\"destination_city\": city, key: reply}),\n",
    "        \"tool_call_id\": tool_call.id\n",
    "    }\n",
    "\n",
    "    return response, city"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1ebaee-434c-4b24-87b9-3c179d0527c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def talker(message):\n",
    "    response = openai.audio.speech.create(\n",
    "      model=\"tts-1\",\n",
    "      voice=\"alloy\",\n",
    "      input=message\n",
    "    )\n",
    "    \n",
    "    audio_stream = BytesIO(response.content)\n",
    "    audio = AudioSegment.from_file(audio_stream, format=\"mp3\")\n",
    "    play(audio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c8f675b-f8bb-4173-9e47-24508778f224",
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_city(city):\n",
    "    image_response = openai.images.generate(\n",
    "            model=\"dall-e-3\",\n",
    "            prompt=f\"An image representing a vacation in {city}, showing tourist spots and everything unique about {city}, in a vibrant pop-art style\",\n",
    "            size=\"1024x1024\",\n",
    "            n=1,\n",
    "            response_format=\"b64_json\",\n",
    "        )\n",
    "    image_base64 = image_response.data[0].b64_json\n",
    "    image_data = base64.b64decode(image_base64)\n",
    "    return Image.open(BytesIO(image_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f38fed6-bcd9-4ad2-848a-16193c14a659",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message, history):\n",
    "    history.append({\"role\": \"user\", \"content\": message})\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history\n",
    "    # print(f\"BEFORE TOOL CALL: {message} \\n\")\n",
    "    response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tool_call)\n",
    "    image = None\n",
    "    city = None\n",
    "    \n",
    "    if response.choices[0].finish_reason == \"tool_calls\":\n",
    "        tool_message = response.choices[0].message\n",
    "        response, city = handle_tool_call(tool_message)\n",
    "        messages.append(tool_message)\n",
    "        messages.append(response)\n",
    "        response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
    "        talker(response.choices[0].message.content) \n",
    "    \n",
    "    if \"Booking confirmed\" in response.choices[0].message.content and city:\n",
    "        image = draw_city(city)\n",
    "\n",
    "    new_message = response.choices[0].message.content\n",
    "    history.append({\"role\": \"assistant\", \"content\": new_message})\n",
    "\n",
    "    return \"\", history, image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "992bc241-ce17-4d57-9f9c-1baaf2088162",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        chatbot = gr.Chatbot(height=500, type=\"messages\")\n",
    "        image_output = gr.Image(height=600)\n",
    "    with gr.Row():\n",
    "        entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n",
    "    with gr.Row():\n",
    "        clear = gr.Button(\"Clear\")\n",
    "\n",
    "    entry.submit(chat, inputs=[entry, chatbot], outputs=[entry, chatbot, image_output])\n",
    "    clear.click(lambda: ([], None), inputs=None, outputs=[chatbot, image_output], queue=False)\n",
    "\n",
    "ui.launch(inbrowser=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}