Browse Source

Merge pull request #35 from sandeepgangaram/voice-to-text

week 2 solution with audio input
pull/34/merge
Ed Donner 5 months ago committed by GitHub
parent
commit
f7f150eaf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 475
      week2/community-contributions/week2_multimodal_chatbot_with_audio.ipynb

475
week2/community-contributions/week2_multimodal_chatbot_with_audio.ipynb

@ -0,0 +1,475 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ad900e1c-b4a9-4f05-93d5-e364fae208dd",
"metadata": {},
"source": [
"# Multimodal Expert Tutor\n",
"\n",
"An AI assistant which leverages expertise from other sources for you.\n",
"\n",
"Features:\n",
"- Multimodal\n",
"- Uses tools\n",
"- Streams responses\n",
"- Reads out the responses after streaming\n",
"- Coverts voice to text during input\n",
"\n",
"Scope for Improvement\n",
"- Read response faster (as streaming starts)\n",
"- code optimization\n",
"- UI enhancements\n",
"- Make it more real time"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c1070317-3ed9-4659-abe3-828943230e03",
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import os\n",
"import json\n",
"from dotenv import load_dotenv\n",
"from IPython.display import Markdown, display, update_display\n",
"from openai import OpenAI\n",
"import gradio as gr\n",
"import google.generativeai\n",
"import anthropic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a456906-915a-4bfd-bb9d-57e505c5093f",
"metadata": {},
"outputs": [],
"source": [
"# constants\n",
"\n",
"MODEL_GPT = 'gpt-4o-mini'\n",
"MODEL_CLAUDE = 'claude-3-5-sonnet-20240620'\n",
"MODEL_GEMINI = 'gemini-1.5-flash'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8d7923c-5f28-4c30-8556-342d7c8497c1",
"metadata": {},
"outputs": [],
"source": [
"# set up environment\n",
"\n",
"load_dotenv()\n",
"os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
"os.environ['GOOGLE_API_KEY'] = os.getenv('GOOGLE_API_KEY', 'your-key-if-not-using-env')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6fd8538-0be6-4539-8add-00e42133a641",
"metadata": {},
"outputs": [],
"source": [
"# Connect to OpenAI, Anthropic and Google\n",
"\n",
"openai = OpenAI()\n",
"\n",
"claude = anthropic.Anthropic()\n",
"\n",
"google.generativeai.configure()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "852faee9-79aa-4741-a676-4f5145ccccdc",
"metadata": {},
"outputs": [],
"source": [
"import tempfile\n",
"import subprocess\n",
"from io import BytesIO\n",
"from pydub import AudioSegment\n",
"import time\n",
"\n",
"def play_audio(audio_segment):\n",
" temp_dir = tempfile.gettempdir()\n",
" temp_path = os.path.join(temp_dir, \"temp_audio.wav\")\n",
" try:\n",
" audio_segment.export(temp_path, format=\"wav\")\n",
" subprocess.call([\n",
" \"ffplay\",\n",
" \"-nodisp\",\n",
" \"-autoexit\",\n",
" \"-hide_banner\",\n",
" temp_path\n",
" ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)\n",
" finally:\n",
" try:\n",
" os.remove(temp_path)\n",
" except Exception:\n",
" pass\n",
" \n",
"def talker(message):\n",
" response = openai.audio.speech.create(\n",
" model=\"tts-1\",\n",
" voice=\"onyx\", # Also, try replacing onyx with alloy\n",
" input=message\n",
" )\n",
" audio_stream = BytesIO(response.content)\n",
" audio = AudioSegment.from_file(audio_stream, format=\"mp3\")\n",
" play_audio(audio)\n",
"\n",
"talker(\"Well hi there\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8595807b-8ae2-4e1b-95d9-e8532142e8bb",
"metadata": {},
"outputs": [],
"source": [
"# prompts\n",
"general_prompt = \"Please be as technical as possible with your answers.\\\n",
"Only answer questions about topics you have expertise in.\\\n",
"If you do not know something say so.\"\n",
"\n",
"additional_prompt_gpt = \"Analyze the user query and determine if the content is primarily related to \\\n",
"coding, software engineering, data science and LLMs. \\\n",
"If so please answer it yourself else if it is primarily related to \\\n",
"physics, chemistry or biology get answers from tool ask_gemini or \\\n",
"if it belongs to subject related to finance, business or economics get answers from tool ask_claude.\"\n",
"\n",
"system_prompt_gpt = \"You are a helpful technical tutor who is an expert in \\\n",
"coding, software engineering, data science and LLMs.\"+ additional_prompt_gpt + general_prompt\n",
"system_prompt_gemini = \"You are a helpful technical tutor who is an expert in physics, chemistry and biology.\" + general_prompt\n",
"system_prompt_claude = \"You are a helpful technical tutor who is an expert in finance, business and economics.\" + general_prompt\n",
"\n",
"def get_user_prompt(question):\n",
" return \"Please give a detailed explanation to the following question: \" + question"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24d4a313-60b0-4696-b455-6cfef95ad2fe",
"metadata": {},
"outputs": [],
"source": [
"def call_claude(question):\n",
" result = claude.messages.create(\n",
" model=MODEL_CLAUDE,\n",
" max_tokens=200,\n",
" temperature=0.7,\n",
" system=system_prompt_claude,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": get_user_prompt(question)},\n",
" ],\n",
" )\n",
" \n",
" return result.content[0].text"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd5d5345-54ab-470b-9b5b-5611a7981458",
"metadata": {},
"outputs": [],
"source": [
"def call_gemini(question):\n",
" gemini = google.generativeai.GenerativeModel(\n",
" model_name=MODEL_GEMINI,\n",
" system_instruction=system_prompt_gemini\n",
" )\n",
" response = gemini.generate_content(get_user_prompt(question))\n",
" response = response.text\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f74da8f-56d1-405e-bc81-040f5428d296",
"metadata": {},
"outputs": [],
"source": [
"# tools and functions\n",
"\n",
"def ask_claude(question):\n",
" print(f\"Tool ask_claude called for {question}\")\n",
" return call_claude(question)\n",
"def ask_gemini(question):\n",
" print(f\"Tool ask_gemini called for {question}\")\n",
" return call_gemini(question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c469304d-99b4-42ee-ab02-c9216b61594b",
"metadata": {},
"outputs": [],
"source": [
"ask_claude_function = {\n",
" \"name\": \"ask_claude\",\n",
" \"description\": \"Get the answer to the question related to a topic this agent is faimiliar with. Call this whenever you need to answer something related to finance, marketing, sales or business in general.For example 'What is gross margin' or 'Explain stock market'\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"question_for_topic\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The question which is related to finance, business or economics.\",\n",
" },\n",
" },\n",
" \"required\": [\"question_for_topic\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}\n",
"\n",
"ask_gemini_function = {\n",
" \"name\": \"ask_gemini\",\n",
" \"description\": \"Get the answer to the question related to a topic this agent is faimiliar with. Call this whenever you need to answer something related to physics, chemistry or biology.Few examples: 'What is gravity','How do rockets work?', 'What is ATP'\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"question_for_topic\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The question which is related to physics, chemistry or biology\",\n",
" },\n",
" },\n",
" \"required\": [\"question_for_topic\"],\n",
" \"additionalProperties\": False\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "73a60096-c49b-401f-bfd3-d1d40f4563d2",
"metadata": {},
"outputs": [],
"source": [
"tools = [{\"type\": \"function\", \"function\": ask_claude_function},\n",
" {\"type\": \"function\", \"function\": ask_gemini_function}]\n",
"tools_functions_map = {\n",
" \"ask_claude\":ask_claude,\n",
" \"ask_gemini\":ask_gemini\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d54e758-42b2-42f2-a8eb-49c35d44acc6",
"metadata": {},
"outputs": [],
"source": [
"def chat(history):\n",
" messages = [{\"role\": \"system\", \"content\": system_prompt_gpt}] + history\n",
" stream = openai.chat.completions.create(model=MODEL_GPT, messages=messages, tools=tools, stream=True)\n",
" \n",
" full_response = \"\"\n",
" history += [{\"role\":\"assistant\", \"content\":full_response}]\n",
" \n",
" tool_call_accumulator = \"\" # Accumulator for JSON fragments of tool call arguments\n",
" tool_call_id = None # Current tool call ID\n",
" tool_call_function_name = None # Function name\n",
" tool_calls = [] # List to store complete tool calls\n",
"\n",
" for chunk in stream:\n",
" if chunk.choices[0].delta.content:\n",
" full_response += chunk.choices[0].delta.content or \"\"\n",
" history[-1]['content']=full_response\n",
" yield history\n",
" \n",
" if chunk.choices[0].delta.tool_calls:\n",
" message = chunk.choices[0].delta\n",
" for tc in chunk.choices[0].delta.tool_calls:\n",
" if tc.id: # New tool call detected here\n",
" tool_call_id = tc.id\n",
" if tool_call_function_name is None:\n",
" tool_call_function_name = tc.function.name\n",
" \n",
" tool_call_accumulator += tc.function.arguments if tc.function.arguments else \"\"\n",
" \n",
" # When the accumulated JSON string seems complete then:\n",
" try:\n",
" func_args = json.loads(tool_call_accumulator)\n",
" \n",
" # Handle tool call and get response\n",
" tool_response, tool_call = handle_tool_call(tool_call_function_name, func_args, tool_call_id)\n",
" \n",
" tool_calls.append(tool_call)\n",
"\n",
" # Add tool call and tool response to messages this is required by openAI api\n",
" messages.append({\n",
" \"role\": \"assistant\",\n",
" \"tool_calls\": tool_calls\n",
" })\n",
" messages.append(tool_response)\n",
" \n",
" # Create new response with full context\n",
" response = openai.chat.completions.create(\n",
" model=MODEL_GPT, \n",
" messages=messages, \n",
" stream=True\n",
" )\n",
" \n",
" # Reset and accumulate new full response\n",
" full_response = \"\"\n",
" for chunk in response:\n",
" if chunk.choices[0].delta.content:\n",
" full_response += chunk.choices[0].delta.content or \"\"\n",
" history[-1]['content'] = full_response\n",
" yield history\n",
" \n",
" # Reset tool call accumulator and related variables\n",
" tool_call_accumulator = \"\"\n",
" tool_call_id = None\n",
" tool_call_function_name = None\n",
" tool_calls = []\n",
"\n",
" except json.JSONDecodeError:\n",
" # Incomplete JSON; continue accumulating\n",
" pass\n",
"\n",
" # trigger text-to-audio once full response available\n",
" talker(full_response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "515d3774-cc2c-44cd-af9b-768a63ed90dc",
"metadata": {},
"outputs": [],
"source": [
"# We have to write that function handle_tool_call:\n",
"def handle_tool_call(function_name, arguments, tool_call_id):\n",
" question = arguments.get('question_for_topic')\n",
" \n",
" # Prepare tool call information\n",
" tool_call = {\n",
" \"id\": tool_call_id,\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": function_name,\n",
" \"arguments\": json.dumps(arguments)\n",
" }\n",
" }\n",
" \n",
" if function_name in tools_functions_map:\n",
" answer = tools_functions_map[function_name](question)\n",
" response = {\n",
" \"role\": \"tool\",\n",
" \"content\": json.dumps({\"question\": question, \"answer\" : answer}),\n",
" \"tool_call_id\": tool_call_id\n",
" }\n",
"\n",
" return response, tool_call"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d7cc622-8635-4693-afa3-b5bcc2f9a63d",
"metadata": {},
"outputs": [],
"source": [
"def transcribe_audio(audio_file_path):\n",
" try:\n",
" audio_file = open(audio_file_path, \"rb\")\n",
" response = openai.audio.transcriptions.create(model=\"whisper-1\", file=audio_file) \n",
" return response.text\n",
" except Exception as e:\n",
" return f\"An error occurred: {e}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ded9b3f-83e1-4971-9714-4894f2982b5a",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"with gr.Blocks() as ui:\n",
" with gr.Row():\n",
" chatbot = gr.Chatbot(height=500, type=\"messages\", label=\"Multimodal Technical Expert Chatbot\")\n",
" with gr.Row():\n",
" entry = gr.Textbox(label=\"Ask our technical expert anything:\")\n",
" audio_input = gr.Audio(\n",
" sources=\"microphone\", \n",
" type=\"filepath\",\n",
" label=\"Record audio\",\n",
" editable=False,\n",
" waveform_options=gr.WaveformOptions(\n",
" show_recording_waveform=False,\n",
" ),\n",
" )\n",
"\n",
" # Add event listener for audio stop recording and show text on input area\n",
" audio_input.stop_recording(\n",
" fn=transcribe_audio, \n",
" inputs=audio_input, \n",
" outputs=entry\n",
" )\n",
" \n",
" with gr.Row():\n",
" clear = gr.Button(\"Clear\")\n",
"\n",
" def do_entry(message, history):\n",
" history += [{\"role\":\"user\", \"content\":message}]\n",
" yield \"\", history\n",
" \n",
" entry.submit(do_entry, inputs=[entry, chatbot], outputs=[entry,chatbot]).then(\n",
" chat, inputs=chatbot, outputs=chatbot)\n",
" \n",
" clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)\n",
"\n",
"ui.launch(inbrowser=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "532cb948-7733-4323-b85f-febfe2631e66",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save