1 changed files with 363 additions and 0 deletions
@ -0,0 +1,363 @@ |
|||||||
|
{ |
||||||
|
"cells": [ |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "5df0164c-1980-4fd7-94e4-a71b485a41fd", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"# Week 2 Day 1 - Conversation between three AI's\n", |
||||||
|
"\n", |
||||||
|
"This notebook defines three classes (`ThreePartyChat`, `Participant` and `Model`) that implement a 3-party chat between different AI's. \n", |
||||||
|
"\n", |
||||||
|
"At the bottom there is an example conversation between a Claude model and two GPT models.\n", |
||||||
|
"\n", |
||||||
|
"The implementation works with models available via the `openai` and `anthropic` libraries." |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "8b466547-809a-4b81-bfd7-ce9a1ac4bb2b", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"import os\n", |
||||||
|
"import logging\n", |
||||||
|
"import re\n", |
||||||
|
"\n", |
||||||
|
"from dotenv import load_dotenv\n", |
||||||
|
"from openai import OpenAI\n", |
||||||
|
"import anthropic" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "acaff46f-e43e-4527-a404-a5b3ae830e51", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"logging.basicConfig(\n", |
||||||
|
" level=logging.WARNING,\n", |
||||||
|
" format=\"%(levelname)s:%(name)s:%(funcName)s:%(message)s\"\n", |
||||||
|
")" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "aca57918-0271-4574-918b-2808f51698d1", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"# check if API keys are in .env\n", |
||||||
|
"load_dotenv(override=True)\n", |
||||||
|
"openai_api_key = os.getenv('OPENAI_API_KEY')\n", |
||||||
|
"anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n", |
||||||
|
"\n", |
||||||
|
"assert openai_api_key, \"OpenAI API key is missing\"\n", |
||||||
|
"assert anthropic_api_key, \"Anthropic API key is missing\"" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "25c37440-8692-4a8d-95e6-998691b4acf6", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"class Model:\n", |
||||||
|
" \"\"\"One class for different API's.\n", |
||||||
|
" \n", |
||||||
|
" This implementation allows the use of the OpenAI and Anthropic API. Other endpoints,\n", |
||||||
|
" such as Ollama, can be used as well, as long as they are used via the OpenAI\n", |
||||||
|
" Python library.\n", |
||||||
|
" \n", |
||||||
|
" \"\"\"\n", |
||||||
|
" def __init__(self, api=None, model_name=\"mock\"):\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" Args:\n", |
||||||
|
" api: Can be an OpenAI or anthropic.Anthropic object or None to make a mock run.\n", |
||||||
|
" model_name (str): Identifies the model used via the API.\n", |
||||||
|
"\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" self.api = api\n", |
||||||
|
" self.name = model_name\n", |
||||||
|
" if type(self.api) not in {OpenAI, anthropic.Anthropic} and self.name not in {\"mock\", \"\"}:\n", |
||||||
|
" logging.warning(f\"Unknown API '{self.api}'. Using mock.\")\n", |
||||||
|
"\n", |
||||||
|
" def complete(self, messages, system=\"\"):\n", |
||||||
|
" \"\"\"Make API call.\"\"\"\n", |
||||||
|
" completion = \"\"\n", |
||||||
|
" if isinstance(self.api, OpenAI):\n", |
||||||
|
" completion = self.api.chat.completions.create(\n", |
||||||
|
" model=self.name,\n", |
||||||
|
" messages=messages,\n", |
||||||
|
" max_tokens=300\n", |
||||||
|
" )\n", |
||||||
|
" completion = completion.choices[0].message.content\n", |
||||||
|
"\n", |
||||||
|
" elif isinstance(self.api, anthropic.Anthropic):\n", |
||||||
|
" completion = self.api.messages.create(\n", |
||||||
|
" model=self.name,\n", |
||||||
|
" system=system,\n", |
||||||
|
" messages=messages,\n", |
||||||
|
" max_tokens=300\n", |
||||||
|
" )\n", |
||||||
|
" completion = completion.content[0].text\n", |
||||||
|
" \n", |
||||||
|
" else:\n", |
||||||
|
" completion = \"Mock answer.\"\n", |
||||||
|
"\n", |
||||||
|
" return self.parse_answer(completion)\n", |
||||||
|
"\n", |
||||||
|
" def parse_answer(self, answer):\n", |
||||||
|
" # \n", |
||||||
|
" # Remove prefix 'Name:' from answer if present.\n", |
||||||
|
" regex = r\"(?P<name>\\w+): (?P<content>.*)\"\n", |
||||||
|
" match = re.match(regex, answer, re.DOTALL)\n", |
||||||
|
" if match:\n", |
||||||
|
" logging.info(f\"{self.name} generated {match.group('name')}\")\n", |
||||||
|
" return match.group(\"content\")\n", |
||||||
|
" return answer\n" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "462df0ba-36b5-4043-b0d0-a1d68edb968a", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"class Participant:\n", |
||||||
|
" \"\"\"Represents one participant in a conversation.\"\"\"\n", |
||||||
|
" def __init__(self, name, model=Model(), system_prompt=\"\", initial_message=\"\"):\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" Args:\n", |
||||||
|
" model (Model): The model that is called to get participant's answer.\n", |
||||||
|
" name (str): Used to assign answers to different participants. Is inserted in the\n", |
||||||
|
" messages list, so the model knows who's spoken. Is also\n", |
||||||
|
" displayed in the output.\n", |
||||||
|
" system_prompt (str): The system prompt overgiven to the model backend.\n", |
||||||
|
" initial_message (str): An optional conversation start.\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" self.model = model\n", |
||||||
|
" self.name = name\n", |
||||||
|
" self.role = system_prompt\n", |
||||||
|
" self.initial_msg = initial_message\n", |
||||||
|
" self.messages = [] # keeps conversation history\n", |
||||||
|
" if isinstance(self.model.api, OpenAI) and self.role:\n", |
||||||
|
" self.messages = [{\"role\": \"system\", \"content\": self.role}]\n", |
||||||
|
" self.last_msg = \"\"\n", |
||||||
|
"\n", |
||||||
|
" def speak(self):\n", |
||||||
|
" if self.initial_msg:\n", |
||||||
|
" self.last_msg = self.initial_msg\n", |
||||||
|
" self.initial_msg = \"\"\n", |
||||||
|
" else:\n", |
||||||
|
" self.last_msg = self.model.complete(self.messages, self.role)\n", |
||||||
|
" self.update_messages(role=\"assistant\", content=self.last_msg)\n", |
||||||
|
" return self.last_msg\n", |
||||||
|
"\n", |
||||||
|
" def listen(self, message: str, speaker_name: str):\n", |
||||||
|
" # Insert the speaker name, so the model can distinguish them\n", |
||||||
|
" self.update_messages(role=\"user\", content=f\"{speaker_name}: {message}\")\n", |
||||||
|
"\n", |
||||||
|
" def update_messages(self, role, content):\n", |
||||||
|
" self.messages.append({\"role\": role, \"content\": content})\n" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "e838901f-9a50-4f6b-b30f-e78c27e86bd7", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"class ThreePartyChat:\n", |
||||||
|
" \"\"\"Make three Participants communicate.\"\"\"\n", |
||||||
|
" def __init__(self, participants, n_turns=4):\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" Args:\n", |
||||||
|
" participants (tuple[Participant]): Three objects. The order determines the speaking order.\n", |
||||||
|
" n_turns (int): Number of turns per participant, incl. Participant.initial_message.\n", |
||||||
|
"\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" self.n_turns = n_turns\n", |
||||||
|
" self.p1, self.p2, self.p3 = participants\n", |
||||||
|
" if len({bool(self.p1.initial_msg), bool(self.p2.initial_msg), bool(self.p3.initial_msg)}) != 1:\n", |
||||||
|
" logging.warning(\"At least one Participant has gotten a value for initial_message while another hasn't.\")\n", |
||||||
|
" if len({self.p1.name, self.p2.name, self.p3.name}) != 3:\n", |
||||||
|
" raise ValueError(f\"Some Participants have the same name. \"\n", |
||||||
|
" f\"Please use unique names.\"\n", |
||||||
|
" f\"\\nNames you've given: {self.p1.name}, {self.p2.name} and {self.p3.name}. \")\n", |
||||||
|
"\n", |
||||||
|
" def start(self, n_turns=None):\n", |
||||||
|
" \"\"\"Start a conversation with n_turns rounds.\n", |
||||||
|
" \n", |
||||||
|
" Args:\n", |
||||||
|
" n_turns (int): If None, self.n_turns is used.\n", |
||||||
|
"\n", |
||||||
|
" \"\"\"\n", |
||||||
|
" for i in range(n_turns or self.n_turns):\n", |
||||||
|
" # Make each participant speak and display their answers\n", |
||||||
|
" self.make_display_turn(self.p1, self.p2, self.p3)\n", |
||||||
|
" self.make_display_turn(self.p2, self.p1, self.p3)\n", |
||||||
|
" self.make_display_turn(self.p3, self.p2, self.p1)\n", |
||||||
|
"\n", |
||||||
|
" def make_display_turn(self, speaker, *listeners):\n", |
||||||
|
" self.speaker_to_listeners(speaker, *listeners)\n", |
||||||
|
" self.display_last_utterance(speaker)\n", |
||||||
|
" \n", |
||||||
|
" def speaker_to_listeners(self, speaker, *listeners):\n", |
||||||
|
" \"\"\"Get answer from speaker and update conversation histories.\"\"\"\n", |
||||||
|
" speaker_text = speaker.speak()\n", |
||||||
|
" for listener in listeners:\n", |
||||||
|
" listener.listen(speaker_text, speaker.name)\n", |
||||||
|
"\n", |
||||||
|
" def display_last_utterance(self, speaker):\n", |
||||||
|
" print(\"{} ({}):\\n{}\\n\".format(\n", |
||||||
|
" speaker.name.upper(), speaker.model.name, speaker.last_msg\n", |
||||||
|
" ))\n", |
||||||
|
"\n" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "80294493-04ff-4bec-af88-c3fc11d21c54", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"#### Example system prompts:" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "997841b1-d547-472b-a298-a60be2f9b90f", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"name1 = \"Austin\"\n", |
||||||
|
"name2 = \"Jonas\"\n", |
||||||
|
"name3 = \"Tim\"\n", |
||||||
|
"\n", |
||||||
|
"general_system = (\n", |
||||||
|
" \"\\n\\nYou've entered a chatroom with two other participants. \"\n", |
||||||
|
" 'Their names are \"{}\" and \"{}\". Your name is \"{}\".'\n", |
||||||
|
" \"\\nGenerate a maximum of 100 words per turn.\"\n", |
||||||
|
")\n", |
||||||
|
"\n", |
||||||
|
"system1 = (\n", |
||||||
|
" \"You are very argumentative; \"\n", |
||||||
|
" \"You always find something to discuss. \"\n", |
||||||
|
" \"When someone says their opinion, you often disagree. \"\n", |
||||||
|
" \"You enjoy swimming against the tide and mocking mainstream opinions.\"\n", |
||||||
|
" + general_system.format(name3, name2, name1)\n", |
||||||
|
")\n", |
||||||
|
"\n", |
||||||
|
"system2 = (\n", |
||||||
|
" \"You have a very conservative and clear opinion on most things. \"\n", |
||||||
|
" \"You feel safest in your familiar surroundings. You are very reluctant to try out new things. \"\n", |
||||||
|
" \"In discourses you are stubborn and want to convince others from your gridlocked beliefs.\"\n", |
||||||
|
" + general_system.format(name1, name3, name2)\n", |
||||||
|
")\n", |
||||||
|
"\n", |
||||||
|
"system3 = (\n", |
||||||
|
" \"You are very humorous and like to be ironic. Sometimes you tell silly jokes. \"\n", |
||||||
|
" \"You like variation; If a discussion about a topic takes too long, you start a new topic.\"\n", |
||||||
|
" + general_system.format(name1, name2, name3)\n", |
||||||
|
")" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "0f455bb6-c6a8-4f75-a003-4bfda8dcff8a", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"#### Example with **Claude-3-Haiku** and *two instances* of **GPT-4o-mini**:" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "6953f270-6a59-4c73-aad9-0284580adccd", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"openai_api = OpenAI()\n", |
||||||
|
"claude_api = anthropic.Anthropic()\n", |
||||||
|
"# ollama could be used like this:\n", |
||||||
|
"# ollama_api = OpenAI(base_url=\"http://localhost:11434/v1\", api_key=\"ollama\")\n", |
||||||
|
"\n", |
||||||
|
"claude_model_str = \"claude-3-haiku-20240307\"\n", |
||||||
|
"gpt_model_str = \"gpt-4o-mini\"\n", |
||||||
|
"# llama_model_str = \"llama3.2\"" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "2fadb8db-41e6-4362-a2fe-3e0902ff7116", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"# Create Model objects\n", |
||||||
|
"gpt_model = Model(openai_api, gpt_model_str)\n", |
||||||
|
"claude_model = Model(claude_api, claude_model_str)\n", |
||||||
|
"\n", |
||||||
|
"# Create three Participants\n", |
||||||
|
"p1 = Participant(name=name1, model=gpt_model, system_prompt=system1, initial_message=\"Hello there\")\n", |
||||||
|
"p2 = Participant(name=name2, model=claude_model, system_prompt=system2, initial_message=\"Good evening.\")\n", |
||||||
|
"p3 = Participant(name=name3, model=gpt_model, system_prompt=system3, initial_message=\"Hey guys\")\n", |
||||||
|
"\n", |
||||||
|
"# To make a mock run without API calls:\n", |
||||||
|
"p1 = Participant(name=name1, system_prompt=system1, initial_message=\"Hello there\")\n", |
||||||
|
"p2 = Participant(name=name2, system_prompt=system2, initial_message=\"Good evening.\")\n", |
||||||
|
"p3 = Participant(name=name3, system_prompt=system3, initial_message=\"Hey guys\")\n", |
||||||
|
"\n", |
||||||
|
"# Create Chat\n", |
||||||
|
"chat = ThreePartyChat((p1, p2, p3))" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "markdown", |
||||||
|
"id": "7f0daa3e-b97e-48ad-aa24-bff728234241", |
||||||
|
"metadata": {}, |
||||||
|
"source": [ |
||||||
|
"#### Start the conversation:" |
||||||
|
] |
||||||
|
}, |
||||||
|
{ |
||||||
|
"cell_type": "code", |
||||||
|
"execution_count": null, |
||||||
|
"id": "4b377d50-52a1-4f3e-a7ed-bdc8a6abe710", |
||||||
|
"metadata": {}, |
||||||
|
"outputs": [], |
||||||
|
"source": [ |
||||||
|
"chat.start() # starts a chat with 4 rounds\n", |
||||||
|
"# chat.start(2) # 2 rounds" |
||||||
|
] |
||||||
|
} |
||||||
|
], |
||||||
|
"metadata": { |
||||||
|
"kernelspec": { |
||||||
|
"display_name": "Python 3 (ipykernel)", |
||||||
|
"language": "python", |
||||||
|
"name": "python3" |
||||||
|
}, |
||||||
|
"language_info": { |
||||||
|
"codemirror_mode": { |
||||||
|
"name": "ipython", |
||||||
|
"version": 3 |
||||||
|
}, |
||||||
|
"file_extension": ".py", |
||||||
|
"mimetype": "text/x-python", |
||||||
|
"name": "python", |
||||||
|
"nbconvert_exporter": "python", |
||||||
|
"pygments_lexer": "ipython3", |
||||||
|
"version": "3.11.11" |
||||||
|
} |
||||||
|
}, |
||||||
|
"nbformat": 4, |
||||||
|
"nbformat_minor": 5 |
||||||
|
} |
Loading…
Reference in new issue