1 changed files with 313 additions and 0 deletions
@ -0,0 +1,313 @@
|
||||
{ |
||||
"cells": [ |
||||
{ |
||||
"cell_type": "markdown", |
||||
"id": "ae6f09ff-a7f8-411f-b3e5-9ae502af35c6", |
||||
"metadata": {}, |
||||
"source": [ |
||||
"### To do:\n", |
||||
"\n", |
||||
"- [x] get some extra **Python practice**\n", |
||||
"- [x] get lots of **tool practice**\n", |
||||
"- [x] increase your **Gradio proficiency**\n", |
||||
"- [ ] try **picture output** (on command only!)\n", |
||||
"- [ ] try **audio output** (on command only?)\n", |
||||
"- [ ] try **audio input?** (most importantly, do anything you saw in the video; the rest is nice-to-have)\n", |
||||
"_Extra: delve into Claude's function calling documentation (can be left for later)_" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "36e0cd9c-6622-4fa9-a4f8-b3da1b9b836e", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"import os\n", |
||||
"import json\n", |
||||
"from dotenv import load_dotenv\n", |
||||
"from openai import OpenAI\n", |
||||
"import gradio as gr\n", |
||||
"import random\n", |
||||
"import re" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "57fc95b9-043c-4a38-83aa-365cc3b285ba", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"load_dotenv()\n", |
||||
"\n", |
||||
"openai_api_key = os.getenv('OPENAI_API_KEY')\n", |
||||
"if openai_api_key:\n", |
||||
" print(f\"OpenAI API Key exists and begins with {openai_api_key[:8]}\")\n", |
||||
"else:\n", |
||||
" print(\"OpenAI API Key? As if!\")\n", |
||||
" \n", |
||||
"MODEL = \"gpt-4o-mini\"\n", |
||||
"openai = OpenAI()" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "e633ee2a-bbaa-47a4-95ef-b1d8773866aa", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"system_message = \"You are a helpful assistant for an Airline called FlightAI. \"\n", |
||||
"system_message += \"Give short, courteous answers, no more than 1 sentence. \"\n", |
||||
"system_message += \"Always be accurate. If you don't know the answer, say so. \"" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "c123af78-b5d6-4cc9-8f18-c492b1f30c85", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"# ticket price function\n", |
||||
"\n", |
||||
"ticket_prices = {\"valletta\": \"799 $\", \"turin\": \"899 $\", \"sacramento\": \"1400 $\", \"montreal\": \"499 $\"} #awkward currency for better tts rendition\n", |
||||
"\n", |
||||
"def get_ticket_price(destination_city):\n", |
||||
" print(f\"Tool get_ticket_price called for {destination_city}\")\n", |
||||
" city = destination_city.lower()\n", |
||||
" return ticket_prices.get(city, \"Unknown\")" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "00e486fb-709e-4b8e-a029-9e2b225ddc25", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"# travel booking function\n", |
||||
"\n", |
||||
"def book_flight(destination_city):\n", |
||||
" booking_code = ''.join(random.choice('0123456789BCDFXYZ') for i in range(2)) + ''.join(random.choice('012346789HIJKLMNOPQRS') for i in range(2)) + ''.join(random.choice('0123456789GHIJKLMNUOP') for i in range(2))\n", |
||||
" print(f\"Booking code {booking_code} generated for flight to {destination_city}.\")\n", |
||||
" \n", |
||||
" return booking_code" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "c0600b4e-fa4e-4c34-b317-fac1e60b5f95", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"# verify if booking code is correct\n", |
||||
"\n", |
||||
"def check_code(code):\n", |
||||
" valid = \"valid\" if re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None else \"not valid\"\n", |
||||
" println(f\"Code checker called for code {code}, which is {valid}.\")\n", |
||||
" return re.match(\"^[0123456789BCDFXYZ]{2}[012346789HIJKLMNOPQRS]{2}[0123456789GHIJKLMNUOP]{2}$\", code) != None" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "626d99af-90de-4594-9ffd-b87a8b6ef4fd", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"price_function = {\n", |
||||
" \"name\": \"get_ticket_price\",\n", |
||||
" \"description\": \"Get the price of a return ticket to the destination city. Call this whenever you need to know the ticket price, for example when a customer asks 'How much is a ticket to this city'\",\n", |
||||
" \"parameters\": {\n", |
||||
" \"type\": \"object\",\n", |
||||
" \"properties\": {\n", |
||||
" \"destination_city\": {\n", |
||||
" \"type\": \"string\",\n", |
||||
" \"description\": \"The city that the customer wants to travel to\",\n", |
||||
" },\n", |
||||
" },\n", |
||||
" \"required\": [\"destination_city\"],\n", |
||||
" \"additionalProperties\": False\n", |
||||
" }\n", |
||||
"}" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "6e7bc09c-665b-4885-823c-f145cefe8c23", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"booking_function = {\n", |
||||
" \"name\": \"book_flight\",\n", |
||||
" \"description\": \"Call this whenever you have to book a flight. Give it the destination city and you will get a booking code. Tell the customer \\\n", |
||||
"that the flight is booked and give them the booking code obtained through this function. Never give any other codes to the customer.\",\n", |
||||
" \"parameters\": {\n", |
||||
" \"type\": \"object\",\n", |
||||
" \"properties\": {\n", |
||||
" \"destination_city\": {\n", |
||||
" \"type\": \"string\",\n", |
||||
" \"description\": \"The city that the customer wants to book their flight to\",\n", |
||||
" },\n", |
||||
" },\n", |
||||
" \"required\": [\"destination_city\"],\n", |
||||
" \"additionalProperties\": False\n", |
||||
" }\n", |
||||
"}" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "99b0a0e3-db44-49f9-8d27-349b9f04c680", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"codecheck_function = {\n", |
||||
" \"name\": \"check_code\",\n", |
||||
" \"description\": \"Call this whenever you need to verify if a booking code for a flight (also called 'flight code', 'booking reference', \\\n", |
||||
"or variations thereof) is valid.\",\n", |
||||
" \"parameters\": {\n", |
||||
" \"type\": \"object\",\n", |
||||
" \"properties\": {\n", |
||||
" \"code\": {\n", |
||||
" \"type\": \"string\",\n", |
||||
" \"description\": \"The code that you or the user needs to verify\",\n", |
||||
" },\n", |
||||
" },\n", |
||||
" \"required\": [\"code\"],\n", |
||||
" \"additionalProperties\": False\n", |
||||
" }\n", |
||||
"}" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "3fa371c4-91ff-41ae-9b10-23fe617022d1", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"# List of tools:\n", |
||||
"\n", |
||||
"tools = [{\"type\": \"function\", \"function\": price_function}, {\"type\": \"function\", \"function\": booking_function}, {\"type\": \"function\", \"function\": codecheck_function}]" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "4d34942a-f0c7-4835-ba07-746104a8c524", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"def chat(history):\n", |
||||
" messages = [{\"role\": \"system\", \"content\": system_message}] + history\n", |
||||
" response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", |
||||
" image = None\n", |
||||
" \n", |
||||
" if response.choices[0].finish_reason==\"tool_calls\":\n", |
||||
" message = response.choices[0].message\n", |
||||
" responses = handle_tool_call(message)\n", |
||||
" messages.append(message)\n", |
||||
" for response in responses:\n", |
||||
" messages.append(response)\n", |
||||
" response = openai.chat.completions.create(model=MODEL, messages=messages)\n", |
||||
" \n", |
||||
" reply = response.choices[0].message.content\n", |
||||
" history += [{\"role\": \"assistant\", \"content\": reply}]\n", |
||||
" \n", |
||||
" return history, image" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "5413f7fb-c5f7-44c4-a63d-3d0465eb0af4", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"def handle_tool_call(message):\n", |
||||
" responses = []\n", |
||||
" for tool_call in message.tool_calls:\n", |
||||
" arguments = json.loads(tool_call.function.arguments)\n", |
||||
" indata = arguments[list(arguments.keys())[0]] # works for now because we only have one argument in each of our functions\n", |
||||
" function_name = tool_call.function.name\n", |
||||
" if function_name == 'get_ticket_price':\n", |
||||
" outdata = get_ticket_price(indata)\n", |
||||
" input_name = \"destination city\"\n", |
||||
" output_name = \"price\"\n", |
||||
" elif function_name == 'book_travel':\n", |
||||
" outdata = book_flight(indata)\n", |
||||
" input_name = \"destination city\"\n", |
||||
" output_name = \"booking code\"\n", |
||||
" else:\n", |
||||
" outdata = check_code(indata)\n", |
||||
" input_name = \"booking code\"\n", |
||||
" output_name = \"validity\"\n", |
||||
" \n", |
||||
" responses.append({\n", |
||||
" \"role\": \"tool\",\n", |
||||
" \"content\": json.dumps({input_name: indata, output_name: outdata}),\n", |
||||
" \"tool_call_id\": tool_call.id\n", |
||||
" })\n", |
||||
" return responses" |
||||
] |
||||
}, |
||||
{ |
||||
"cell_type": "code", |
||||
"execution_count": null, |
||||
"id": "a5a31bcf-71d5-4537-a7bf-92385dc6e26e", |
||||
"metadata": {}, |
||||
"outputs": [], |
||||
"source": [ |
||||
"# More involved Gradio code as we're not using the preset Chat interface!\n", |
||||
"# Passing in inbrowser=True in the last line will cause a Gradio window to pop up immediately.\n", |
||||
"\n", |
||||
"with gr.Blocks() as ui:\n", |
||||
" with gr.Row():\n", |
||||
" chatbot = gr.Chatbot(height=500, type=\"messages\")\n", |
||||
" image_output = gr.Image(height=500)\n", |
||||
" with gr.Row():\n", |
||||
" entry = gr.Textbox(label=\"Chat with our AI Assistant:\")\n", |
||||
" with gr.Row():\n", |
||||
" clear = gr.Button(\"Clear\")\n", |
||||
"\n", |
||||
" def do_entry(message, history):\n", |
||||
" history += [{\"role\":\"user\", \"content\":message}]\n", |
||||
" return \"\", history\n", |
||||
"\n", |
||||
" entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry, chatbot]).then(\n", |
||||
" chat, inputs=chatbot, outputs=[chatbot, image_output]\n", |
||||
" )\n", |
||||
" clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n", |
||||
"\n", |
||||
"ui.launch(inbrowser=True)" |
||||
] |
||||
} |
||||
], |
||||
"metadata": { |
||||
"kernelspec": { |
||||
"display_name": "Python 3 (ipykernel)", |
||||
"language": "python", |
||||
"name": "python3" |
||||
}, |
||||
"language_info": { |
||||
"codemirror_mode": { |
||||
"name": "ipython", |
||||
"version": 3 |
||||
}, |
||||
"file_extension": ".py", |
||||
"mimetype": "text/x-python", |
||||
"name": "python", |
||||
"nbconvert_exporter": "python", |
||||
"pygments_lexer": "ipython3", |
||||
"version": "3.11.11" |
||||
} |
||||
}, |
||||
"nbformat": 4, |
||||
"nbformat_minor": 5 |
||||
} |
Loading…
Reference in new issue