diff --git a/week2/day4.ipynb b/week2/day4.ipynb
index 06c3904..a3f1a5e 100644
--- a/week2/day4.ipynb
+++ b/week2/day4.ipynb
@@ -12,7 +12,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"id": "8b50bbe2-c0b1-49c3-9a5c-1ba7efa2bcb4",
"metadata": {},
"outputs": [],
@@ -28,10 +28,18 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"id": "747e8786-9da8-4342-b6c9-f5f69c2e22ae",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "OpenAI API Key exists and begins sk-proj-\n"
+ ]
+ }
+ ],
"source": [
"# Initialization\n",
"\n",
@@ -49,7 +57,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"id": "0a521d84-d07c-49ab-a0df-d6451499ed97",
"metadata": {},
"outputs": [],
@@ -61,10 +69,40 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 14,
"id": "61a2a15d-b559-4844-b377-6bd5cb4949f6",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "* Running on local URL: http://127.0.0.1:7861\n",
+ "\n",
+ "To create a public link, set `share=True` in `launch()`.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": []
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# This function looks rather simpler than the one from my video, because we're taking advantage of the latest Gradio updates\n",
"\n",
@@ -94,7 +132,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 15,
"id": "0696acb1-0b05-4dc2-80d5-771be04f1fb2",
"metadata": {},
"outputs": [],
@@ -111,17 +149,35 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"id": "80ca4e09-6287-4d3f-997d-fa6afbcf6c85",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tool get_ticket_price called for Berlin\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'$499'"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"get_ticket_price(\"Berlin\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"id": "4afceded-7178-4c05-8fa6-9f2085e6a344",
"metadata": {},
"outputs": [],
@@ -147,7 +203,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"id": "bdca8679-935f-4e7f-97e6-e71a4d4f228c",
"metadata": {},
"outputs": [],
@@ -173,7 +229,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"id": "ce9b0744-9c78-408d-b9df-9f6fd9ed78cf",
"metadata": {},
"outputs": [],
@@ -194,7 +250,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
"id": "b0992986-ea09-4912-a076-8e5603ee631f",
"metadata": {},
"outputs": [],
@@ -216,10 +272,213 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"id": "f4be8a71-b19e-4c2f-80df-f59ff2661f14",
"metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "* Running on local URL: http://127.0.0.1:7862\n",
+ "\n",
+ "To create a public link, set `share=True` in `launch()`.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": []
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gr.ChatInterface(fn=chat, type=\"messages\").launch()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "11c9da69-d0cf-4cf2-a49e-e5669deec47b",
+ "metadata": {},
"outputs": [],
+ "source": [
+ "#Assignment -- adding tool to book a flight\n",
+ "\n",
+ "flight_booking_status = {\"london\": \"booked\", \"paris\": \"cannot be booked\", \"tokyo\": \"pending\", \"berlin\": \"payment issue\" }\n",
+ "\n",
+ "#this is a tool function that will be used by a GPT model\n",
+ "\n",
+ "def get_flight_booking_status(destination_city):\n",
+ " print(f\"Tool get_flight_booking_status for {destination_city}\")\n",
+ " city = destination_city.lower()\n",
+ " return flight_booking_status.get(city, \"Unknown\")\n",
+ " \n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "073207fb-a161-49e0-9ca8-60ccf588aa72",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tool get_flight_booking_status for berlin\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'payment issue'"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "#test the tool \n",
+ "get_flight_booking_status(\"berlin\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "76bc29f6-1bfc-4c4b-afb0-1b678b5b02d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Describe metadata about the above tool function so that model can understand it\n",
+ "\n",
+ "flight_booking_status_function = {\n",
+ " \"name\": \"get_flight_booking_status\",\n",
+ " \"description\": \"Get the status to flight booking. Show your answer in polite way, like 'flight booked or cannotn be booked etc...'\",\n",
+ " \"parameters\": {\n",
+ " \"type\": \"object\",\n",
+ " \"properties\": {\n",
+ " \"destination_city\": {\n",
+ " \"type\": \"string\",\n",
+ " \"description\": \"The city that the customer wants to book a flight for\",\n",
+ " },\n",
+ " },\n",
+ " \"required\": [\"destination_city\"],\n",
+ " \"additionalProperties\": False\n",
+ " }\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "221e2406-f19c-4ca6-a384-9876904be1b5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#add the above booking status function to the tools list\n",
+ "tools = [{\"type\": \"function\", \"function\": flight_booking_status_function}]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "c067cd26-562d-487c-9d32-e2e8cdd17796",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#now let openAI use our tool to find the flight booking status\n",
+ "\n",
+ "def chat(message, history):\n",
+ " messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
+ " response = openai.chat.completions.create(model=MODEL, messages=messages, tools=tools)\n",
+ "\n",
+ " if response.choices[0].finish_reason==\"tool_calls\":\n",
+ " message = response.choices[0].message\n",
+ " response, city = handle_flight_status_tool_call(message)\n",
+ " messages.append(message)\n",
+ " messages.append(response)\n",
+ " response = openai.chat.completions.create(model=MODEL, messages=messages)\n",
+ " \n",
+ " return response.choices[0].message.content\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "5bb6313b-d84c-4a77-b2f1-efd17554a553",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#function to handle call to our tool\n",
+ "\n",
+ "# We have to write that function handle_tool_call:\n",
+ "\n",
+ "def handle_flight_status_tool_call(message):\n",
+ " tool_call = message.tool_calls[0]\n",
+ " arguments = json.loads(tool_call.function.arguments)\n",
+ " city = arguments.get('destination_city')\n",
+ " status = get_flight_booking_status(city)\n",
+ " response = {\n",
+ " \"role\": \"tool\",\n",
+ " \"content\": json.dumps({\"destination_city\": city,\"status\": status}),\n",
+ " \"tool_call_id\": message.tool_calls[0].id\n",
+ " }\n",
+ " return response, city"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "b184d72c-da31-43ad-b7bf-27cf4525b35b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "* Running on local URL: http://127.0.0.1:7863\n",
+ "\n",
+ "To create a public link, set `share=True` in `launch()`.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": []
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"gr.ChatInterface(fn=chat, type=\"messages\").launch()"
]
@@ -227,7 +486,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "11c9da69-d0cf-4cf2-a49e-e5669deec47b",
+ "id": "648a35d3-0365-48ce-958e-7531a98f9b8a",
"metadata": {},
"outputs": [],
"source": []