{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5df0164c-1980-4fd7-94e4-a71b485a41fd",
   "metadata": {},
   "source": [
    "# Week 2 Day 1 - Conversation between three AI's\n",
    "\n",
    "This notebook defines three classes (`ThreeWayChat`, `Participant` and `Model`) that implement a 3-way conversation between different AI's.  \n",
    "\n",
    "At the bottom there is an example conversation between a Claude model and two GPT models.\n",
    "\n",
    "The implementation works with models available via the `openai` and `anthropic` libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b466547-809a-4b81-bfd7-ce9a1ac4bb2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "import re\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import anthropic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acaff46f-e43e-4527-a404-a5b3ae830e51",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(\n",
    "    level=logging.WARNING,\n",
    "    format=\"%(levelname)s:%(name)s:%(funcName)s:%(message)s\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aca57918-0271-4574-918b-2808f51698d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if API keys are in .env\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "\n",
    "assert openai_api_key, \"OpenAI API key is missing\"\n",
    "assert anthropic_api_key, \"Anthropic API key is missing\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25c37440-8692-4a8d-95e6-998691b4acf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    \"\"\"One class for different API's.\n",
    "    \n",
    "    This implementation allows the use of the OpenAI and Anthropic API. Other endpoints,\n",
    "    such as Ollama, can be used as well, as long as they are used via the OpenAI\n",
    "    Python library.\n",
    "    \n",
    "    \"\"\"\n",
    "    def __init__(self, api=None, model_name=\"mock\"):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            api: Can be an OpenAI or anthropic.Anthropic object or None to make a mock run.\n",
    "            model_name (str): Identifies the model used via the API.\n",
    "\n",
    "        \"\"\"\n",
    "        self.api = api\n",
    "        self.name = model_name\n",
    "        if type(self.api) not in {OpenAI, anthropic.Anthropic} and self.name not in {\"mock\", \"\"}:\n",
    "            logging.warning(f\"Unknown API '{self.api}'. Using mock.\")\n",
    "\n",
    "    def complete(self, messages, system=\"\"):\n",
    "        \"\"\"Make API call.\"\"\"\n",
    "        completion = \"\"\n",
    "        if isinstance(self.api, OpenAI):\n",
    "            completion = self.api.chat.completions.create(\n",
    "                model=self.name,\n",
    "                messages=[{\"role\": \"system\", \"content\": system}] + messages,\n",
    "                max_tokens=300\n",
    "            )\n",
    "            completion = completion.choices[0].message.content\n",
    "\n",
    "        elif isinstance(self.api, anthropic.Anthropic):\n",
    "            completion = self.api.messages.create(\n",
    "                model=self.name,\n",
    "                system=system,\n",
    "                messages=messages,\n",
    "                max_tokens=300\n",
    "            )\n",
    "            completion = completion.content[0].text\n",
    "        \n",
    "        else:\n",
    "            completion = \"Mock answer.\"\n",
    "\n",
    "        return self.parse_answer(completion)\n",
    "\n",
    "    def parse_answer(self, answer):\n",
    "        # Remove prefix 'Name:' from answer if present.\n",
    "        regex = r\"(?P<name>\\w+): (?P<content>.*)\"\n",
    "        match = re.match(regex, answer, re.DOTALL)\n",
    "        if match:\n",
    "            logging.info(f\"{self.name} generated {match.group('name')}\")\n",
    "            return match.group(\"content\")\n",
    "        return answer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "462df0ba-36b5-4043-b0d0-a1d68edb968a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Participant:\n",
    "    \"\"\"Represents one participant in a conversation.\"\"\"\n",
    "    def __init__(self, name, model=Model(), system_prompt=\"\", initial_message=\"\"):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            model (Model): The model that is called to get participant's answer.\n",
    "            name (str): Used to assign answers to different participants. Is inserted in the\n",
    "                messages list, so the model knows who's spoken. Is also\n",
    "                displayed in the output.\n",
    "            system_prompt (str): The system prompt overgiven to the model backend.\n",
    "            initial_message (str): An optional conversation start.\n",
    "            \"\"\"\n",
    "        self.model = model\n",
    "        self.name = name\n",
    "        self.role = system_prompt\n",
    "        self.initial_msg = initial_message\n",
    "        self.messages = []  # keeps conversation history\n",
    "        self.last_msg = \"\"\n",
    "\n",
    "    def speak(self):\n",
    "        if self.initial_msg:\n",
    "            self.last_msg = self.initial_msg\n",
    "            self.initial_msg = \"\"\n",
    "        else:\n",
    "            self.last_msg = self.model.complete(self.messages, self.role)\n",
    "        self.update_messages(role=\"assistant\", content=self.last_msg)\n",
    "        return self.last_msg\n",
    "\n",
    "    def listen(self, message: str, speaker_name: str):\n",
    "        # Insert the speaker name, so the model can distinguish them\n",
    "        self.update_messages(role=\"user\", content=f\"{speaker_name}: {message}\")\n",
    "\n",
    "    def update_messages(self, role, content):\n",
    "        self.messages.append({\"role\": role, \"content\": content})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e838901f-9a50-4f6b-b30f-e78c27e86bd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ThreeWayChat:\n",
    "    \"\"\"Make three Participants communicate.\"\"\"\n",
    "    def __init__(self, participants, n_turns=4):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            participants (tuple[Participant]): Three objects. The order determines the speaking order.\n",
    "            n_turns (int): Number of turns per participant, incl. Participant.initial_message.\n",
    "\n",
    "        \"\"\"\n",
    "        self.n_turns = n_turns\n",
    "        self.p1, self.p2, self.p3 = participants\n",
    "        if len({bool(self.p1.initial_msg), bool(self.p2.initial_msg), bool(self.p3.initial_msg)}) != 1:\n",
    "            logging.warning(\"At least one Participant has gotten a value for initial_message while another hasn't.\")\n",
    "        if len({self.p1.name, self.p2.name, self.p3.name}) != 3:\n",
    "            raise ValueError(f\"Some Participants have the same name. \"\n",
    "                             f\"Please use unique names.\"\n",
    "                             f\"\\nNames you've given: {self.p1.name}, {self.p2.name} and {self.p3.name}. \")\n",
    "\n",
    "    def start(self, n_turns=None):\n",
    "        \"\"\"Start a conversation with n_turns rounds.\n",
    "        \n",
    "        Args:\n",
    "            n_turns (int): If None, self.n_turns is used.\n",
    "\n",
    "        \"\"\"\n",
    "        for i in range(n_turns or self.n_turns):\n",
    "            # Make each participant speak and display their answers\n",
    "            self.make_display_turn(self.p1, self.p2, self.p3)\n",
    "            self.make_display_turn(self.p2, self.p1, self.p3)\n",
    "            self.make_display_turn(self.p3, self.p2, self.p1)\n",
    "\n",
    "    def make_display_turn(self, speaker, *listeners):\n",
    "        self.speaker_to_listeners(speaker, *listeners)\n",
    "        self.display_last_utterance(speaker)\n",
    "    \n",
    "    def speaker_to_listeners(self, speaker, *listeners):\n",
    "        \"\"\"Get answer from speaker and update conversation histories.\"\"\"\n",
    "        speaker_text = speaker.speak()\n",
    "        for listener in listeners:\n",
    "            listener.listen(speaker_text, speaker.name)\n",
    "\n",
    "    def display_last_utterance(self, speaker):\n",
    "        print(\"{} ({}):\\n{}\\n\".format(\n",
    "                speaker.name.upper(), speaker.model.name, speaker.last_msg\n",
    "            ))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80294493-04ff-4bec-af88-c3fc11d21c54",
   "metadata": {},
   "source": [
    "#### Example system prompts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "997841b1-d547-472b-a298-a60be2f9b90f",
   "metadata": {},
   "outputs": [],
   "source": [
    "name1 = \"Austin\"\n",
    "name2 = \"Jonas\"\n",
    "name3 = \"Tim\"\n",
    "\n",
    "general_system = (\n",
    "    \"\\n\\nYou've entered a chatroom with two other participants. \"\n",
    "    'Their names are \"{}\" and \"{}\". Your name is \"{}\".'\n",
    "    \"\\nGenerate a maximum of 100 words per turn.\"\n",
    ")\n",
    "\n",
    "system1 = (\n",
    "    \"You are very argumentative; \"\n",
    "    \"You always find something to discuss. \"\n",
    "    \"When someone says their opinion, you often disagree. \"\n",
    "    \"You enjoy swimming against the tide and mocking mainstream opinions.\"\n",
    "    + general_system.format(name3, name2, name1)\n",
    ")\n",
    "\n",
    "system2 = (\n",
    "    \"You have a very conservative and clear opinion on most things. \"\n",
    "    \"You feel safest in your familiar surroundings. You are very reluctant to try out new things. \"\n",
    "    \"In discourses you are stubborn and want to convince others from your gridlocked beliefs.\"\n",
    "    + general_system.format(name1, name3, name2)\n",
    ")\n",
    "\n",
    "system3 = (\n",
    "    \"You are very humorous and like to be ironic. Sometimes you tell silly jokes. \"\n",
    "    \"You like variation; If a discussion about a topic takes too long, you start a new topic.\"\n",
    "    + general_system.format(name1, name2, name3)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f455bb6-c6a8-4f75-a003-4bfda8dcff8a",
   "metadata": {},
   "source": [
    "#### Example with **Claude-3-Haiku** and *two instances* of **GPT-4o-mini**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6953f270-6a59-4c73-aad9-0284580adccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "openai_api = OpenAI()\n",
    "claude_api = anthropic.Anthropic()\n",
    "# ollama could be used like this:\n",
    "# ollama_api = OpenAI(base_url=\"http://localhost:11434/v1\", api_key=\"ollama\")\n",
    "\n",
    "claude_model_str = \"claude-3-haiku-20240307\"\n",
    "gpt_model_str = \"gpt-4o-mini\"\n",
    "# llama_model_str = \"llama3.2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fadb8db-41e6-4362-a2fe-3e0902ff7116",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Model objects\n",
    "gpt_model = Model(openai_api, gpt_model_str)\n",
    "claude_model = Model(claude_api, claude_model_str)\n",
    "\n",
    "# Create three Participants\n",
    "p1 = Participant(name=name1, model=gpt_model, system_prompt=system1, initial_message=\"Hello there\")\n",
    "p2 = Participant(name=name2, model=claude_model, system_prompt=system2, initial_message=\"Good evening.\")\n",
    "p3 = Participant(name=name3, model=gpt_model, system_prompt=system3, initial_message=\"Hey guys\")\n",
    "\n",
    "# To make a mock run without API calls:\n",
    "# p1 = Participant(name=name1, system_prompt=system1, initial_message=\"Hello there\")\n",
    "# p2 = Participant(name=name2, system_prompt=system2, initial_message=\"Good evening.\")\n",
    "# p3 = Participant(name=name3, system_prompt=system3, initial_message=\"Hey guys\")\n",
    "\n",
    "# Create Chat\n",
    "chat = ThreeWayChat((p1, p2, p3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f0daa3e-b97e-48ad-aa24-bff728234241",
   "metadata": {},
   "source": [
    "#### Start the conversation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b377d50-52a1-4f3e-a7ed-bdc8a6abe710",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat.start() # starts a chat with 4 rounds\n",
    "# chat.start(2) # 2 rounds"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}