Browse Source

Complete exercise week 2, day 1: 3-way conversation between AI's

- Make claude and two gpt's have a conversation
pull/292/head
Sandra Neuhäußer 1 month ago
parent
commit
c68bde8ba5
  1. 363
      week2/community-contributions/day1_three_party_chat.ipynb

363
week2/community-contributions/day1_three_party_chat.ipynb

@ -0,0 +1,363 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5df0164c-1980-4fd7-94e4-a71b485a41fd",
"metadata": {},
"source": [
"# Week 2 Day 1 - Conversation between three AI's\n",
"\n",
"This notebook defines three classes (`ThreePartyChat`, `Participant` and `Model`) that implement a 3-party chat between different AI's. \n",
"\n",
"At the bottom there is an example conversation between a Claude model and two GPT models.\n",
"\n",
"The implementation works with models available via the `openai` and `anthropic` libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b466547-809a-4b81-bfd7-ce9a1ac4bb2b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"import re\n",
"\n",
"from dotenv import load_dotenv\n",
"from openai import OpenAI\n",
"import anthropic"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "acaff46f-e43e-4527-a404-a5b3ae830e51",
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(\n",
" level=logging.WARNING,\n",
" format=\"%(levelname)s:%(name)s:%(funcName)s:%(message)s\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aca57918-0271-4574-918b-2808f51698d1",
"metadata": {},
"outputs": [],
"source": [
"# check if API keys are in .env\n",
"load_dotenv(override=True)\n",
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
"\n",
"assert openai_api_key, \"OpenAI API key is missing\"\n",
"assert anthropic_api_key, \"Anthropic API key is missing\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25c37440-8692-4a8d-95e6-998691b4acf6",
"metadata": {},
"outputs": [],
"source": [
"class Model:\n",
" \"\"\"One class for different API's.\n",
" \n",
" This implementation allows the use of the OpenAI and Anthropic API. Other endpoints,\n",
" such as Ollama, can be used as well, as long as they are used via the OpenAI\n",
" Python library.\n",
" \n",
" \"\"\"\n",
" def __init__(self, api=None, model_name=\"mock\"):\n",
" \"\"\"\n",
" Args:\n",
" api: Can be an OpenAI or anthropic.Anthropic object or None to make a mock run.\n",
" model_name (str): Identifies the model used via the API.\n",
"\n",
" \"\"\"\n",
" self.api = api\n",
" self.name = model_name\n",
" if type(self.api) not in {OpenAI, anthropic.Anthropic} and self.name not in {\"mock\", \"\"}:\n",
" logging.warning(f\"Unknown API '{self.api}'. Using mock.\")\n",
"\n",
" def complete(self, messages, system=\"\"):\n",
" \"\"\"Make API call.\"\"\"\n",
" completion = \"\"\n",
" if isinstance(self.api, OpenAI):\n",
" completion = self.api.chat.completions.create(\n",
" model=self.name,\n",
" messages=messages,\n",
" max_tokens=300\n",
" )\n",
" completion = completion.choices[0].message.content\n",
"\n",
" elif isinstance(self.api, anthropic.Anthropic):\n",
" completion = self.api.messages.create(\n",
" model=self.name,\n",
" system=system,\n",
" messages=messages,\n",
" max_tokens=300\n",
" )\n",
" completion = completion.content[0].text\n",
" \n",
" else:\n",
" completion = \"Mock answer.\"\n",
"\n",
" return self.parse_answer(completion)\n",
"\n",
" def parse_answer(self, answer):\n",
" # \n",
" # Remove prefix 'Name:' from answer if present.\n",
" regex = r\"(?P<name>\\w+): (?P<content>.*)\"\n",
" match = re.match(regex, answer, re.DOTALL)\n",
" if match:\n",
" logging.info(f\"{self.name} generated {match.group('name')}\")\n",
" return match.group(\"content\")\n",
" return answer\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "462df0ba-36b5-4043-b0d0-a1d68edb968a",
"metadata": {},
"outputs": [],
"source": [
"class Participant:\n",
" \"\"\"Represents one participant in a conversation.\"\"\"\n",
" def __init__(self, name, model=Model(), system_prompt=\"\", initial_message=\"\"):\n",
" \"\"\"\n",
" Args:\n",
" model (Model): The model that is called to get participant's answer.\n",
" name (str): Used to assign answers to different participants. Is inserted in the\n",
" messages list, so the model knows who's spoken. Is also\n",
" displayed in the output.\n",
" system_prompt (str): The system prompt overgiven to the model backend.\n",
" initial_message (str): An optional conversation start.\n",
" \"\"\"\n",
" self.model = model\n",
" self.name = name\n",
" self.role = system_prompt\n",
" self.initial_msg = initial_message\n",
" self.messages = [] # keeps conversation history\n",
" if isinstance(self.model.api, OpenAI) and self.role:\n",
" self.messages = [{\"role\": \"system\", \"content\": self.role}]\n",
" self.last_msg = \"\"\n",
"\n",
" def speak(self):\n",
" if self.initial_msg:\n",
" self.last_msg = self.initial_msg\n",
" self.initial_msg = \"\"\n",
" else:\n",
" self.last_msg = self.model.complete(self.messages, self.role)\n",
" self.update_messages(role=\"assistant\", content=self.last_msg)\n",
" return self.last_msg\n",
"\n",
" def listen(self, message: str, speaker_name: str):\n",
" # Insert the speaker name, so the model can distinguish them\n",
" self.update_messages(role=\"user\", content=f\"{speaker_name}: {message}\")\n",
"\n",
" def update_messages(self, role, content):\n",
" self.messages.append({\"role\": role, \"content\": content})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e838901f-9a50-4f6b-b30f-e78c27e86bd7",
"metadata": {},
"outputs": [],
"source": [
"class ThreePartyChat:\n",
" \"\"\"Make three Participants communicate.\"\"\"\n",
" def __init__(self, participants, n_turns=4):\n",
" \"\"\"\n",
" Args:\n",
" participants (tuple[Participant]): Three objects. The order determines the speaking order.\n",
" n_turns (int): Number of turns per participant, incl. Participant.initial_message.\n",
"\n",
" \"\"\"\n",
" self.n_turns = n_turns\n",
" self.p1, self.p2, self.p3 = participants\n",
" if len({bool(self.p1.initial_msg), bool(self.p2.initial_msg), bool(self.p3.initial_msg)}) != 1:\n",
" logging.warning(\"At least one Participant has gotten a value for initial_message while another hasn't.\")\n",
" if len({self.p1.name, self.p2.name, self.p3.name}) != 3:\n",
" raise ValueError(f\"Some Participants have the same name. \"\n",
" f\"Please use unique names.\"\n",
" f\"\\nNames you've given: {self.p1.name}, {self.p2.name} and {self.p3.name}. \")\n",
"\n",
" def start(self, n_turns=None):\n",
" \"\"\"Start a conversation with n_turns rounds.\n",
" \n",
" Args:\n",
" n_turns (int): If None, self.n_turns is used.\n",
"\n",
" \"\"\"\n",
" for i in range(n_turns or self.n_turns):\n",
" # Make each participant speak and display their answers\n",
" self.make_display_turn(self.p1, self.p2, self.p3)\n",
" self.make_display_turn(self.p2, self.p1, self.p3)\n",
" self.make_display_turn(self.p3, self.p2, self.p1)\n",
"\n",
" def make_display_turn(self, speaker, *listeners):\n",
" self.speaker_to_listeners(speaker, *listeners)\n",
" self.display_last_utterance(speaker)\n",
" \n",
" def speaker_to_listeners(self, speaker, *listeners):\n",
" \"\"\"Get answer from speaker and update conversation histories.\"\"\"\n",
" speaker_text = speaker.speak()\n",
" for listener in listeners:\n",
" listener.listen(speaker_text, speaker.name)\n",
"\n",
" def display_last_utterance(self, speaker):\n",
" print(\"{} ({}):\\n{}\\n\".format(\n",
" speaker.name.upper(), speaker.model.name, speaker.last_msg\n",
" ))\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "80294493-04ff-4bec-af88-c3fc11d21c54",
"metadata": {},
"source": [
"#### Example system prompts:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "997841b1-d547-472b-a298-a60be2f9b90f",
"metadata": {},
"outputs": [],
"source": [
"name1 = \"Austin\"\n",
"name2 = \"Jonas\"\n",
"name3 = \"Tim\"\n",
"\n",
"general_system = (\n",
" \"\\n\\nYou've entered a chatroom with two other participants. \"\n",
" 'Their names are \"{}\" and \"{}\". Your name is \"{}\".'\n",
" \"\\nGenerate a maximum of 100 words per turn.\"\n",
")\n",
"\n",
"system1 = (\n",
" \"You are very argumentative; \"\n",
" \"You always find something to discuss. \"\n",
" \"When someone says their opinion, you often disagree. \"\n",
" \"You enjoy swimming against the tide and mocking mainstream opinions.\"\n",
" + general_system.format(name3, name2, name1)\n",
")\n",
"\n",
"system2 = (\n",
" \"You have a very conservative and clear opinion on most things. \"\n",
" \"You feel safest in your familiar surroundings. You are very reluctant to try out new things. \"\n",
" \"In discourses you are stubborn and want to convince others from your gridlocked beliefs.\"\n",
" + general_system.format(name1, name3, name2)\n",
")\n",
"\n",
"system3 = (\n",
" \"You are very humorous and like to be ironic. Sometimes you tell silly jokes. \"\n",
" \"You like variation; If a discussion about a topic takes too long, you start a new topic.\"\n",
" + general_system.format(name1, name2, name3)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0f455bb6-c6a8-4f75-a003-4bfda8dcff8a",
"metadata": {},
"source": [
"#### Example with **Claude-3-Haiku** and *two instances* of **GPT-4o-mini**:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6953f270-6a59-4c73-aad9-0284580adccd",
"metadata": {},
"outputs": [],
"source": [
"openai_api = OpenAI()\n",
"claude_api = anthropic.Anthropic()\n",
"# ollama could be used like this:\n",
"# ollama_api = OpenAI(base_url=\"http://localhost:11434/v1\", api_key=\"ollama\")\n",
"\n",
"claude_model_str = \"claude-3-haiku-20240307\"\n",
"gpt_model_str = \"gpt-4o-mini\"\n",
"# llama_model_str = \"llama3.2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2fadb8db-41e6-4362-a2fe-3e0902ff7116",
"metadata": {},
"outputs": [],
"source": [
"# Create Model objects\n",
"gpt_model = Model(openai_api, gpt_model_str)\n",
"claude_model = Model(claude_api, claude_model_str)\n",
"\n",
"# Create three Participants\n",
"p1 = Participant(name=name1, model=gpt_model, system_prompt=system1, initial_message=\"Hello there\")\n",
"p2 = Participant(name=name2, model=claude_model, system_prompt=system2, initial_message=\"Good evening.\")\n",
"p3 = Participant(name=name3, model=gpt_model, system_prompt=system3, initial_message=\"Hey guys\")\n",
"\n",
"# To make a mock run without API calls:\n",
"p1 = Participant(name=name1, system_prompt=system1, initial_message=\"Hello there\")\n",
"p2 = Participant(name=name2, system_prompt=system2, initial_message=\"Good evening.\")\n",
"p3 = Participant(name=name3, system_prompt=system3, initial_message=\"Hey guys\")\n",
"\n",
"# Create Chat\n",
"chat = ThreePartyChat((p1, p2, p3))"
]
},
{
"cell_type": "markdown",
"id": "7f0daa3e-b97e-48ad-aa24-bff728234241",
"metadata": {},
"source": [
"#### Start the conversation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b377d50-52a1-4f3e-a7ed-bdc8a6abe710",
"metadata": {},
"outputs": [],
"source": [
"chat.start() # starts a chat with 4 rounds\n",
"# chat.start(2) # 2 rounds"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading…
Cancel
Save