From 851fb42c2f4731e2e9adf6f5ddd2f843cf6f5813 Mon Sep 17 00:00:00 2001 From: Sanjay Semwal Date: Mon, 2 Dec 2024 13:31:19 -0800 Subject: [PATCH] added flght status using tools: --- week2/day4.ipynb | 289 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 274 insertions(+), 15 deletions(-) diff --git a/week2/day4.ipynb b/week2/day4.ipynb index 06c3904..a3f1a5e 100644 --- a/week2/day4.ipynb +++ b/week2/day4.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4", "metadata": {}, "outputs": [], @@ -28,10 +28,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key exists and begins sk-proj-\n" + ] + } + ], "source": [ "# Initialization\n", "\n", @@ -49,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "0a521d84-d07c-49ab-a0df-d6451499ed97", "metadata": {}, "outputs": [], @@ -61,10 +69,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "61a2a15d-b559-4844-b377-6bd5cb4949f6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7861\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# This function looks rather simpler than the one from my video, because we're taking advantage of the latest Gradio updates\n", "\n", @@ -94,7 +132,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "0696acb1-0b05-4dc2-80d5-771be04f1fb2", "metadata": {}, "outputs": [], @@ -111,17 +149,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "80ca4e09-6287-4d3f-997d-fa6afbcf6c85", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tool get_ticket_price called for Berlin\n" + ] + }, + { + "data": { + "text/plain": [ + "'$499'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "get_ticket_price(\"Berlin\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "4afceded-7178-4c05-8fa6-9f2085e6a344", "metadata": {}, "outputs": [], @@ -147,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "bdca8679-935f-4e7f-97e6-e71a4d4f228c", "metadata": {}, "outputs": [], @@ -173,7 +229,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "id": "ce9b0744-9c78-408d-b9df-9f6fd9ed78cf", "metadata": {}, "outputs": [], @@ -194,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "b0992986-ea09-4912-a076-8e5603ee631f", "metadata": {}, "outputs": [], @@ -216,10 +272,213 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "f4be8a71-b19e-4c2f-80df-f59ff2661f14", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7862\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gr.ChatInterface(fn=chat, type=\"messages\").launch()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "11c9da69-d0cf-4cf2-a49e-e5669deec47b", + "metadata": {}, "outputs": [], + "source": [ + "#Assignment -- adding tool to book a flight\n", + "\n", + "flight_booking_status = {\"london\": \"booked\", \"paris\": \"cannot be booked\", \"tokyo\": \"pending\", \"berlin\": \"payment issue\" }\n", + "\n", + "#this is a tool function that will be used by a GPT model\n", + "\n", + "def get_flight_booking_status(destination_city):\n", + " print(f\"Tool get_flight_booking_status for {destination_city}\")\n", + " city = destination_city.lower()\n", + " return flight_booking_status.get(city, \"Unknown\")\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "073207fb-a161-49e0-9ca8-60ccf588aa72", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tool get_flight_booking_status for berlin\n" + ] + }, + { + "data": { + "text/plain": [ + "'payment issue'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#test the tool \n", + "get_flight_booking_status(\"berlin\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "76bc29f6-1bfc-4c4b-afb0-1b678b5b02d4", + "metadata": {}, + "outputs": [], + "source": [ + "# Describe metadata about the above tool function so that model can understand it\n", + "\n", + "flight_booking_status_function = {\n", + " \"name\": \"get_flight_booking_status\",\n", + " \"description\": \"Get the status to flight booking. Show your answer in polite way, like 'flight booked or cannotn be booked etc...'\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"destination_city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city that the customer wants to book a flight for\",\n", + " },\n", + " },\n", + " \"required\": [\"destination_city\"],\n", + " \"additionalProperties\": False\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "221e2406-f19c-4ca6-a384-9876904be1b5", + "metadata": {}, + "outputs": [], + "source": [ + "#add the above booking status function to the tools list\n", + "tools = [{\"type\": \"function\", \"function\": flight_booking_status_function}]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c067cd26-562d-487c-9d32-e2e8cdd17796", + "metadata": {}, + "outputs": [], + "source": [ + "#now let openAI use our tool to find the flight booking status\n", + "\n", + "def chat(message, history):\n", + " messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n", + "\n", + " if response.choices[0].finish_reason==\"tool_calls\":\n", + " message = response.choices[0].message\n", + " response, city = handle_flight_status_tool_call(message)\n", + " messages.append(message)\n", + " messages.append(response)\n", + " response = openai.chat.completions.create(model=MODEL, messages=messages)\n", + " \n", + " return response.choices[0].message.content\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "5bb6313b-d84c-4a77-b2f1-efd17554a553", + "metadata": {}, + "outputs": [], + "source": [ + "#function to handle call to our tool\n", + "\n", + "# We have to write that function handle_tool_call:\n", + "\n", + "def handle_flight_status_tool_call(message):\n", + " tool_call = message.tool_calls[0]\n", + " arguments = json.loads(tool_call.function.arguments)\n", + " city = arguments.get('destination_city')\n", + " status = get_flight_booking_status(city)\n", + " response = {\n", + " \"role\": \"tool\",\n", + " \"content\": json.dumps({\"destination_city\": city,\"status\": status}),\n", + " \"tool_call_id\": message.tool_calls[0].id\n", + " }\n", + " return response, city" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "b184d72c-da31-43ad-b7bf-27cf4525b35b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7863\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "gr.ChatInterface(fn=chat, type=\"messages\").launch()" ] @@ -227,7 +486,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11c9da69-d0cf-4cf2-a49e-e5669deec47b", + "id": "648a35d3-0365-48ce-958e-7531a98f9b8a", "metadata": {}, "outputs": [], "source": []