From 203df5ed4c9d10ea0cc7f0f9b1b4c44cf70ee502 Mon Sep 17 00:00:00 2001 From: Sandeep Gangaram Date: Wed, 11 Dec 2024 23:13:48 +0530 Subject: [PATCH] week 4 solution with audio input --- .../week2_solution_with_audio.ipynb | 461 ++++++++++++++++++ 1 file changed, 461 insertions(+) create mode 100644 week2/community-contributions/week2_solution_with_audio.ipynb diff --git a/week2/community-contributions/week2_solution_with_audio.ipynb b/week2/community-contributions/week2_solution_with_audio.ipynb new file mode 100644 index 0000000..97a8e2c --- /dev/null +++ b/week2/community-contributions/week2_solution_with_audio.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c1070317-3ed9-4659-abe3-828943230e03", + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "\n", + "import os\n", + "import json\n", + "from dotenv import load_dotenv\n", + "from IPython.display import Markdown, display, update_display\n", + "from openai import OpenAI\n", + "import gradio as gr\n", + "import google.generativeai\n", + "import anthropic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a456906-915a-4bfd-bb9d-57e505c5093f", + "metadata": {}, + "outputs": [], + "source": [ + "# constants\n", + "\n", + "MODEL_GPT = 'gpt-4o-mini'\n", + "MODEL_CLAUDE = 'claude-3-5-sonnet-20240620'\n", + "MODEL_GEMINI = 'gemini-1.5-flash'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d7923c-5f28-4c30-8556-342d7c8497c1", + "metadata": {}, + "outputs": [], + "source": [ + "# set up environment\n", + "\n", + "load_dotenv()\n", + "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n", + "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", + "os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6fd8538-0be6-4539-8add-00e42133a641", + "metadata": {}, + "outputs": [], + "source": [ + "# Connect to OpenAI, Anthropic and Google\n", + "\n", + "openai = OpenAI()\n", + "\n", + "claude = anthropic.Anthropic()\n", + "\n", + "google.generativeai.configure()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "852faee9-79aa-4741-a676-4f5145ccccdc", + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "import subprocess\n", + "from io import BytesIO\n", + "from pydub import AudioSegment\n", + "import time\n", + "\n", + "def play_audio(audio_segment):\n", + " temp_dir = tempfile.gettempdir()\n", + " temp_path = os.path.join(temp_dir, \"temp_audio.wav\")\n", + " try:\n", + " audio_segment.export(temp_path, format=\"wav\")\n", + " subprocess.call([\n", + " \"ffplay\",\n", + " \"-nodisp\",\n", + " \"-autoexit\",\n", + " \"-hide_banner\",\n", + " temp_path\n", + " ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)\n", + " finally:\n", + " try:\n", + " os.remove(temp_path)\n", + " except Exception:\n", + " pass\n", + " \n", + "def talker(message):\n", + " response = openai.audio.speech.create(\n", + " model=\"tts-1\",\n", + " voice=\"onyx\", # Also, try replacing onyx with alloy\n", + " input=message\n", + " )\n", + " audio_stream = BytesIO(response.content)\n", + " audio = AudioSegment.from_file(audio_stream, format=\"mp3\")\n", + " play_audio(audio)\n", + "\n", + "talker(\"Well hi there\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8595807b-8ae2-4e1b-95d9-e8532142e8bb", + "metadata": {}, + "outputs": [], + "source": [ + "# prompts\n", + "general_prompt = \"Please be as technical as possible with your answers.\\\n", + "Only answer questions about topics you have expertise in.\\\n", + "If you do not know something say so.\"\n", + "\n", + "additional_prompt_gpt = \"Analyze the user query and determine if the content is primarily related to \\\n", + "coding, software engineering, data science and LLMs. \\\n", + "If so please answer it yourself else if it is primarily related to \\\n", + "physics, chemistry or biology get answers from tool ask_gemini or \\\n", + "if it belongs to subject related to finance, business or economics get answers from tool ask_claude.\"\n", + "\n", + "system_prompt_gpt = \"You are a helpful technical tutor who is an expert in \\\n", + "coding, software engineering, data science and LLMs.\"+ additional_prompt_gpt + general_prompt\n", + "system_prompt_gemini = \"You are a helpful technical tutor who is an expert in physics, chemistry and biology.\" + general_prompt\n", + "system_prompt_claude = \"You are a helpful technical tutor who is an expert in finance, business and economics.\" + general_prompt\n", + "\n", + "def get_user_prompt(question):\n", + " return \"Please give a detailed explanation to the following question: \" + question" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24d4a313-60b0-4696-b455-6cfef95ad2fe", + "metadata": {}, + "outputs": [], + "source": [ + "def call_claude(question):\n", + " result = claude.messages.create(\n", + " model=MODEL_CLAUDE,\n", + " max_tokens=200,\n", + " temperature=0.7,\n", + " system=system_prompt_claude,\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": get_user_prompt(question)},\n", + " ],\n", + " )\n", + " \n", + " return result.content[0].text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd5d5345-54ab-470b-9b5b-5611a7981458", + "metadata": {}, + "outputs": [], + "source": [ + "def call_gemini(question):\n", + " gemini = google.generativeai.GenerativeModel(\n", + " model_name=MODEL_GEMINI,\n", + " system_instruction=system_prompt_gemini\n", + " )\n", + " response = gemini.generate_content(get_user_prompt(question))\n", + " response = response.text\n", + " return response" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f74da8f-56d1-405e-bc81-040f5428d296", + "metadata": {}, + "outputs": [], + "source": [ + "# tools and functions\n", + "\n", + "def ask_claude(question):\n", + " print(f\"Tool ask_claude called for {question}\")\n", + " return call_claude(question)\n", + "def ask_gemini(question):\n", + " print(f\"Tool ask_gemini called for {question}\")\n", + " return call_gemini(question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c469304d-99b4-42ee-ab02-c9216b61594b", + "metadata": {}, + "outputs": [], + "source": [ + "ask_claude_function = {\n", + " \"name\": \"ask_claude\",\n", + " \"description\": \"Get the answer to the question related to a topic this agent is faimiliar with. Call this whenever you need to answer something related to finance, marketing, sales or business in general.For example 'What is gross margin' or 'Explain stock market'\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"question_for_topic\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The question which is related to finance, business or economics.\",\n", + " },\n", + " },\n", + " \"required\": [\"question_for_topic\"],\n", + " \"additionalProperties\": False\n", + " }\n", + "}\n", + "\n", + "ask_gemini_function = {\n", + " \"name\": \"ask_gemini\",\n", + " \"description\": \"Get the answer to the question related to a topic this agent is faimiliar with. Call this whenever you need to answer something related to physics, chemistry or biology.Few examples: 'What is gravity','How do rockets work?', 'What is ATP'\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"question_for_topic\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The question which is related to physics, chemistry or biology\",\n", + " },\n", + " },\n", + " \"required\": [\"question_for_topic\"],\n", + " \"additionalProperties\": False\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73a60096-c49b-401f-bfd3-d1d40f4563d2", + "metadata": {}, + "outputs": [], + "source": [ + "tools = [{\"type\": \"function\", \"function\": ask_claude_function},\n", + " {\"type\": \"function\", \"function\": ask_gemini_function}]\n", + "tools_functions_map = {\n", + " \"ask_claude\":ask_claude,\n", + " \"ask_gemini\":ask_gemini\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d54e758-42b2-42f2-a8eb-49c35d44acc6", + "metadata": {}, + "outputs": [], + "source": [ + "def chat(history):\n", + " messages = [{\"role\": \"system\", \"content\": system_prompt_gpt}] + history\n", + " stream = openai.chat.completions.create(model=MODEL_GPT, messages=messages, tools=tools, stream=True)\n", + " \n", + " full_response = \"\"\n", + " history += [{\"role\":\"assistant\", \"content\":full_response}]\n", + " \n", + " tool_call_accumulator = \"\" # Accumulator for JSON fragments of tool call arguments\n", + " tool_call_id = None # Current tool call ID\n", + " tool_call_function_name = None # Function name\n", + " tool_calls = [] # List to store complete tool calls\n", + "\n", + " for chunk in stream:\n", + " if chunk.choices[0].delta.content:\n", + " full_response += chunk.choices[0].delta.content or \"\"\n", + " history[-1]['content']=full_response\n", + " yield history\n", + " \n", + " if chunk.choices[0].delta.tool_calls:\n", + " message = chunk.choices[0].delta\n", + " for tc in chunk.choices[0].delta.tool_calls:\n", + " if tc.id: # New tool call detected here\n", + " tool_call_id = tc.id\n", + " if tool_call_function_name is None:\n", + " tool_call_function_name = tc.function.name\n", + " \n", + " tool_call_accumulator += tc.function.arguments if tc.function.arguments else \"\"\n", + " \n", + " # When the accumulated JSON string seems complete then:\n", + " try:\n", + " func_args = json.loads(tool_call_accumulator)\n", + " \n", + " # Handle tool call and get response\n", + " tool_response, tool_call = handle_tool_call(tool_call_function_name, func_args, tool_call_id)\n", + " \n", + " tool_calls.append(tool_call)\n", + "\n", + " # Add tool call and tool response to messages this is required by openAI api\n", + " messages.append({\n", + " \"role\": \"assistant\",\n", + " \"tool_calls\": tool_calls\n", + " })\n", + " messages.append(tool_response)\n", + " \n", + " # Create new response with full context\n", + " response = openai.chat.completions.create(\n", + " model=MODEL_GPT, \n", + " messages=messages, \n", + " stream=True\n", + " )\n", + " \n", + " # Reset and accumulate new full response\n", + " full_response = \"\"\n", + " for chunk in response:\n", + " if chunk.choices[0].delta.content:\n", + " full_response += chunk.choices[0].delta.content or \"\"\n", + " history[-1]['content'] = full_response\n", + " yield history\n", + " \n", + " # Reset tool call accumulator and related variables\n", + " tool_call_accumulator = \"\"\n", + " tool_call_id = None\n", + " tool_call_function_name = None\n", + " tool_calls = []\n", + "\n", + " except json.JSONDecodeError:\n", + " # Incomplete JSON; continue accumulating\n", + " pass\n", + "\n", + " # trigger text-to-audio once full response available\n", + " talker(full_response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "515d3774-cc2c-44cd-af9b-768a63ed90dc", + "metadata": {}, + "outputs": [], + "source": [ + "# We have to write that function handle_tool_call:\n", + "def handle_tool_call(function_name, arguments, tool_call_id):\n", + " question = arguments.get('question_for_topic')\n", + " \n", + " # Prepare tool call information\n", + " tool_call = {\n", + " \"id\": tool_call_id,\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": function_name,\n", + " \"arguments\": json.dumps(arguments)\n", + " }\n", + " }\n", + " \n", + " if function_name in tools_functions_map:\n", + " answer = tools_functions_map[function_name](question)\n", + " response = {\n", + " \"role\": \"tool\",\n", + " \"content\": json.dumps({\"question\": question, \"answer\" : answer}),\n", + " \"tool_call_id\": tool_call_id\n", + " }\n", + "\n", + " return response, tool_call" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d7cc622-8635-4693-afa3-b5bcc2f9a63d", + "metadata": {}, + "outputs": [], + "source": [ + "def transcribe_audio(audio_file_path):\n", + " try:\n", + " audio_file = open(audio_file_path, \"rb\")\n", + " response = openai.audio.transcriptions.create(model=\"whisper-1\", file=audio_file) \n", + " return response.text\n", + " except Exception as e:\n", + " return f\"An error occurred: {e}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ded9b3f-83e1-4971-9714-4894f2982b5a", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "with gr.Blocks() as ui:\n", + " with gr.Row():\n", + " chatbot = gr.Chatbot(height=500, type=\"messages\")\n", + " # image_output = gr.Image(height=500)\n", + " with gr.Row():\n", + " entry = gr.Textbox(label=\"Ask our technical expert anything:\")\n", + " audio_input = gr.Audio(\n", + " sources=\"microphone\", \n", + " type=\"filepath\",\n", + " label=\"Record audio\",\n", + " editable=False,\n", + " waveform_options=gr.WaveformOptions(\n", + " show_recording_waveform=False,\n", + " ),\n", + " )\n", + "\n", + " # Add event listener for audio stop recording and show text on input area\n", + " audio_input.stop_recording(\n", + " fn=transcribe_audio, \n", + " inputs=audio_input, \n", + " outputs=entry\n", + " )\n", + " \n", + " with gr.Row():\n", + " clear = gr.Button(\"Clear\")\n", + "\n", + " def do_entry(message, history):\n", + " history += [{\"role\":\"user\", \"content\":message}]\n", + " yield \"\", history\n", + " \n", + " entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry,chatbot]).then(\n", + " chat, inputs=chatbot, outputs=chatbot)\n", + " \n", + " clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n", + "\n", + "ui.launch(inbrowser=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da663d73-dd2a-4fff-84df-2209cf2b330b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "532cb948-7733-4323-b85f-febfe2631e66", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}