{ "cells": [ { "cell_type": "markdown", "id": "ae6f09ff-a7f8-411f-b3e5-9ae502af35c6", "metadata": {}, "source": [ "### To do:\n", "\n", "- [x] get some extra **Python practice**\n", "- [x] get lots of **tool practice**\n", "- [x] increase your **Gradio proficiency**\n", "- [x] try **picture output** (on command only!)\n", "- [x] try **audio output** (on command only?)\n", "- [ ] try **audio input?** (most importantly, do anything you saw in the video; the rest is nice-to-have)\n", "_Extra: delve into Claude's function calling documentation (can be left for later)_" ] }, { "cell_type": "code", "execution_count": null, "id": "36e0cd9c-6622-4fa9-a4f8-b3da1b9b836e", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from dotenv import load_dotenv\n", "from openai import OpenAI\n", "import gradio as gr\n", "import random\n", "import re\n", "import base64\n", "from io import BytesIO\n", "from PIL import Image\n", "from IPython.display import Audio, display" ] }, { "cell_type": "code", "execution_count": null, "id": "57fc95b9-043c-4a38-83aa-365cc3b285ba", "metadata": {}, "outputs": [], "source": [ "load_dotenv()\n", "\n", "openai_api_key = os.getenv('OPENAI_API_KEY')\n", "if openai_api_key:\n", " print(f\"OpenAI API Key exists and begins with {openai_api_key[:8]}\")\n", "else:\n", " print(\"OpenAI API Key? As if!\")\n", " \n", "MODEL = \"gpt-4o-mini\"\n", "openai = OpenAI()" ] }, { "cell_type": "code", "execution_count": null, "id": "e633ee2a-bbaa-47a4-95ef-b1d8773866aa", "metadata": {}, "outputs": [], "source": [ "system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", "system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n", "system_message += \"Always be accurate. If you don't know the answer, say so. \"\n", "system_message += \"You can book flights directly. \"\n", "system_message += \"You can generate beautiful artistic renditions of the cities we fly to.\"" ] }, { "cell_type": "code", "execution_count": null, "id": "c123af78-b5d6-4cc9-8f18-c492b1f30c85", "metadata": {}, "outputs": [], "source": [ "# ticket price function\n", "\n", "ticket_prices = {\"valletta\": \"799 $\", \"turin\": \"899 $\", \"sacramento\": \"1400 $\", \"montreal\": \"499 $\"} #awkward currency for better tts rendition\n", "\n", "def get_ticket_price(destination_city):\n", " print(f\"Tool get_ticket_price called for {destination_city}\")\n", " city = destination_city.lower()\n", " return ticket_prices.get(city, \"Unknown\")" ] }, { "cell_type": "code", "execution_count": null, "id": "00e486fb-709e-4b8e-a029-9e2b225ddc25", "metadata": {}, "outputs": [], "source": [ "# travel booking function\n", "\n", "def book_flight(destination_city):\n", " booking_code = ''.join(random.choice('0123456789BCDFXYZ') for i in range(2)) + ''.join(random.choice('012346789HIJKLMNOPQRS') for i in range(2)) + ''.join(random.choice('0123456789GHIJKLMNUOP') for i in range(2))\n", " print(f\"Booking code {booking_code} generated for flight to {destination_city}.\")\n", " \n", " return booking_code" ] }, { "cell_type": "code", "execution_count": null, "id": "c0600b4e-fa4e-4c34-b317-fac1e60b5f95", "metadata": {}, "outputs": [], "source": [ "# verify if booking code is correct\n", "\n", "def check_code(code):\n", " valid = \"valid\" if re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None else \"not valid\"\n", " print(f\"Code checker called for code {code}, which is {valid}.\")\n", " return re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None" ] }, { "cell_type": "code", "execution_count": null, "id": "e1d1b1c2-089c-41e5-b1bd-900632271093", "metadata": {}, "outputs": [], "source": [ "# make a nice preview of the travel destination\n", "\n", "def artist(city):\n", " image_response = openai.images.generate(\n", " model=\"dall-e-3\",\n", " prompt=f\"Make an image in the style of a vibrant, artistically filtered photo that is a collage of the best sights and views in {city}.\",\n", " size=\"1024x1024\",\n", " n=1,\n", " response_format=\"b64_json\",\n", " )\n", " image_base64 = image_response.data[0].b64_json\n", " image_data = base64.b64decode(image_base64)\n", " img = Image.open(BytesIO(image_data))\n", "\n", " img.save(\"img001.png\") #make them 4 cents count! .save is from PIL library, btw\n", " \n", " return img" ] }, { "cell_type": "code", "execution_count": null, "id": "626d99af-90de-4594-9ffd-b87a8b6ef4fd", "metadata": {}, "outputs": [], "source": [ "price_function = {\n", " \"name\": \"get_ticket_price\",\n", " \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to travel to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "6e7bc09c-665b-4885-823c-f145cefe8c23", "metadata": {}, "outputs": [], "source": [ "booking_function = {\n", " \"name\": \"book_flight\",\n", " \"description\": \"Call this whenever you have to book a flight. Give it the destination city and you will get a booking code. Tell the customer \\\n", "that the flight is booked and give them the booking code obtained through this function. Never give any other codes to the customer.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"destination_city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city that the customer wants to book their flight to\",\n", " },\n", " },\n", " \"required\": [\"destination_city\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "cc365d87-fed2-41ff-9232-850fdce1cff2", "metadata": {}, "outputs": [], "source": [ "artist_function = {\n", " \"name\": \"artist\",\n", " \"description\": \"Call this whenever you need to generate a picture, photo, or graphic impression of a city.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"city\": {\n", " \"type\": \"string\",\n", " \"description\": \"The city of which an image is to be generated\",\n", " },\n", " },\n", " \"required\": [\"city\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "99b0a0e3-db44-49f9-8d27-349b9f04c680", "metadata": {}, "outputs": [], "source": [ "codecheck_function = {\n", " \"name\": \"check_code\",\n", " \"description\": \"Call this whenever you need to verify if a booking code for a flight (also called 'flight code', 'booking reference', \\\n", "or variations thereof) is valid.\",\n", " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"code\": {\n", " \"type\": \"string\",\n", " \"description\": \"The code that you or the user needs to verify\",\n", " },\n", " },\n", " \"required\": [\"code\"],\n", " \"additionalProperties\": False\n", " }\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "3fa371c4-91ff-41ae-9b10-23fe617022d1", "metadata": {}, "outputs": [], "source": [ "# List of tools:\n", "\n", "tools = [{\"type\": \"function\", \"function\": price_function}, {\"type\": \"function\", \"function\": booking_function}, {\"type\": \"function\", \"function\": codecheck_function}, {\"type\": \"function\", \"function\": artist_function}]" ] }, { "cell_type": "code", "execution_count": null, "id": "4d34942a-f0c7-4835-ba07-746104a8c524", "metadata": {}, "outputs": [], "source": [ "def chat(history):\n", " messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", " image = None\n", " \n", " if response.choices[0].finish_reason==\"tool_calls\":\n", " message = response.choices[0].message\n", " responses = handle_tool_call(message)[0]\n", " image = handle_tool_call(message)[1]\n", " messages.append(message)\n", " for response in responses:\n", " messages.append(response)\n", " response = openai.chat.completions.create(model=MODEL, messages=messages)\n", " \n", " reply = response.choices[0].message.content\n", "\n", " #talker(reply) #current cost: $0.015 per 1000 characters (not tokens!)\n", " \n", " history += [{\"role\": \"assistant\", \"content\": reply}]\n", " \n", " return history, image" ] }, { "cell_type": "code", "execution_count": null, "id": "5413f7fb-c5f7-44c4-a63d-3d0465eb0af4", "metadata": {}, "outputs": [], "source": [ "def handle_tool_call(message):\n", " responses = []\n", " image = None\n", " \n", " for tool_call in message.tool_calls:\n", " arguments = json.loads(tool_call.function.arguments)\n", " indata = arguments[list(arguments.keys())[0]] # works for now because we only have one argument in each of our functions\n", " function_name = tool_call.function.name\n", " if function_name == 'get_ticket_price':\n", " outdata = get_ticket_price(indata)\n", " input_name = \"destination city\"\n", " output_name = \"price\"\n", " elif function_name == 'book_flight':\n", " outdata = book_flight(indata)\n", " input_name = \"destination city\"\n", " output_name = \"booking code\"\n", " elif function_name == \"check_code\":\n", " outdata = check_code(indata)\n", " input_name = \"booking code\"\n", " output_name = \"validity\"\n", " elif function_name == \"artist\":\n", " image = artist(indata)\n", " outdata = f\"artistic rendition of {indata}\"\n", " input_name = \"city\"\n", " output_name = \"image\"\n", "\n", " responses.append({\n", " \"role\": \"tool\",\n", " \"content\": json.dumps({input_name: indata, output_name: outdata}),\n", " \"tool_call_id\": tool_call.id\n", " })\n", "\n", " return responses, image" ] }, { "cell_type": "code", "execution_count": null, "id": "505b585e-e9f9-4326-8455-184398bc82d1", "metadata": {}, "outputs": [], "source": [ "def talker(message):\n", " response = openai.audio.speech.create(\n", " model=\"tts-1\",\n", " voice=\"onyx\",\n", " input=message)\n", "\n", " audio_stream = BytesIO(response.content)\n", " output_filename = \"output_audio.mp3\"\n", " with open(output_filename, \"wb\") as f:\n", " f.write(audio_stream.read())\n", "\n", " # Play the generated audio\n", " display(Audio(output_filename, autoplay=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "a5a31bcf-71d5-4537-a7bf-92385dc6e26e", "metadata": {}, "outputs": [], "source": [ "## Gradio\n", "\n", "with gr.Blocks() as ui:\n", " with gr.Row():\n", " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", " image_output = gr.Image(height=500)\n", " with gr.Row():\n", " entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", " with gr.Row():\n", " clear = gr.Button(\"Clear\")\n", "\n", " def do_entry(message, history):\n", " history += [{\"role\":\"user\", \"content\":message}]\n", " return \"\", history\n", "\n", " entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n", " chat, inputs=chatbot, outputs=[chatbot, image_output]\n", " )\n", " clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n", "\n", "ui.launch(inbrowser=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }